Что такое activation recomputation (checkpointing) и зачем оно нужно?

Краткий тезис

Activation recomputation (также gradient checkpointing) — это техника экономии видеопамяти при обучении глубоких нейронных сетей, при которой промежуточные активации не хранятся всё время, а пересчитываются заново во время обратного прохода из сохранённых «контрольных точек» (checkpoints). Это позволяет обучать значительно более крупные модели на том же GPU ценой дополнительных вычислений (обычно +20–30% времени). Техника особенно востребована для больших языковых моделей (LLM) и длинных последовательностей, где объём активаций доминирует над памятью параметров.


1. Проблема памяти при обучении глубоких сетей

При обучении нейронной сети с помощью backpropagation (обратного распространения ошибки) требуется хранить промежуточные активации каждого слоя, вычисленные во время прямого прохода (forward pass). Эти активации необходимы для вычисления градиентов на этапе обратного прохода (backward pass).

  • Параметры модели (веса) занимают память один раз.
  • Активации — это выходы каждого слоя для каждого элемента батча. Для больших моделей (например, GPT-3 с 175B параметров) активации могут занимать в десятки раз больше памяти, чем сами веса.

Пример:

  • Модель с 12 слоями, размер скрытого слоя 4096, длина последовательности 1024, батч 8.
  • Активации одного слоя: 8 × 1024 × 4096 × 4 байта (float32) ≈ 128 MB.
  • Для 12 слоёв: 12 × 128 MB = 1.5 GB только активаций.
  • Если добавить ещё и буферы для dropout, attention scores]] и т.д., память растёт линейно с числом слоёв O(L).

Термин Активации (activations) — выходные тензоры после применения функции активации или других операций слоя. Они являются узлами вычислительного графа.


2. Что такое activation recomputation (checkpointing)

Activation recomputation (также gradient checkpointing) — это метод, при котором мы не храним все активации, а сохраняем только входы определённых сегментов сети (checkpoints). Во время backward pass, когда требуются активации для вычисления градиентов, мы пересчитываем forward pass внутри сегмента, начиная с сохранённого checkpoint.

Идея

  • Разбиваем сеть на сегменты (например, каждые K слоёв).
  • Для каждого сегмента сохраняем только его входной тензор (checkpoint).
  • При обратном проходе, когда нужно вычислить градиенты для слоёв внутри сегмента, мы заново выполняем forward pass этого сегмента, получая все промежуточные активации, используем их для backward и затем отбрасываем.

Результат

  • Память для активаций снижается с O(L) до O(L / K) (если K — размер сегмента) или до O(sqrt(L)) при оптимальном выборе числа сегментов.
  • Время обучения увеличивается на 20–30% из-за повторных вычислений.

Термин Checkpoint — сохранённый вход сегмента, с которого начинается пересчёт.


3. Как это работает: граф вычислений и чекпоинты

Рассмотрим простую сеть из 4 слоёв. Обычный forward/backward:

Forward:  x → f1 → a1 → f2 → a2 → f3 → a3 → f4 → loss
Backward: loss → ∂f4 → ∂a3 → ∂f3 → ∂a2 → ∂f2 → ∂a1 → ∂f1
  • Хранятся все a1, a2, a3 (активации после каждого слоя).
  • Память: 3 тензора.

С checkpointing (сегмент из 2 слоёв, checkpoint на входе в сегмент):

Forward:  x → [f1 → a1 → f2 → a2] → [f3 → a3 → f4 → loss]
         checkpoint1 = x          checkpoint2 = a2
  • Хранятся только checkpoint1 и checkpoint2 (входы сегментов).
  • При backward для первого сегмента:
    1. Берём checkpoint1 = x.
    2. Пересчитываем forward: x → f1 → a1 → f2 → a2.
    3. Теперь имеем a1, a2 в памяти.
    4. Вычисляем градиенты для f1, f2.
    5. Отбрасываем a1, a2.
  • Аналогично для второго сегмента.

Память 2 тензора (checkpoint) + временно 2 активации внутри сегмента (но они не накапливаются между сегментами). Итого ~2 тензора вместо 3.

Термин Вычислительный граф (computation graph) — направленный ациклический граф, где узлы — операции, а рёбра — тензоры (активации).


4. Варианты checkpointing

4.1 Full checkpointing (каждый слой — сегмент)

  • Сохраняем вход каждого слоя.
  • Память: O(L) (но каждый checkpoint — это вход слоя, а не активация после него, что может быть меньше по размеру, если вход меньше выхода).
  • Пересчёт: каждый слой пересчитывается один раз за backward.
  • Применяется редко, так как выигрыш в памяти невелик.

4.2 Selective checkpointing (выборочные сегменты)

  • Сохраняем checkpoint только для некоторых «тяжёлых» операций (например, attention в трансформере).
  • Позволяет точно контролировать trade-off между памятью и временем.

4.3 Optimal checkpointing (теоретически оптимальный)

  • Использует динамическое программирование для выбора сегментов, минимизирующих пиковую память при заданном бюджете вычислений.
  • Даёт O(sqrt(L)) памяти.
  • Реализован в библиотеках (например, torch.utils.checkpoint с параметром use_reentrant).

Таблица сравнения

СтратегияПамять (активации)Доп. времяСложность реализации
Без checkpointingO(L)0%Нет
Full (каждый слой)O(L) (но меньше)~100%Низкая
SelectiveO(L/K)20–50%Средняя
OptimalO(sqrt(L))20–30%Высокая (используется встроенная)

5. Реализация в PyTorch: torch.utils.checkpoint

PyTorch предоставляет удобный интерфейс через модуль torch.utils.checkpoint.

