中文翻译暂不可用,显示俄语原文。
Что такое activation recomputation (checkpointing) и зачем оно нужно?
Краткий тезис
Activation recomputation (также gradient checkpointing) — это техника экономии видеопамяти при обучении глубоких нейронных сетей, при которой промежуточные активации не хранятся всё время, а пересчитываются заново во время обратного прохода из сохранённых «контрольных точек» (checkpoints). Это позволяет обучать значительно более крупные модели на том же GPU ценой дополнительных вычислений (обычно +20–30% времени). Техника особенно востребована для больших языковых моделей (LLM) и длинных последовательностей, где объём активаций доминирует над памятью параметров.
1. Проблема памяти при обучении глубоких сетей
При обучении нейронной сети с помощью backpropagation (обратного распространения ошибки) требуется хранить промежуточные активации каждого слоя, вычисленные во время прямого прохода (forward pass). Эти активации необходимы для вычисления градиентов на этапе обратного прохода (backward pass).
- Параметры модели (веса) занимают память один раз.
- Активации — это выходы каждого слоя для каждого элемента батча. Для больших моделей (например, GPT-3 с 175B параметров) активации могут занимать в десятки раз больше памяти, чем сами веса.
Пример:
- Модель с 12 слоями, размер скрытого слоя 4096, длина последовательности 1024, батч 8.
- Активации одного слоя: 8 × 1024 × 4096 × 4 байта (float32) ≈ 128 MB.
- Для 12 слоёв:
12 × 128 MB = 1.5 GBтолько активаций. - Если добавить ещё и буферы для dropout, attention scores]] и т.д., память растёт линейно с числом слоёв O(L).
Термин Активации (activations) — выходные тензоры после применения функции активации или других операций слоя. Они являются узлами вычислительного графа.
2. Что такое activation recomputation (checkpointing)
Activation recomputation (также gradient checkpointing) — это метод, при котором мы не храним все активации, а сохраняем только входы определённых сегментов сети (checkpoints). Во время backward pass, когда требуются активации для вычисления градиентов, мы пересчитываем forward pass внутри сегмента, начиная с сохранённого checkpoint.
Идея
- Разбиваем сеть на сегменты (например, каждые K слоёв).
- Для каждого сегмента сохраняем только его входной тензор (checkpoint).
- При обратном проходе, когда нужно вычислить градиенты для слоёв внутри сегмента, мы заново выполняем forward pass этого сегмента, получая все промежуточные активации, используем их для backward и затем отбрасываем.
Результат
- Память для активаций снижается с O(L) до O(L / K) (если K — размер сегмента) или до O(sqrt(L)) при оптимальном выборе числа сегментов.
- Время обучения увеличивается на 20–30% из-за повторных вычислений.
Термин Checkpoint — сохранённый вход сегмента, с которого начинается пересчёт.
3. Как это работает: граф вычислений и чекпоинты
Рассмотрим простую сеть из 4 слоёв. Обычный forward/backward:
Forward: x → f1 → a1 → f2 → a2 → f3 → a3 → f4 → loss
Backward: loss → ∂f4 → ∂a3 → ∂f3 → ∂a2 → ∂f2 → ∂a1 → ∂f1
- Хранятся все
a1, a2, a3(активации после каждого слоя). - Память: 3 тензора.
С checkpointing (сегмент из 2 слоёв, checkpoint на входе в сегмент):
Forward: x → [f1 → a1 → f2 → a2] → [f3 → a3 → f4 → loss]
checkpoint1 = x checkpoint2 = a2
- Хранятся только
checkpoint1иcheckpoint2(входы сегментов). - При backward для первого сегмента:
- Берём
checkpoint1 = x. - Пересчитываем forward:
x → f1 → a1 → f2 → a2. - Теперь имеем
a1, a2в памяти. - Вычисляем градиенты для
f1, f2. - Отбрасываем
a1, a2.
- Берём
- Аналогично для второго сегмента.
Память 2 тензора (checkpoint) + временно 2 активации внутри сегмента (но они не накапливаются между сегментами). Итого ~2 тензора вместо 3.
Термин Вычислительный граф (computation graph) — направленный ациклический граф, где узлы — операции, а рёбра — тензоры (активации).
4. Варианты checkpointing
4.1 Full checkpointing (каждый слой — сегмент)
- Сохраняем вход каждого слоя.
- Память: O(L) (но каждый checkpoint — это вход слоя, а не активация после него, что может быть меньше по размеру, если вход меньше выхода).
- Пересчёт: каждый слой пересчитывается один раз за backward.
- Применяется редко, так как выигрыш в памяти невелик.
4.2 Selective checkpointing (выборочные сегменты)
- Сохраняем checkpoint только для некоторых «тяжёлых» операций (например, attention в трансформере).
- Позволяет точно контролировать trade-off между памятью и временем.
4.3 Optimal checkpointing (теоретически оптимальный)
- Использует динамическое программирование для выбора сегментов, минимизирующих пиковую память при заданном бюджете вычислений.
- Даёт O(sqrt(L)) памяти.
- Реализован в библиотеках (например, torch.utils.checkpoint с параметром
use_reentrant).
Таблица сравнения
| Стратегия | Память (активации) | Доп. время | Сложность реализации |
|---|---|---|---|
| Без checkpointing | O(L) | 0% | Нет |
| Full (каждый слой) | O(L) (но меньше) | ~100% | Низкая |
| Selective | O(L/K) | 20–50% | Средняя |
| Optimal | O(sqrt(L)) | 20–30% | Высокая (используется встроенная) |
5. Реализация в PyTorch: torch.utils.checkpoint
PyTorch предоставляет удобный интерфейс через модуль torch.utils.checkpoint.
Базовый пример
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(1024, 1024)
self.layer2 = nn.Linear(1024, 1024)
self.layer3 = nn.Linear(1024, 1024)
def forward(self, x):
# Применяем checkpoint к последовательности слоёв
x = checkpoint(self._forward_segment, x)
return x
def _forward_segment(self, x):
x = self.layer1(x)
x = torch.relu(x)
x = self.layer2(x)
x = torch.relu(x)
x = self.layer3(x)
return x
- checkpoint принимает функцию и её аргументы.
- Во время forward он сохраняет только входные тензоры.
- Во время backward он пересчитывает
_forward_segmentи использует полученные активации для вычисления градиентов.
Важно
- Функция, переданная в checkpoint, не должна иметь побочных эффектов (например, изменять внешние состояния).
- Для рекуррентных моделей или операций с сохранением состояния (например, BatchNorm) нужно использовать
use_reentrant=False.
Продвинутое использование
from torch.utils.checkpoint import checkpoint_sequential
# Для последовательных моделей
segments = 4
output = checkpoint_sequential(model, segments, input)
6. Влияние на память и время
Память
- Без checkpointing:
M_params + M_activations * L(где L — число слоёв). - С optimal checkpointing:
M_params + M_activations * sqrt(L).
Пример:
- L = 100, M_activations = 1 GB.
- Без:
M_params + 100 GB. - С checkpointing:
M_params + 10 GB. - Выигрыш в 10 раз.
Время
- Дополнительные вычисления: каждый сегмент пересчитывается один раз за backward.
- Если сегментов K, то forward выполняется K+1 раз (один основной + K пересчётов).
- Обычно overhead составляет 20–30% от общего времени обучения (зависит от размера сегментов и скорости вычислений).
Термин Overhead — дополнительные накладные расходы (время, память).
7. Сравнение с другими техниками экономии памяти
| Техника | Механизм | Экономия памяти | Overhead по времени | Совместимость |
|---|---|---|---|---|
| Gradient checkpointing | Пересчёт активаций | O(L) → O(sqrt(L)) | 20–30% | Любая модель |
| Gradient accumulation | Накопление градиентов за несколько микро-батчей | Позволяет использовать эффективный batch size больше физического | Нет (но увеличивает шаги) | Любая модель |
| Mixed precision (FP16/BF16) | Хранение активаций и весов в половинной точности | ~2x | Небольшой (если поддерживается оборудованием) | Требует совместимости |
| Model parallelism | Распределение слоёв по разным GPU | Память делится между устройствами | Коммуникация между GPU | Сложная реализация |
| Pipeline parallelism | Разбиение батча на микробатчи, конвейерная обработка | Память на один микробатч | Коммуникация и простои | Средняя сложность |
Комбинация На практике часто используют gradient checkpointing вместе с mixed precision и gradient accumulation для обучения самых больших моделей.
8. Когда использовать activation recomputation
- Очень большие модели (например, GPT-3, LLaMA-65B), где активации не помещаются даже в H100 (80 GB).
- Длинные последовательности (например, документы > 8K токенов) — активации растут квадратично из-за attention.
- Ограниченный бюджет GPU — одна карта вместо нескольких.
- Эксперименты с архитектурой — когда нужно быстро проверить гипотезу без покупки дополнительных GPU.
Когда НЕ стоит использовать
- Модель маленькая (активации < 20% памяти) — overhead не оправдан.
- Время обучения критично (например, production fine-tuning с жёстким SLA).
- Используется pipeline parallelism, где активации уже распределены.
9. Практические рекомендации
- Начинайте с
torch.utils.checkpoint— это просто и эффективно. - Выбирайте размер сегмента эмпирически: слишком много сегментов → большой overhead, слишком мало → малый выигрыш в памяти.
- Используйте
checkpoint_sequentialдля моделей типаnn.Sequential. - Комбинируйте с mixed precision — FP16/BF16 уменьшает размер активаций, checkpointing ещё больше.
- Мониторьте использование памяти с помощью
torch.cuda.memory_summary(). - Для трансформеров часто checkpoint применяют только к блокам attention (самые тяжёлые по памяти), а FFN оставляют без checkpoint.
Пример настройки для GPT-подобной модели
class TransformerBlock(nn.Module):
def forward(self, x):
# checkpoint только для attention
attn_out = checkpoint(self.attention, x)
x = x + attn_out
x = x + self.ffn(x)
return x
10. Пет-проект для закрепления
Задача Обучить небольшую модель (например, 6-слойный трансформер) на задаче классификации текста с и без gradient checkpointing, сравнить пиковое потребление памяти и время эпохи.
Инструменты
- PyTorch,
torch.utils.checkpoint torch.cuda.max_memory_allocated()для замера памяти- Датасет: IMDB (или любой другой)
Шаги:
- Реализовать простой трансформер (6 слоёв, 4 головы, d_model=256).
- Написать цикл обучения без checkpointing, замерить память и время.
- Добавить checkpointing для каждого второго слоя (сегмент из 2 слоёв).
- Повторить замеры.
- Построить таблицу:
| Конфигурация | Пиковая память (MB) | Время эпохи (сек) |
|---|---|---|
| Без checkpoint | 1200 | 45 |
| С checkpoint (сегмент 2) | 800 | 58 |
| С checkpoint (сегмент 1) | 600 | 72 |
Ожидаемый результат
- Память снижается на 30–50%, время растёт на 20–30%.
- Понимание trade-off на практике.
11. Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 460 | Что такое gradient accumulation и как он помогает при ограниченной памяти? |
| 461 | Как работает mixed precision training (FP16/BF16)? |
| 462 | В чём разница между data parallelism и model parallelism? |
| 464 | Что такое pipeline parallelism и как он сочетается с checkpointing? |
| 465 | Как оптимизировать память при обучении LLM с длинным контекстом? |
| 470 | Какие техники используются для распределённого обучения больших моделей? |
12. Навигация
- Предыдущий: 462
- Следующий: 464
- Индекс: 00. Индекс разборов
Навигация
- Предыдущий: 462
- Следующий: 464
- Индекс: 00. Индекс разборов