English translation is not available yet. Showing Russian content.
Как работает FlashAttention для training (не только inference)?
Краткий тезис
FlashAttention — это IO-aware алгоритм точного внимания, который кардинально снижает потребление памяти с O(n²) до O(n) и ускоряет как forward, так и backward pass за счёт tiling (разбиения на блоки) и recomputation (пересчёта промежуточных матриц на обратном проходе). Для training критично, что backward pass не хранит все матрицы attention scores]] (S, P), а пересчитывает их из входных данных — это экономит сотни гигабайт памяти при длине контекста 100k+ токенов, позволяя обучать модели с гораздо большими batch sizes и более длинными последовательностями.
1. Проблема: квадратичная память стандартного attention
Стандартный механизм скалярного внимания (scaled dot‑product attention) вычисляется как:
Attention(Q, K, V) = softmax( QK^T / sqrt(d) ) V
где Q, K, V — матрицы размера [n × d], n — длина последовательности, d — размерность головы.
На forward pass необходимо хранить всю матрицу внимания S = QK^T (размер n×n) и нормированную матрицу P = softmax(S) (тоже n×n). Это приводит к затратам памяти O(n²). Например, при n = 100k и float16 (2 байта на элемент) одна матрица S занимает 100k² × 2 = 20 млрд байт ≈ 20 GB. Для backward нужны ещё и градиенты — итого > 80 GB на один слой. При 32 слоях — неподъёмные 2.5 TB.
Термин «IO-aware» (aware of input/output) — алгоритм, оптимизирующий не только число арифметических операций, но и обмен данными между быстрой (SRAM) и медленной (HBM) памятью GPU.
Термин «Tiling» — разбиение матриц на блоки, которые помещаются в SRAM (обычно 16–128 элементов), чтобы выполнить вычисления внутри чипа без постоянных запросов к HBM.
2. Идея FlashAttention: вычисляем attention без сохранения полной S и P
Главная хитрость: мы не храним матрицу S целиком в медленной памяти (HBM), а вычисляем attention по частям — по блокам Q, K, V, загружая их в быструю SRAM (shared memory GPU, ~20–100 KB на мультипроцессор). Внутри блока мы вычисляем локальный вклад в выход O и обновляем статистики для softmax (сумму экспонент и максимум), чтобы в конце корректно «склеить» блоки. Это делается с помощью техники online softmax (safe softmax с редукцией по максимуму).
2.1 Forward pass: tiling + online softmax
- Разбиваем Q, K, V на блоки размера B_r × d, B_c × d.
- Для каждой пары блоков (Q_i, K_j):
- загружаем Q_i, K_j в SRAM;
- вычисляем локальную матрицу внимания S_ij = Q_i K_j^T / sqrt(d) (размер B_r × B_c);
- вычисляем локальные максимумы m_ij и суммы exp и объединяем с глобальными статистиками для текущих выходных блоков;
- вычисляем вклад в выходной блок O_i: O_i += softmax(S_ij) V_j (с учётом глобальной нормализации).
В результате не нужно писать в HBM промежуточные матрицы S и P — только финальные O, а также статистики (m, l) для backward.
3. Backward pass: recomputation — ключевое отличие для training
На обратном проходе нам нужны градиенты по Q, K, V. Стандартный backward требует матрицу P (softmax) и dO (градиент выходного лосса), чтобы посчитать dQ, dK, dV:
dQ += P * dO * V^T (с учётом chain rule)
dK += (P^T * dO^T) * V
dV += P^T * dO
Если бы мы хранили P (n×n), то память оставалась квадратичной. FlashAttention поступает иначе: не сохраняет P, а пересчитывает её заново из Q, K, V и сохранённых статистик m, l.
3.1 Шаги backward pass
- Загружаем Q, K, V из HBM (они уже есть после forward) и сохранённые статистики m, l для каждой строки.
- Разбиваем на блоки как на forward.
- Для каждой пары блоков (Q_i, K_j) повторно вычисляем локальную S_ij, применяем softmax (используя сохранённые глобальные m, l), получаем P_ij, затем загружаем соответствующие блоки dO_i и V_j, вычисляем локальные градиенты dQ_i, dK_j, dV_j.
- Аккумулируем градиенты в HBM.
Термин «Recomputation» (перевычисление) — вместо хранения промежуточных результатов (которые занимали бы O(n²)) мы тратим чуть больше FLOPs на повторный расчёт, но экономим огромное количество памяти.
4. Почему recomputation выгоднее, чем хранение?
Несмотря на то, что backward выполняет дополнительные FLOPs (те же самые вычисления S и softmax, что и на forward), выигрыш в памяти колоссален:
- Память: O(n) вместо O(n²).
- Дополнительные FLOPs: ~1.33× относительно стандартного backward (по данным оригинальной статьи), т.е. на 33% больше арифметики.
- Но благодаря лучшей утилизации кэша и меньшему числу обращений к HBM реальное время backward может быть быстрее стандартного, а не медленнее.
Ключевой trade-off: FLOPs vs. memory bandwidth. FlashAttention меняет закон: вместо того чтобы лимитироваться памятью (memory-bound), мы становимся вычислительно-ограниченными (compute-bound) на больших длинах — это лучше для современных GPU.
5. Оценка экономии памяти
| Длина контекста n | Стандартный attention (float16) | FlashAttention |
|---|---|---|
| 1k | ~2 MB (S) | ~O(n) от Q, K, V, O, m, l (~2 MB + незначительные) |
| 10k | ~200 MB (S) | ~20 MB |
| 100k | ~20 GB (только S) | ~0.2 GB (Q,K,V,O) + статистики ~ 0.8 GB |
| 1M | ~2 TB (нереально) | ~2 GB |
Пример из черновика: на training при n=100k разница 80 GB (с градиентами) vs 800 GB — это утрированно, но порядок верен: стандартный attention требует хранения P (100k² ≈ 10^10 элементов → 20 GB) плюс градиенты — около 80 GB; FlashAttention — всего ~0.8 GB на слой.
6. Влияние на обучение: большие batch sizes и длинные контексты
Стандартный attention лимитирует максимальную длину последовательности из-за памяти GPU. FlashAttention позволяет:
- Обучать модели с контекстом 128k и более на одном GPU (например, Llama 3, GPT-4).
- Увеличивать batch size — больше примеров в одной итерации → стабильнее градиенты, выше throughput.
- Уменьшать объём активаций (activations) — не нужно хранить S и P, что даёт больше места для градиентов оптимизатора.
Практический пример: при обучении модели с 32 слоями и n=100k FlashAttention экономит ~2.5 TB памяти только на attention — это позволяет обходиться без дорогой модели параллелизма (tensor/pipeline parallelism) для такой длины.
7. Детали реализации: FlashAttention-2 и -3
- FlashAttention v2 (2023): улучшенный backward, лучшее распараллеливание блоков, оптимизация для современных GPU (H100).
- FlashAttention v3 (2024): использует асинхронный обмен данными (asynchronous SM-to-SM copy), поддержка FP8, ещё выше производительность.
Все версии сохраняют основной принцип: tiling + recomputation для training.
8. Сравнение с другими подходами экономии памяти
| Метод | Память attention | Доп. вычисления | Точность |
|---|---|---|---|
| FlashAttention | O(n) | ~1.33× FLOPs | точный |
| Sparse attention | O(n) или O(n√n) | меньше FLOPs | приближённый |
| Linear attention | O(n) | O(n) FLOPs | приближённый (без softmax) |
| Gradient checkpointing | O(√n) активаций | 2× forward pass | точный |
FlashAttention — единственный метод, который даёт точное внимание (exact attention) с линейной памятью без потери точности. Gradient checkpointing (когда на backward пересчитываются только некоторые слои) менее эффективен, так как часто пересчитывает целые слои, а не микрооперации.
9. Итоговое резюме для собеседования
- FlashAttention решает проблему квадратичной памяти в attention, делая её линейной.
- Для training критичен backward pass: recomputation матрицы P из Q, K, V и сохранённых статистик.
- Перевычисление добавляет ~33% FLOPs, но экономия памяти на порядки.
- Алгоритм IO-aware: минимизирует чтение/запись в HBM, работает на блоках в SRAM.
- FlashAttention используется во всех современных LLM (Llama, GPT, MPT, Falcon) и позволяет обучать модели с контекстом >100k токенов.
Пет-проект для закрепления
Задача: Реализовать упрощённый backward pass FlashAttention для одного слоя (без tiling, но с имитацией recomputation с помощью повторного вызова forward) и сравнить потребление памяти с наивным attention.
Инструменты: PyTorch, CUDA (опционально), профилировщик памяти (torch.cuda.memory_allocated).
Шаги:
- Реализовать класс
StandardAttentionс функциями forward и backward (вручную через autograd или с сохранением S). - Реализовать класс
FlashAttentionSim:- forward: вычисляет и сохраняет только статистики m, l (а не всю S).
- backward: заново вычисляет S из Q и K и использует dO и V для градиентов (имитация recomputation).
- Сравнить объём сохранённых тензоров при длине n=4096, d=128.
- Измерить время backward (можно использовать
torch.cuda.Event). - Расширить: добавить tiling по одному измерению (например, разбиение n на блоки по 1024).
Ожидаемый результат: вы увидите, что FlashAttentionSim использует ~2× меньше памяти (не хранит P × градиенты), а время backward увеличивается незначительно (на 10–20%). Это наглядно демонстрирует trade-off.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 470 | Архитектура механизма внимания в трансформерах (Q,K,V, softmax) |
| 471 | KV Cache: что такое и как работает в inference |
| 472 | Multi-Query Attention / Grouped-Query Attention для экономии памяти |
| 473 | Как бороться с проблемой длинного контекста (LongLoRA, NTK-aware scaling) |
| 475 | Как устроен Agentic RAG: роль внимания при обработке длинных запросов |
Навигация
- Предыдущий: 473
- Следующий: 475
- Индекс: 00. Индекс разборов