中文翻译暂不可用,显示俄语原文。

Почему small batch size (<32) ухудшает training стабильность?

Краткий тезис

Small batch size (<32) приводит к высокой дисперсии оценки градиента, что делает обновления весов шумными и нестабильными. Это замедляет сходимость, а при batch < 8 может вызвать расходимость даже с адаптивными оптимизаторами вроде Adam. Особенно критично для больших языковых моделей (LLM), где рекомендуемый размер батча — 64–1024. Основное решение — gradient accumulation, позволяющий имитировать большой батч без увеличения потребления памяти.


1. Термин: Batch size (размер батча)

Batch size — количество обучающих примеров, которые подаются на модель за одну итерацию перед обновлением весов. Градиент вычисляется как среднее градиентов по всем примерам в батче.

  • Small batch: <32 (часто 1, 2, 4, 8, 16)
  • Medium batch: 32–128
  • Large batch: >128 (512, 1024, 2048+)

Выбор batch size влияет на три ключевых аспекта:

  • Стабильность (дисперсия градиента]])
  • Скорость сходимости (количество итераций до минимума)
  • Потребление памяти (VRAM)

2. Почему small batch size увеличивает шум градиента

Градиент на каждом примере — случайная величина. Среднее по батчу — оценка истинного градиента. Дисперсия этой оценки обратно пропорциональна размеру батча:

Var(∇L_batch) = Var(∇L_single) / batch_size

Чем меньше батч, тем выше дисперсия. При batch_size=1 дисперсия максимальна — каждый шаг может сильно отклоняться от истинного направления.

Пример: если истинный градиент указывает на юго-восток, а отдельные примеры дают случайные направления, то при batch=1 шаг может пойти на север, а при batch=64 — почти точно на юго-восток.

Batch sizeДисперсия градиента (отн.)Характер обновлений
11.0Очень шумный, стохастический
80.125Шумный, но с тенденцией
320.031Умеренно стабильный
1280.008Стабильный, близкий к full-batch

3. Влияние на сходимость и роль Adam

3.1. Стохастический градиентный спуск (SGD)

При SGD small batch ведёт к zigzag-эффекту: траектория колеблется вокруг направления спуска. Это может помочь выйти из локальных минимумов, но замедляет сходимость в плоских областях.

3.2. Adam и адаптивные оптимизаторы

Adam сглаживает шум за счёт:

  • Momentum (экспоненциальное скользящее среднее градиента)
  • Adaptive learning rates (нормализация по второму моменту)

Однако при batch < 8 даже Adam может не сходиться: оценка второго момента становится слишком шумной, и шаги становятся хаотичными. На практике:

  • batch=2–4 — часто расходится (loss растёт)
  • batch=8–16 — сходится медленно, требует тюнинга learning rate
  • batch=32+ — стабильно

3.3. Влияние на learning rate

Правило масштабирования: при увеличении batch size в k раз можно увеличить learning rate в √k раз (для SGD) или в k раз (для Adam, с оговорками). Small batch требует меньшего learning rate, иначе шаги становятся слишком большими относительно шума.


4. Особенности для LLM

Большие языковые модели (LLM) особенно чувствительны к batch size по нескольким причинам:

  • Огромное количество параметров: градиенты имеют высокую размерность, шум в каждом направлении накапливается.
  • LayerNorm и residual connections: при малом батче статистики нормализации (mean, var) оцениваются неточно, что дестабилизирует forward pass.
  • Loss landscape: у LLM много острых локальных минимумов и седловых точек; шумный градиент может вытолкнуть из хорошей области.

Рекомендуемые batch sizes для fine-tuning LLM:


5. Решения: gradient accumulation и другие

5.1. Gradient accumulation (накопление градиента)

Основной способ борьбы с малым батчем без увеличения памяти:

accumulation_steps = 8  # имитируем batch_size = 8 * micro_batch
optimizer.zero_grad()
for step, batch in enumerate(dataloader):
    loss = model(batch)
    loss = loss / accumulation_steps  # нормализация
    loss.backward()
    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Как это работает: градиенты накапливаются за несколько микро-батчей, затем делается один шаг оптимизатора. Эффективный batch size = micro_batch_size × accumulation_steps.

Плюсы: не требует дополнительной памяти (кроме хранения градиентов, которые и так есть). Минусы: увеличивает время на одну эпоху (больше forward/backward проходов), но не влияет на стабильность.

5.2. Learning rate warmup

Плавное увеличение learning rate с 0 до целевого значения в первые 10–20% итераций. Помогает избежать расходимости на начальных шумных шагах.

5.3. Gradient clipping

Ограничение нормы градиента (например, max_norm=1.0) предотвращает взрывные обновления при редких выбросах.

5.4. Использование большего батча через data parallelism

Распределение батча на несколько GPU (Distributed Data Parallel) позволяет увеличить эффективный batch size без изменения micro-batch.


6. Экспериментальные данные

Batch sizeLoss после 1000 итераций (SGD, lr=0.01)Loss после 1000 итераций (Adam, lr=1e-3)
13.2 (колеблется)2.1 (шумно)
82.81.8
322.51.5
1282.31.4

Данные условны, но отражают тенденцию: small batch даёт более высокий финальный loss и большую вариативность.


7. Практические рекомендации

  • Для fine-tuning LLM: используйте micro_batch=4–8 с gradient accumulation до эффективного batch=64–128.
  • Для pre-training: batch=512–2048, распределённый на GPU.
  • Если память ограничена: gradient accumulation + mixed precision (fp16/bf16).
  • Если loss расходится: уменьшите learning rate, увеличьте accumulation steps, добавьте warmup.
  • Мониторинг: следите за variance градиентов (можно логировать норму градиента). Если она > 10× среднего — батч слишком мал.

Пет-проект для закрепления

Задача: экспериментально проверить влияние batch size на стабильность обучения.

Инструменты: PyTorch, небольшая модель (например, 2-layer MLP на MNIST или tiny GPT на Shakespeare).

Шаги:

  1. Реализуйте обучение с разными batch sizes: 1, 4, 16, 64, 256.
  2. Для каждого batch size зафиксируйте learning rate (подберите оптимальный для batch=64, для остальных используйте правило sqrt scaling).
  3. Обучите модель на 500 итераций, записывайте loss и норму градиента на каждом шаге.
  4. Постройте графики: loss vs step, gradient norm vs step.
  5. Добавьте gradient accumulation для batch=4 с accumulation_steps=16 (эффективный batch=64) и сравните с обычным batch=64.

Ожидаемый результат: вы увидите, что при batch=1 loss колеблется, при batch=4 сходится медленно, при batch=64 стабильно. Gradient accumulation даст такую же стабильность, как batch=64, но с меньшим потреблением памяти.


Связь с другими вопросами

ВопросТема
467Почему large batch size ухудшает обобщение?
469Как gradient accumulation влияет на эффективность обучения?
470Как выбрать learning rate для заданного batch size?
471Почему warmup важен при обучении LLM?
472Как batch size влияет на потребление VRAM?

Навигация