中文翻译暂不可用,显示俄语原文。

Как работает distributed optimizer в PyTorch (torch.distributed.optim)?

Краткий тезис

Distributed optimizer в PyTorch, в первую очередь ZeroRedundancyOptimizer (ZeRO-1), решает проблему переполнения памяти при обучении больших моделей (30B+ параметров) за счёт шардирования состояния оптимизатора между GPU. Каждый GPU хранит и обновляет только свою часть параметров, а после обновления синхронизирует полные веса через AllGather. Это позволяет сократить потребление памяти на GPU в число раз, равное размеру группы, ценой небольшого увеличения коммуникации.


1. Термин: Distributed optimizer (распределённый оптимизатор)

Distributed optimizer — это компонент, который позволяет выполнять шаги оптимизации (например, SGD, Adam) в распределённой среде, где модель разбита или реплицирована на несколько устройств (GPU). В отличие от обычного оптимизатора, который хранит полное состояние (веса, моменты, дисперсии) на каждом устройстве, distributed optimizer распределяет это состояние между устройствами, экономя память.

Зачем нужен При обучении больших моделей (например, GPT-3 с 175B параметров) состояние оптимизатора]] может занимать в 2–3 раза больше памяти, чем сами веса модели. Например, для Adam требуется хранить два момента (m и v) для каждого параметра, что при 30B параметров даёт ~360 ГБ только для состояния оптимизатора (при float16). Распределение этого состояния — критическая оптимизация.


2. Проблема: Memory overhead состояния оптимизатора

Рассмотрим модель с N параметров, обучаемую с оптимизатором Adam в float16:

КомпонентРазмер на один параметрРазмер на 30B параметров
Веса модели (fp16)2 байта60 ГБ
Градиенты (fp16)2 байта60 ГБ
Момент m (fp32)4 байта120 ГБ
Момент v (fp32)4 байта120 ГБ
Итого12 байт360 ГБ

Даже на одном GPU с 80 ГБ памяти это невозможно. Distributed optimizer решает эту проблему, шардируя состояние оптимизатора]].


3. ZeRO-1: Шардирование состояния оптимизатора

ZeRO (Zero Redundancy Optimizer) — это техника из библиотеки DeepSpeed, адаптированная в PyTorch как ZeroRedundancyOptimizer. Уровень ZeRO-1 шардирует только состояние оптимизатора]] (моменты, дисперсии), оставляя веса и градиенты реплицированными на каждом GPU.

Как работает

  1. Параметры модели делятся на K равных частей, где K — размер группы GPU.
  2. Каждый GPU отвечает за свою часть параметров: хранит для них состояние оптимизатора]] (m, v) и выполняет шаг оптимизации.
  3. После обновления весов своей части, GPU синхронизирует полные веса со всеми остальными через коллективную операцию AllGather.

Преимущество Потребление памяти на состояние оптимизатора]] уменьшается в K раз. Для 8 GPU — в 8 раз (с 240 ГБ до 30 ГБ на GPU).


4. Как работает ZeroRedundancyOptimizer в PyTorch

ZeroRedundancyOptimizer — это класс из torch.distributed.optim, который реализует ZeRO-1. Он принимает базовый оптимизатор (например, Adam) и распределяет его состояние.

Базовый код

import torch
import torch.distributed as dist
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.optim import Adam

# Инициализация группы процессов
dist.init_process_group(backend='nccl')
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)

# Модель
model = MyLargeModel().cuda()
model = torch.nn.parallel.DistributedDataParallel(model)

# ZeroRedundancyOptimizer оборачивает Adam
optimizer = ZeroRedundancyOptimizer(
    model.parameters(),
    optimizer_class=Adam,
    lr=1e-4,
    parameters_as_bucket_view=True,  # оптимизация памяти
)

# Цикл обучения
for data, target in dataloader:
    optimizer.zero_grad()
    loss = model(data)
    loss.backward()
    optimizer.step()  # внутри: AllGather весов после обновления

Важные параметры

  • optimizer_class: базовый оптимизатор (Adam, SGD и т.д.).
  • parameters_as_bucket_view: объединяет параметры в buckets для эффективной коммуникации.
  • group: группа процессов (по умолчанию dist.group.WORLD).

5. Коммуникация: AllGather после update

После того как каждый GPU обновил свою часть весов, необходимо, чтобы все GPU имели полную копию модели для следующего forward pass. Для этого используется AllGather:

  1. Каждый GPU отправляет свои обновлённые веса (свою часть) всем остальным.
  2. Каждый GPU получает от всех остальных их части и собирает полный тензор весов.

Сложность коммуникации O(K) данных на GPU (каждый отправляет 1/K часть, получает (K-1)/K частей). Для больших K это может стать узким местом, но на практике (до 64 GPU) это приемлемо.

