English translation is not available yet. Showing Russian content.

Как работает FSDP (Fully Sharded Data Parallel) в PyTorch?

Краткий тезис

FSDP (Fully Sharded Data Parallel) — это встроенная в PyTorch техника распределённого обучения, которая шардирует (разделяет) параметры модели, градиенты и состояния оптимизатора между GPU, экономя до 3–4× памяти по сравнению с обычным Data Parallel (DDP). FSDP реализует идеи ZeRO-3 (Zero Redundancy Optimizer) и автоматически управляет сборкой и разборкой шардов во время forward/backward проходов, что позволяет обучать модели с десятками миллиардов параметров на ограниченном числе GPU.


1. Термины и контекст

  • FSDP (Fully Sharded Data Parallel) — метод, при котором каждый GPU хранит только часть (шард) параметров модели, градиентов и состояния оптимизатора. Полные веса собираются только для текущего слоя во время вычислений.
  • DDP (Data Parallel|Distributed Data Parallel) — классический подход, где каждый GPU хранит полную копию модели, но градиенты усредняются через AllReduce. Память растёт линейно с размером модели.
  • ZeRO (Zero Redundancy Optimizer) — семейство оптимизаций от Microsoft DeepSpeed, устраняющих избыточность памяти. ZeRO-3 шардирует всё: параметры, градиенты, состояния оптимизатора. FSDP — нативная реализация ZeRO-3 в PyTorch.
  • Шардирование (sharding) — разбиение тензора на равные части по числу GPU. Каждый GPU владеет уникальным фрагментом.
  • AllGather — коллективная операция, при которой каждый GPU отправляет свой шард всем остальным, и каждый получает полный тензор.
  • ReduceScatter — комбинация Reduce (суммирование) и Scatter (рассылка частей). Используется для усреднения градиентов без сборки полного тензора.

2. Проблема памяти при обучении больших моделей

При обучении LLM память GPU расходуется на:

  • Параметры модели (weights) — ~4 байта на параметр (FP32) или 2 байта (BF16/FP16).
  • Градиенты — столько же, сколько параметры.
  • Состояние оптимизатора (например, Adam хранит m и v) — ещё 8 байт на параметр (в FP32).
  • Активации (промежуточные значения) — могут быть огромными.

Для модели 7B параметров в FP32:

  • Параметры: 7B × 4B = 28 GB
  • Градиенты: 28 GB
  • Состояние Adam: 7B × 8B = 56 GB
  • Итого только для обучения: 112 GB на GPU. Одна A100 (80 GB) не поместит.

FSDP решает эту проблему, распределяя эти компоненты между GPU.


3. Как работает FSDP: общая идея

FSDP делит модель на шарды (обычно по слоям). Каждый GPU хранит только свой шард параметров, градиентов и состояния оптимизатора. Во время forward/backward для текущего слоя выполняется AllGather — все GPU собирают полный слой, вычисляют, затем сбрасывают чужие шарды. После backward градиенты усредняются через ReduceScatter, и каждый GPU обновляет только свой шард параметров.

Ключевой принцип: коммуникация O(1) на слой — один AllGather и один ReduceScatter на каждый слой, независимо от числа GPU.


4. Детальный механизм

4.1 Forward pass

  1. Модель разбита на FSDP-юниты (обычно один юнит = один слой или блок трансформера).
  2. Перед вычислением юнита все GPU выполняют AllGather — каждый отправляет свой шард весов, и каждый получает полный тензор весов для этого юнита.
  3. Выполняется forward этого юнита (на полных весах).
  4. После forward чужие шарды весов удаляются (освобождается память). Остаётся только локальный шард.

4.2 Backward pass

  1. Для вычисления градиентов снова нужны полные веса юнита. Выполняется повторный AllGather весов (или они могут быть закешированы, но обычно нет).
  2. Вычисляются градиенты по полным весам.
  3. Затем выполняется ReduceScatter: градиенты усредняются по всем GPU, и каждый GPU получает только свой шард усреднённых градиентов.
  4. Локальный шард градиентов сохраняется для обновления.

4.3 Обновление оптимизатора

Каждый GPU обновляет только свой шард параметров, используя свой шард градиентов и состояния оптимизатора. Состояние оптимизатора (m, v для Adam) хранится только для локального шарда.

4.4 Коммуникационная сложность

  • На каждый юнит: один AllGather (forward) + один AllGather (backward) + один ReduceScatter (backward) = 3 коллективные операции.
  • Общий объём переданных данных: 3 × (размер юнита) × (число GPU) — но каждый GPU отправляет/получает только свой шард, поэтому пропускная способность используется эффективно.

5. Сравнение с DDP

