中文翻译暂不可用,显示俄语原文。
Как работает selective activation recomputation?
Краткий тезис
Selective activation recomputation — это техника оптимизации памяти при обучении и инференсе больших языковых моделей (LLM), которая пересчитывает не все промежуточные активации, а только часть (например, каждый второй слой трансформера). Это компромисс между полным activation checkpointing (пересчёт всех активаций, максимальная экономия памяти, но замедление) и отсутствием пересчёта (максимальная скорость, но огромное потребление памяти). Выбор, какие активации пересчитывать, позволяет точно настроить баланс между памятью и временем вычислений, что критично для Agentic RAG-систем, где агенты выполняют множество последовательных вызовов LLM.
1. Зачем нужна оптимизация памяти в LLM
Современные LLM (например, GPT-3, LLaMA) содержат миллиарды параметров. При прямом проходе (forward pass) для каждого слоя сохраняются активации — промежуточные тензоры, необходимые для обратного распространения ошибки (backward pass). Эти активации могут занимать в десятки раз больше памяти, чем сами веса модели. Например, для модели с 7B параметров и длиной контекста 2048 активации могут потребовать > 100 ГБ памяти GPU. Без оптимизации обучение или инференс с большим батч-сайзом невозможны.
Термин: активации — выходные тензоры каждого слоя (например, после attention и feed-forward), которые сохраняются для вычисления градиентов.
Термин: обратное распространение (backward pass) — этап обучения, на котором по цепному правилу вычисляются градиенты весов; для этого нужны активации из forward pass.
2. Полный activation checkpointing (recomputation)
Базовая техника — activation checkpointing (также называемый gradient checkpointing). Идея: не сохранять все активации, а пересчитывать их заново во время backward pass из предыдущих сохранённых состояний. Обычно сохраняются активации только на границах блоков (например, после каждого трансформер-слоя), а внутри блока активации пересчитываются.
Плюсы радикальная экономия памяти (до 50–80%). Минусы дополнительное время — каждый пересчёт требует повторного forward pass для выбранных блоков. Время обучения может вырасти на 20–30%.
Термин: checkpoint — точка, в которой активации сохраняются (не пересчитываются). Обычно это выходы целых слоёв или блоков.
3. Selective activation recomputation — идея
Selective activation recomputation (выборочный пересчёт активаций) — это более гибкий подход. Вместо того чтобы пересчитывать все активации внутри checkpoint-интервала, мы выбираем только некоторые слои или операции для пересчёта. Остальные активации сохраняются в памяти.
Пример: в трансформере с 12 слоями можно сохранять активации для каждого второго слоя (слои 1,3,5,...), а для остальных пересчитывать. Или можно пересчитывать только активации attention, а feed-forward сохранять.
Ключевое преимущество мы можем точно настроить, сколько памяти сэкономить и насколько замедлить вычисления. Это позволяет, например, увеличить батч-сайз ровно настолько, насколько нужно, без излишнего замедления.
4. Стратегии выбора активаций для пересчёта
Существует несколько стратегий, какие активации пересчитывать:
| Стратегия | Описание | Пример |
|---|---|---|
| По слоям | Пересчитывать каждый N-й слой | Каждый 2-й слой трансформера |
| По типу операции | Пересчитывать только дорогие по памяти операции (например, attention) | Сохранять активации feed-forward, пересчитывать attention |
| Адаптивная | Анализировать вклад слоя в память и скорость, выбирать оптимальный набор | Использовать профилировщик памяти |
| По важности | Пересчитывать слои с наименьшим влиянием на точность (например, первые слои) | Сохранять последние слои, пересчитывать первые |
На практике часто используется равномерная стратегия «каждый K-й слой», так как она проста в реализации и даёт предсказуемый эффект.
5. Влияние на память и скорость
Рассмотрим модель с 12 слоями. Пусть полное сохранение всех активаций требует 12 единиц памяти, полный пересчёт (checkpointing) — 2 единицы (сохраняем только входы и выходы модели). Selective recomputation с сохранением каждого второго слоя (6 слоёв) потребует примерно 6 + 1 = 7 единиц памяти (сохранённые активации + входы). Время: пересчитывается 6 слоёв, что добавляет ~50% к времени backward pass (по сравнению с полным сохранением).
Таблица сравнения (гипотетические цифры для 12-слойной модели):
| Метод | Память (усл. ед.) | Время (усл. ед.) | Прирост времени |
|---|---|---|---|
| Без recomputation | 12 | 1.0 | 0% |
| Полный checkpointing | 2 | 1.3 | +30% |
| Selective (каждый 2-й слой) | 7 | 1.15 | +15% |
| Selective (каждый 3-й слой) | 9 | 1.07 | +7% |
Термин: trade-off — компромисс между памятью и скоростью. Selective recomputation позволяет выбрать точку на этом компромиссе.
6. Реализация в PyTorch (пример)
PyTorch предоставляет функцию torch.utils.checkpoint.checkpoint для создания checkpoint-блоков. Для selective recomputation можно вручную разбить модель на блоки и применить checkpoint только к выбранным.
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.attn = nn.MultiheadAttention(dim, num_heads=8)
self.ffn = nn.Sequential(
nn.Linear(dim, dim*4),
nn.ReLU(),
nn.Linear(dim*4, dim)
)
def forward(self, x):
# Обычный forward
attn_out, _ = self.attn(x, x, x)
x = x + attn_out
x = x + self.ffn(x)
return x
class SelectiveCheckpointModel(nn.Module):
def __init__(self, num_layers=12, checkpoint_every=2):
super().__init__()
self.layers = nn.ModuleList([TransformerBlock(512) for _ in range(num_layers)])
self.checkpoint_every = checkpoint_every
def forward(self, x):
for i, layer in enumerate(self.layers):
if i % self.checkpoint_every == 0:
# Сохраняем активации (не пересчитываем)
x = layer(x)
else:
# Пересчитываем активации через checkpoint
x = checkpoint(layer, x)
return x
В этом примере каждый второй слой (индекс 0,2,4,...) выполняется без checkpoint (активации сохраняются), а остальные — с checkpoint (пересчитываются при backward). Это простейшая реализация selective recomputation.
Важно checkpoint в PyTorch автоматически сохраняет входные тензоры и пересчитывает forward внутри блока при backward. Мы можем комбинировать такие блоки для выборочного пересчёта.
7. Связь с Agentic RAG
В Agentic RAG агенты (LLM) выполняют множество шагов: планирование, поиск, генерация ответа, вызов инструментов. Каждый шаг требует forward pass (и иногда backward при обучении). Если агент использует большую модель, память GPU быстро заполняется, особенно при параллельной обработке нескольких запросов. Selective activation recomputation позволяет:
- Увеличить batch size для параллельной обработки запросов агентов.
- Снизить пиковое потребление памяти при длинных контекстах (агенты могут передавать друг другу большие истории диалогов).
- Ускорить инференс за счёт меньшего числа пересчётов по сравнению с полным checkpointing.
Таким образом, эта техника напрямую повышает эффективность и масштабируемость Agentic RAG-систем.
8. Trade-offs и настройка
Выбор параметра checkpoint_every (или другой стратегии) зависит от:
- Доступной памяти GPU — чем меньше памяти, тем чаще нужно пересчитывать.
- Допустимого замедления — если время критично, лучше сохранять больше активаций.
- Длины контекста — для длинных контекстов активации attention занимают много памяти (квадратично от длины), поэтому их часто пересчитывают.
- Архитектуры модели — в некоторых моделях (например, с FlashAttention) память на attention уже оптимизирована, и пересчёт может быть невыгоден.
Рекомендуется профилировать модель с разными стратегиями и выбирать ту, которая даёт нужный баланс. Инструменты: torch.cuda.memory_summary(), torch.profiler.
Пет-проект для закрепления
Задача Реализовать selective activation recomputation для небольшой модели (например, 6-слойный трансформер) и сравнить потребление памяти и время обучения при разных стратегиях.
Инструменты PyTorch, torch.utils.checkpoint, torch.cuda.max_memory_allocated, time.
Шаги:
- Создать модель из 6 TransformerBlock (как в примере выше).
- Написать функцию обучения на синтетических данных (batch_size=8, seq_len=128).
- Измерить пиковое использование памяти и время одного шага для:
- Без checkpoint (все активации сохраняются).
- Полный checkpoint (каждый блок обёрнут в
checkpoint). - Selective: сохранять каждый 2-й слой (checkpoint_every=2).
- Selective: сохранять каждый 3-й слой (checkpoint_every=3).
- Построить таблицу и график trade-off.
Ожидаемый результат Вы увидите, как меняется память и скорость. Например, selective с checkpoint_every=2 даст ~50% экономии памяти при ~10% замедлении по сравнению с полным сохранением.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 479 | Activation checkpointing (полный пересчёт) |
| 481 | Gradient accumulation как альтернатива экономии памяти |
| 485 | Оптимизация памяти при инференсе LLM |
| 490 | FlashAttention и его влияние на память |
| 500 | Параллелизм моделей (model parallelism) |
| 510 | Смешанная точность (mixed precision) |
Навигация
- Предыдущий: 479
- Следующий: 481
- Индекс: 00. Индекс разборов