Сравнение с DDP

  • DDP (DistributedDataParallel): каждый GPU хранит полную копию модели и состояния оптимизатора. После backward выполняется AllReduce градиентов. Память: полное состояние на каждом GPU.
  • ZeroRedundancyOptimizer: после backward каждый GPU обновляет только свою часть весов, затем AllGather для синхронизации полных весов. Память: состояние оптимизатора шардировано.
АспектDDPZeroRedundancyOptimizer (ZeRO-1)
Память на состояние optimizerПолная (K * размер)1/K размера
Коммуникация на шагAllReduce градиентов (1 операция)AllGather весов (1 операция)
СкоростьБыстрее (меньше данных)Медленнее на 5–10% из-за AllGather
Когда использоватьМодели < 10B параметровМодели > 30B параметров

6. Когда использовать distributed optimizer

Основной сценарий модели с числом параметров 30B+, где состояние оптимизатора превышает 100 ГБ на GPU. Например:

  • GPT-3 (175B) — состояние Adam ~1.4 ТБ.
  • LLaMA-65B — состояние Adam ~520 ГБ.
  • Модели для научных расчётов (например, в биоинформатике).

Когда НЕ использовать:

  • Модели меньше 10B параметров — оверхед коммуникации не оправдан.
  • Если доступно много GPU с большой памятью (например, 8 × A100 80 ГБ) — можно обойтись DDP.
  • Если важна скорость обучения больше, чем экономия памяти.

7. Ограничения и альтернативы (ZeRO-2, ZeRO-3)

ZeroRedundancyOptimizer (ZeRO-1) — только первый уровень ZeRO. Существуют более продвинутые уровни:

УровеньЧто шардируетсяЭкономия памятиКоммуникация
ZeRO-1Состояние optimizer~4x (для Adam)AllGather весов
ZeRO-2Состояние optimizer + градиенты~6xAllReduce градиентов + AllGather
ZeRO-3Состояние optimizer + градиенты + веса~8xAllGather весов и градиентов

ZeRO-2 и ZeRO-3 доступны через библиотеку DeepSpeed, которая интегрируется с PyTorch. ZeroRedundancyOptimizer — это встроенная альтернатива для ZeRO-1 без внешних зависимостей.

Ограничения ZeroRedundancyOptimizer

  • Не поддерживает все оптимизаторы (только те, что совместимы с torch.optim).
  • Требует, чтобы модель была обёрнута в DistributedDataParallel.
  • Не оптимизирует коммуникацию для градиентов (как ZeRO-2).

8. Пример кода с замером памяти

Продемонстрируем разницу в памяти между DDP и ZeroRedundancyOptimizer:

import torch
import torch.distributed as dist
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.optim import Adam
import os

def train_with_optimizer(use_zero):
    dist.init_process_group(backend='nccl')
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    
    model = torch.nn.Linear(10000, 10000).cuda()  # ~400M параметров
    if not use_zero:
        model = torch.nn.parallel.DistributedDataParallel(model)
    
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            model.parameters(),
            optimizer_class=Adam,
            lr=1e-4,
        )
    else:
        optimizer = Adam(model.parameters(), lr=1e-4)
    
    # Замер памяти до обучения
    mem_before = torch.cuda.memory_allocated()
    
    # Один шаг
    x = torch.randn(32, 10000).cuda()
    y = model(x).sum()
    y.backward()
    optimizer.step()
    
    mem_after = torch.cuda.memory_allocated()
    if rank == 0:
        print(f"Zero={use_zero}: memory delta = {(mem_after - mem_before) / 1e9:.2f} GB")
    
    dist.destroy_process_group()

Ожидаемый вывод (для 2 GPU):

Zero=False: memory delta = 1.52 GB
Zero=True:  memory delta = 0.76 GB

9. Пет-проект для закрепления

Задача Реализовать distributed training для модели BERT-large (340M параметров) с использованием ZeroRedundancyOptimizer и сравнить потребление памяти с DDP.

Инструменты

Шаги:

  1. Создайте скрипт train_bert_distributed.py с поддержкой аргумента --use-zero.
  2. Загрузите модель bert-large-uncased (340M параметров).
  3. Обучите на синтетических данных (batch_size=8, sequence_length=128) на 2 GPU.
  4. Замерьте torch.cuda.max_memory_allocated() до и после обучения.
  5. Постройте таблицу сравнения:
МетодПамять на GPU (max)Время на шаг (ms)
DDP4.2 ГБ120
ZeRO-12.8 ГБ135

Ожидаемый результат Вы увидите снижение памяти на ~30% и небольшое замедление (5–10%). Это подтвердит теорию.


10. Связь с другими вопросами

ВопросТема
477Как работает DistributedDataParallel (DDP)
479Что такое ZeRO-2 и ZeRO-3 в DeepSpeed
480Как профилировать память при distributed training
481Сравнение FSDP и ZeroRedundancyOptimizer
482Как выбрать размер группы для шардирования
483Как работает pipeline parallelism

11. Навигация


Навигация