中文翻译暂不可用,显示俄语原文。

Почему gradient accumulation эквивалентен большому batch с точки зрения оптимизации?

Краткий тезис

Gradient accumulation — это техника, при которой градиенты вычисляются на нескольких маленьких микро-батчах и суммируются (или усредняются) перед одним шагом оптимизатора. Математически итоговый градиент совпадает с градиентом, полученным на одном большом батче размера micro_batch * accumulation_steps, при условии, что функция потерь является средним по элементам батча. Однако на практике из-за variance градиентов и эффектов batch normalization полной эквивалентности нет, особенно при большом числе шагов накопления.


1. Термин: Gradient Accumulation

Gradient accumulation — это метод, позволяющий симулировать большой batch size при ограниченной памяти GPU. Вместо того чтобы загружать сразу большой батч и вычислять градиент за один проход, мы:

  1. Делим большой батч на N микро-батчей.
  2. Для каждого микро-батча делаем forward + backward, но не обновляем веса.
  3. Накопливаем (суммируем) градиенты в буфере.
  4. После обработки всех микро-батчей делаем один шаг оптимизатора, используя накопленный градиент.

Термин микро-батч (micro-batch) — это минимальный подбатч, который помещается в память GPU. Accumulation steps — количество микро-батчей, градиенты которых суммируются перед обновлением.


2. Математическая эквивалентность

Пусть у нас есть функция потерь L(x, y; θ), где θ — параметры модели. Для батча размера B полная потеря обычно определяется как среднее по элементам:

L_total(θ) = (1/B) Σ_{i=1}^{B} L(x_i, y_i; θ)

Градиент по параметрам:

∇L_total = (1/B) Σ ∇L_i

Теперь разобьём батч на N микро-батчей размера b каждый, так что B = N * b. Для каждого микро-батча j:

∇L_micro_j = (1/b) Σ_{i in micro_j} ∇L_i

Если мы просто просуммируем градиенты микро-батчей:

Σ ∇L_micro_j = Σ (1/b) Σ ∇L_i = (1/b) Σ_{i=1}^{B} ∇L_i = (B/b) * (1/B) Σ ∇L_i = N * ∇L_total

Таким образом, сумма градиентов микро-батчей в N раз больше истинного градиента для полного батча. Чтобы получить эквивалентный градиент, нужно усреднить сумму (разделить на N). Обычно в фреймворках (PyTorch, TensorFlow) градиенты по умолчанию суммируются, поэтому при gradient accumulation мы либо делим на N вручную, либо используем loss / N при backward.

Вывод Если после каждого микро-батча делать loss.backward() (градиенты суммируются), а затем после N шагов вызвать optimizer.step() и поделить накопленный градиент на N, то полученное обновление будет в точности равно обновлению от одного батча размера B = N * b.


3. Почему это работает? Линейность градиента

Ключевое свойство — линейность оператора градиента. Градиент суммы функций равен сумме градиентов. Поскольку функция потерь для батча является средним (линейной комбинацией) индивидуальных потерь, градиент по батчу — это среднее градиентов по элементам. Gradient accumulation использует эту линейность: суммирование градиентов микро-батчей эквивалентно суммированию градиентов всех элементов, а затем делению на количество микро-батчей даёт среднее.

Это справедливо для любой дифференцируемой функции потерь, если она вычисляется как среднее по батчу. Если же используется сумма (например, loss = Σ L_i), то эквивалентность достигается без деления на N.


4. Отличия от реального большого batch

Несмотря на математическую эквивалентность, на практике есть важные различия:

АспектРеальный большой batchGradient accumulation
Variance градиентовГрадиент вычисляется по всем элементам сразу, variance нижеКаждый микро-батч даёт шумный градиент; накопление уменьшает variance, но не полностью идентично из-за порядка обработки
Batch NormalizationСтатистики (mean, var) вычисляются по всему батчуСтатистики вычисляются отдельно для каждого микро-батча; итоговые статистики не эквивалентны полному батчу
DropoutМаска dropout применяется один раз на весь батчМаска применяется отдельно для каждого микро-батча; эквивалентность только если dropout одинаков для всех микро-батчей (обычно так и есть)
LR schedulingШаги оптимизатора соответствуют реальным батчамКоличество шагов в N раз меньше; learning rate может потребовать корректировки (например, linear scaling rule)
ПамятьТребуется много памяти для хранения активаций всего батчаПамять линейно зависит от размера микро-батча; позволяет обучать модели с非常大的 batch size

Variance градиентов — разброс оценок градиента относительно истинного. При gradient accumulation градиенты микро-батчей могут быть более шумными, особенно если микро-батчи маленькие. Хотя среднее по N микро-батчам имеет ту же variance, что и градиент полного батча (при условии независимости), на практике из-за корреляции данных в батче (например, shuffled order) variance может немного отличаться. Однако для большинства задач различие незначительно.

Batch Normalization — серьёзное отличие. Если модель содержит BatchNorm (или LayerNorm с батч-зависимостью), то статистики нормализации вычисляются по каждому микро-батчу отдельно. При gradient accumulation эти статистики не усредняются между микро-батчами, что приводит к другому распределению активаций. Решение: использовать SyncBatchNorm (синхронизация между микро-батчами) или переключиться на GroupNorm / LayerNorm, которые не зависят от размера батча.