ХарактеристикаDDPFSDP
Память на GPUПолная копия модели + градиенты + оптимизаторТолько шард (≈ 1/N от полного)
КоммуникацияОдин AllReduce на весь backward (весь градиент)AllGather + ReduceScatter на каждый юнит
СкоростьБыстрее при малых моделях (меньше операций)Медленнее из-за частых коллективных операций, но масштабируется
Максимальный размер моделиОграничен памятью одного GPUМожет быть в N раз больше
НастройкаМинимальнаяТребует выбора политики шардирования

6. Сравнение с DeepSpeed ZeRO-3

АспектDeepSpeed ZeRO-3FSDP (PyTorch)
ПроисхождениеСторонняя библиотекаВстроен в PyTorch (с 1.11)
ГибкостьМножество оптимизаций (offload, CPU Adam, mixed precision)Меньше опций, но покрывает основные сценарии
ИнтеграцияТребует обёртки deepspeed.initializeИспользует FullyShardedDataParallel как модуль
ПроизводительностьЧасто быстрее за счёт тонких оптимизацийНемного медленнее, но проще в использовании
ПоддержкаАктивная, но отдельнаяНативная, обновляется вместе с PyTorch

7. Конфигурация FSDP в PyTorch (пример кода)

import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import AutoModelForCausalLM

# Инициализация процесса (обычно через torchrun)
torch.distributed.init_process_group(backend="nccl")

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

# Автоматическая обёртка: каждый слой трансформера становится FSDP-юнитом
auto_wrap_policy = transformer_auto_wrap_policy(
    transformer_layer_cls={LlamaDecoderLayer}  # класс слоя вашей модели
)

fsdp_model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # ZeRO-3
    auto_wrap_policy=auto_wrap_policy,
    device_id=torch.cuda.current_device(),
)

# Обучение как обычно
optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-5)
for batch in dataloader:
    outputs = fsdp_model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Ключевые параметры

  • sharding_strategy: FULL_SHARD (ZeRO-3), SHARD_GRAD_OP (ZeRO-2), NO_SHARD (DDP).
  • auto_wrap_policy: определяет, какие модули оборачивать в отдельные FSDP-юниты.
  • mixed_precision: можно включить torch.bfloat16 для экономии памяти.

8. Преимущества и недостатки

Преимущества

  • Резкое снижение потребления памяти (до 3–4×).
  • Возможность обучать модели, которые не влезают в один GPU.
  • Нативная интеграция с PyTorch — не требует дополнительных зависимостей.
  • Поддержка CPU offload (параметры могут выгружаться на CPU).

Недостатки

  • Дополнительные накладные расходы на коммуникацию (AllGather/ReduceScatter на каждый слой).
  • Сложность отладки (распределённые ошибки).
  • Меньше оптимизаций по сравнению с DeepSpeed (например, нет ZeRO-Infinity для offload на диск).
  • Требует тщательного выбора размера FSDP-юнитов: слишком мелкие → много коммуникации, слишком крупные → мало экономии памяти.

9. Когда использовать FSDP

  • Модель не влезает в один GPU (например, 13B+ параметров).
  • Несколько GPU доступны (от 2 до сотен).
  • Простота развёртывания важнее максимальной производительности (FSDP проще, чем DeepSpeed).
  • Эксперименты и прототипирование — быстро включить, не меняя код.

Не рекомендуется:

  • Для маленьких моделей (DDP быстрее и проще).
  • Если требуется продвинутый offload на CPU/диск (лучше DeepSpeed).
  • При очень медленной сети (InfiniBand предпочтителен).

Пет-проект для закрепления

Задача Обучить небольшую модель (например, GPT-2 124M) на двух GPU с помощью FSDP и сравнить потребление памяти с DDP.

Инструменты PyTorch 2.x, torchrun, nvidia-smi, torch.cuda.max_memory_allocated.

Шаги:

  1. Напишите скрипт, который загружает модель GPT-2 (или любую другую) и оборачивает её в FSDP с FULL_SHARD.
  2. Запустите обучение на одном GPU (без FSDP) и замерьте пиковое потребление памяти.
  3. Запустите на двух GPU с FSDP, замерьте память на каждом GPU.
  4. Повторите с ShardingStrategy.NO_SHARD (эквивалент DDP).
  5. Постройте таблицу: модель, число GPU, стратегия, память на GPU, скорость (итераций/сек).

Ожидаемый результат Вы увидите, что FSDP на двух GPU использует примерно половину памяти на каждом по сравнению с DDP, но может быть немного медленнее.


Связь с другими вопросами

ВопросТема
470Как работает DDP (Distributed Data Parallel) в PyTorch?
472Как работает DeepSpeed ZeRO?
473В чём разница между ZeRO-1, ZeRO-2 и ZeRO-3?
474Что такое модель параллелизм (model parallelism)?
475Как работает pipeline parallelism?
476Что такое tensor parallelism и когда его использовать?

Навигация