Как работает FSDP (Fully Sharded Data Parallel) в PyTorch?
Краткий тезис
FSDP (Fully Sharded Data Parallel) — это встроенная в PyTorch техника распределённого обучения, которая шардирует (разделяет) параметры модели, градиенты и состояния оптимизатора между GPU, экономя до 3–4× памяти по сравнению с обычным Data Parallel (DDP). FSDP реализует идеи ZeRO-3 (Zero Redundancy Optimizer) и автоматически управляет сборкой и разборкой шардов во время forward/backward проходов, что позволяет обучать модели с десятками миллиардов параметров на ограниченном числе GPU.
1. Термины и контекст
- FSDP (Fully Sharded Data Parallel) — метод, при котором каждый GPU хранит только часть (шард) параметров модели, градиентов и состояния оптимизатора. Полные веса собираются только для текущего слоя во время вычислений.
- DDP (Data Parallel|Distributed Data Parallel) — классический подход, где каждый GPU хранит полную копию модели, но градиенты усредняются через AllReduce. Память растёт линейно с размером модели.
- ZeRO (Zero Redundancy Optimizer) — семейство оптимизаций от Microsoft DeepSpeed, устраняющих избыточность памяти. ZeRO-3 шардирует всё: параметры, градиенты, состояния оптимизатора. FSDP — нативная реализация ZeRO-3 в PyTorch.
- Шардирование (sharding) — разбиение тензора на равные части по числу GPU. Каждый GPU владеет уникальным фрагментом.
- AllGather — коллективная операция, при которой каждый GPU отправляет свой шард всем остальным, и каждый получает полный тензор.
- ReduceScatter — комбинация Reduce (суммирование) и Scatter (рассылка частей). Используется для усреднения градиентов без сборки полного тензора.
2. Проблема памяти при обучении больших моделей
При обучении LLM память GPU расходуется на:
- Параметры модели (weights) — ~4 байта на параметр (FP32) или 2 байта (BF16/FP16).
- Градиенты — столько же, сколько параметры.
- Состояние оптимизатора (например, Adam хранит m и v) — ещё 8 байт на параметр (в FP32).
- Активации (промежуточные значения) — могут быть огромными.
Для модели 7B параметров в FP32:
- Параметры: 7B × 4B = 28 GB
- Градиенты: 28 GB
- Состояние Adam: 7B × 8B = 56 GB
- Итого только для обучения: 112 GB на GPU. Одна A100 (80 GB) не поместит.
FSDP решает эту проблему, распределяя эти компоненты между GPU.
3. Как работает FSDP: общая идея
FSDP делит модель на шарды (обычно по слоям). Каждый GPU хранит только свой шард параметров, градиентов и состояния оптимизатора. Во время forward/backward для текущего слоя выполняется AllGather — все GPU собирают полный слой, вычисляют, затем сбрасывают чужие шарды. После backward градиенты усредняются через ReduceScatter, и каждый GPU обновляет только свой шард параметров.
Ключевой принцип: коммуникация O(1) на слой — один AllGather и один ReduceScatter на каждый слой, независимо от числа GPU.
4. Детальный механизм
4.1 Forward pass
- Модель разбита на FSDP-юниты (обычно один юнит = один слой или блок трансформера).
- Перед вычислением юнита все GPU выполняют AllGather — каждый отправляет свой шард весов, и каждый получает полный тензор весов для этого юнита.
- Выполняется forward этого юнита (на полных весах).
- После forward чужие шарды весов удаляются (освобождается память). Остаётся только локальный шард.
4.2 Backward pass
- Для вычисления градиентов снова нужны полные веса юнита. Выполняется повторный AllGather весов (или они могут быть закешированы, но обычно нет).
- Вычисляются градиенты по полным весам.
- Затем выполняется ReduceScatter: градиенты усредняются по всем GPU, и каждый GPU получает только свой шард усреднённых градиентов.
- Локальный шард градиентов сохраняется для обновления.
4.3 Обновление оптимизатора
Каждый GPU обновляет только свой шард параметров, используя свой шард градиентов и состояния оптимизатора. Состояние оптимизатора (m, v для Adam) хранится только для локального шарда.
4.4 Коммуникационная сложность
- На каждый юнит: один AllGather (forward) + один AllGather (backward) + один ReduceScatter (backward) = 3 коллективные операции.
- Общий объём переданных данных: 3 × (размер юнита) × (число GPU) — но каждый GPU отправляет/получает только свой шард, поэтому пропускная способность используется эффективно.
5. Сравнение с DDP
| Характеристика | DDP | FSDP |
|---|---|---|
| Память на GPU | Полная копия модели + градиенты + оптимизатор | Только шард (≈ 1/N от полного) |
| Коммуникация | Один AllReduce на весь backward (весь градиент) | AllGather + ReduceScatter на каждый юнит |
| Скорость | Быстрее при малых моделях (меньше операций) | Медленнее из-за частых коллективных операций, но масштабируется |
| Максимальный размер модели | Ограничен памятью одного GPU | Может быть в N раз больше |
| Настройка | Минимальная | Требует выбора политики шардирования |
6. Сравнение с DeepSpeed ZeRO-3
| Аспект | DeepSpeed ZeRO-3 | FSDP (PyTorch) |
|---|---|---|
| Происхождение | Сторонняя библиотека | Встроен в PyTorch (с 1.11) |
| Гибкость | Множество оптимизаций (offload, CPU Adam, mixed precision) | Меньше опций, но покрывает основные сценарии |
| Интеграция | Требует обёртки deepspeed.initialize | Использует FullyShardedDataParallel как модуль |
| Производительность | Часто быстрее за счёт тонких оптимизаций | Немного медленнее, но проще в использовании |
| Поддержка | Активная, но отдельная | Нативная, обновляется вместе с PyTorch |
7. Конфигурация FSDP в PyTorch (пример кода)
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import AutoModelForCausalLM
# Инициализация процесса (обычно через torchrun)
torch.distributed.init_process_group(backend="nccl")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Автоматическая обёртка: каждый слой трансформера становится FSDP-юнитом
auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={LlamaDecoderLayer} # класс слоя вашей модели
)
fsdp_model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
)
# Обучение как обычно
optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-5)
for batch in dataloader:
outputs = fsdp_model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
Ключевые параметры
sharding_strategy:FULL_SHARD(ZeRO-3),SHARD_GRAD_OP(ZeRO-2),NO_SHARD(DDP).- auto_wrap_policy: определяет, какие модули оборачивать в отдельные FSDP-юниты.
- mixed_precision: можно включить torch.bfloat16 для экономии памяти.
8. Преимущества и недостатки
Преимущества
- Резкое снижение потребления памяти (до 3–4×).
- Возможность обучать модели, которые не влезают в один GPU.
- Нативная интеграция с PyTorch — не требует дополнительных зависимостей.
- Поддержка CPU offload (параметры могут выгружаться на CPU).
Недостатки
- Дополнительные накладные расходы на коммуникацию (AllGather/ReduceScatter на каждый слой).
- Сложность отладки (распределённые ошибки).
- Меньше оптимизаций по сравнению с DeepSpeed (например, нет ZeRO-Infinity для offload на диск).
- Требует тщательного выбора размера FSDP-юнитов: слишком мелкие → много коммуникации, слишком крупные → мало экономии памяти.
9. Когда использовать FSDP
- Модель не влезает в один GPU (например, 13B+ параметров).
- Несколько GPU доступны (от 2 до сотен).
- Простота развёртывания важнее максимальной производительности (FSDP проще, чем DeepSpeed).
- Эксперименты и прототипирование — быстро включить, не меняя код.
Не рекомендуется:
- Для маленьких моделей (DDP быстрее и проще).
- Если требуется продвинутый offload на CPU/диск (лучше DeepSpeed).
- При очень медленной сети (InfiniBand предпочтителен).
Пет-проект для закрепления
Задача Обучить небольшую модель (например, GPT-2 124M) на двух GPU с помощью FSDP и сравнить потребление памяти с DDP.
Инструменты PyTorch 2.x, torchrun, nvidia-smi, torch.cuda.max_memory_allocated.
Шаги:
- Напишите скрипт, который загружает модель GPT-2 (или любую другую) и оборачивает её в
FSDPсFULL_SHARD. - Запустите обучение на одном GPU (без FSDP) и замерьте пиковое потребление памяти.
- Запустите на двух GPU с FSDP, замерьте память на каждом GPU.
- Повторите с
ShardingStrategy.NO_SHARD(эквивалент DDP). - Постройте таблицу: модель, число GPU, стратегия, память на GPU, скорость (итераций/сек).
Ожидаемый результат Вы увидите, что FSDP на двух GPU использует примерно половину памяти на каждом по сравнению с DDP, но может быть немного медленнее.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 470 | Как работает DDP (Distributed Data Parallel) в PyTorch? |
| 472 | Как работает DeepSpeed ZeRO? |
| 473 | В чём разница между ZeRO-1, ZeRO-2 и ZeRO-3? |
| 474 | Что такое модель параллелизм (model parallelism)? |
| 475 | Как работает pipeline parallelism? |
| 476 | Что такое tensor parallelism и когда его использовать? |
Навигация
- Предыдущий: 470
- Следующий: 472
- Индекс: 00. Индекс разборов