Базовый пример

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)
        self.layer3 = nn.Linear(1024, 1024)

    def forward(self, x):
        # Применяем checkpoint к последовательности слоёв
        x = checkpoint(self._forward_segment, x)
        return x

    def _forward_segment(self, x):
        x = self.layer1(x)
        x = torch.relu(x)
        x = self.layer2(x)
        x = torch.relu(x)
        x = self.layer3(x)
        return x
  • checkpoint принимает функцию и её аргументы.
  • Во время forward он сохраняет только входные тензоры.
  • Во время backward он пересчитывает _forward_segment и использует полученные активации для вычисления градиентов.

Важно

  • Функция, переданная в checkpoint, не должна иметь побочных эффектов (например, изменять внешние состояния).
  • Для рекуррентных моделей или операций с сохранением состояния (например, BatchNorm) нужно использовать use_reentrant=False.

Продвинутое использование

from torch.utils.checkpoint import checkpoint_sequential

# Для последовательных моделей
segments = 4
output = checkpoint_sequential(model, segments, input)

6. Влияние на память и время

Память

  • Без checkpointing: M_params + M_activations * L (где L — число слоёв).
  • С optimal checkpointing: M_params + M_activations * sqrt(L).

Пример:

  • L = 100, M_activations = 1 GB.
  • Без: M_params + 100 GB.
  • С checkpointing: M_params + 10 GB.
  • Выигрыш в 10 раз.

Время

  • Дополнительные вычисления: каждый сегмент пересчитывается один раз за backward.
  • Если сегментов K, то forward выполняется K+1 раз (один основной + K пересчётов).
  • Обычно overhead составляет 20–30% от общего времени обучения (зависит от размера сегментов и скорости вычислений).

Термин Overhead — дополнительные накладные расходы (время, память).


7. Сравнение с другими техниками экономии памяти

ТехникаМеханизмЭкономия памятиOverhead по времениСовместимость
Gradient checkpointingПересчёт активацийO(L) → O(sqrt(L))20–30%Любая модель
Gradient accumulationНакопление градиентов за несколько микро-батчейПозволяет использовать эффективный batch size больше физическогоНет (но увеличивает шаги)Любая модель
Mixed precision (FP16/BF16)Хранение активаций и весов в половинной точности~2xНебольшой (если поддерживается оборудованием)Требует совместимости
Model parallelismРаспределение слоёв по разным GPUПамять делится между устройствамиКоммуникация между GPUСложная реализация
Pipeline parallelismРазбиение батча на микробатчи, конвейерная обработкаПамять на один микробатчКоммуникация и простоиСредняя сложность

Комбинация На практике часто используют gradient checkpointing вместе с mixed precision и gradient accumulation для обучения самых больших моделей.


8. Когда использовать activation recomputation

  • Очень большие модели (например, GPT-3, LLaMA-65B), где активации не помещаются даже в H100 (80 GB).
  • Длинные последовательности (например, документы > 8K токенов) — активации растут квадратично из-за attention.
  • Ограниченный бюджет GPU — одна карта вместо нескольких.
  • Эксперименты с архитектурой — когда нужно быстро проверить гипотезу без покупки дополнительных GPU.

Когда НЕ стоит использовать

  • Модель маленькая (активации < 20% памяти) — overhead не оправдан.
  • Время обучения критично (например, production fine-tuning с жёстким SLA).
  • Используется pipeline parallelism, где активации уже распределены.

9. Практические рекомендации

  1. Начинайте с torch.utils.checkpoint — это просто и эффективно.
  2. Выбирайте размер сегмента эмпирически: слишком много сегментов → большой overhead, слишком мало → малый выигрыш в памяти.
  3. Используйте checkpoint_sequential для моделей типа nn.Sequential.
  4. Комбинируйте с mixed precisionFP16/BF16 уменьшает размер активаций, checkpointing ещё больше.
  5. Мониторьте использование памяти с помощью torch.cuda.memory_summary().
  6. Для трансформеров часто checkpoint применяют только к блокам attention (самые тяжёлые по памяти), а FFN оставляют без checkpoint.

Пример настройки для GPT-подобной модели

class TransformerBlock(nn.Module):
    def forward(self, x):
        # checkpoint только для attention
        attn_out = checkpoint(self.attention, x)
        x = x + attn_out
        x = x + self.ffn(x)
        return x

10. Пет-проект для закрепления

Задача Обучить небольшую модель (например, 6-слойный трансформер) на задаче классификации текста с и без gradient checkpointing, сравнить пиковое потребление памяти и время эпохи.

Инструменты

  • PyTorch, torch.utils.checkpoint
  • torch.cuda.max_memory_allocated() для замера памяти
  • Датасет: IMDB (или любой другой)

Шаги:

  1. Реализовать простой трансформер (6 слоёв, 4 головы, d_model=256).
  2. Написать цикл обучения без checkpointing, замерить память и время.
  3. Добавить checkpointing для каждого второго слоя (сегмент из 2 слоёв).
  4. Повторить замеры.
  5. Построить таблицу:
КонфигурацияПиковая память (MB)Время эпохи (сек)
Без checkpoint120045
С checkpoint (сегмент 2)80058
С checkpoint (сегмент 1)60072

Ожидаемый результат

  • Память снижается на 30–50%, время растёт на 20–30%.
  • Понимание trade-off на практике.

11. Связь с другими вопросами

ВопросТема
460Что такое gradient accumulation и как он помогает при ограниченной памяти?
461Как работает mixed precision training (FP16/BF16)?
462В чём разница между data parallelism и model parallelism?
464Что такое pipeline parallelism и как он сочетается с checkpointing?
465Как оптимизировать память при обучении LLM с длинным контекстом?
470Какие техники используются для распределённого обучения больших моделей?

12. Навигация


Навигация