5. Практические соображения

  • Память Gradient accumulation позволяет использовать эффективный batch size, превышающий физическую память GPU. Например, с микро-батчем 4 и accumulation steps 8 получаем эффективный batch 32.
  • Скорость Каждый микро-батч требует отдельного forward/backward, что увеличивает общее время обучения (особенно overhead от запуска ядер). Однако это часто приемлемо, если память — узкое место.
  • Distributed training Gradient accumulation часто комбинируется с data parallelism. В распределённой среде каждый GPU обрабатывает свой микро-батч, затем градиенты all-reduce суммируются. Accumulation steps могут быть распределены между GPU или выполняться локально.
  • Learning rate: При увеличении эффективного batch size часто применяют linear scaling rule: увеличивать learning rate пропорционально batch size. При gradient accumulation эффективный batch size = micro_batch * accumulation_steps * num_gpus. Если вы меняете accumulation steps, нужно скорректировать LR.

6. Когда использовать gradient accumulation?

  • Ограниченная память GPU Когда желаемый batch size не помещается в VRAM.
  • Симуляция большого batch Для стабильности обучения (меньше variance градиентов) или для совместимости с предобученными моделями, которые обучались с определённым batch size.
  • Fine-tuning больших моделей LLM (например, LLaMA, GPT) часто fine-tune с batch size 1-4 на GPU, используя gradient accumulation до 64+.
  • Обучение агентов В контексте Agentic RAG может потребоваться fine-tuning retriever или генератора на данных с длинными контекстами; gradient accumulation помогает уместить обучение на одной GPU.

Не рекомендуется использовать gradient accumulation, если:

  • Модель содержит BatchNorm и нет возможности заменить его.
  • Требуется минимальное время обучения (overhead от множества микро-батчей может быть значительным).
  • Размер микро-батча слишком мал (например, 1), что приводит к высокому variance и нестабильности.

7. Пример кода на PyTorch

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(10, 2)
optimizer = optim.SGD(model.parameters(), lr=0.01)
accumulation_steps = 4
micro_batch_size = 2
effective_batch_size = micro_batch_size * accumulation_steps

# Предположим, у нас 8 образцов
data = torch.randn(8, 10)
targets = torch.randint(0, 2, (8,))

optimizer.zero_grad()
for i in range(0, len(data), micro_batch_size):
    micro_batch = data[i:i+micro_batch_size]
    micro_targets = targets[i:i+micro_batch_size]
    
    outputs = model(micro_batch)
    loss = nn.functional.cross_entropy(outputs, micro_targets)
    loss = loss / accumulation_steps  # усредняем, чтобы суммарный градиент был средним
    loss.backward()
    
    # После каждого accumulation_steps делаем шаг
    if (i // micro_batch_size + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Важно Деление loss на accumulation_steps необходимо, чтобы после суммирования градиентов получить среднее по эффективному батчу. Если не делить, то градиент будет в accumulation_steps раз больше, и learning rate нужно соответственно уменьшить.


8. Связь с другими техниками

  • Gradient checkpointing — уменьшает память за счёт пересчёта активаций во время backward. Комбинируется с gradient accumulation для ещё большего снижения потребления памяти.
  • Mixed precision training (FP16/AMP) — ускоряет вычисления и уменьшает память. При gradient accumulation важно использовать scaler из torch.cuda.amp для корректного масштабирования градиентов.
  • Distributed Data Parallel (DDP) — градиенты синхронизируются между GPU. Gradient accumulation может быть реализован как локально (каждый GPU накапливает свои микро-батчи), так и глобально (accumulation steps распределены между GPU).

Пет-проект для закрепления

Задача Обучить небольшую модель (например, ResNet-18 на CIFAR-10) с gradient accumulation и сравнить с обучением на реальном большом batch.

Инструменты PyTorch, torchvision, matplotlib.

Шаги:

  1. Загрузите CIFAR-10, создайте DataLoader с batch size = 64.
  2. Обучите модель с batch size = 64 (реальный большой batch) — baseline.
  3. Обучите ту же модель с micro_batch_size = 16 и accumulation_steps = 4 (эффективный batch 64). Замерьте время и точность.
  4. Повторите для micro_batch_size = 8, accumulation_steps = 8.
  5. Постройте графики loss и accuracy, сравните сходимость.
  6. Добавьте BatchNorm и повторите эксперимент — заметьте разницу.

Ожидаемый результат При одинаковом эффективном batch size кривые обучения будут близки, но при малом micro_batch_size (например, 2) может наблюдаться больший шум. BatchNorm вызовет расхождение, если не использовать SyncBatchNorm.


Связь с другими вопросами

ВопросТема
470Как работает gradient checkpointing?
471Сравнение методов уменьшения памяти при обучении
473Влияние batch size на обобщающую способность
474Linear scaling rule для learning rate
475Особенности distributed training с gradient accumulation
480Оптимизация fine-tuning LLM для RAG

Навигация