中文翻译暂不可用,显示俄语原文。

Как работает FlashAttention для training (не только inference)?

Краткий тезис

FlashAttention — это IO-aware алгоритм точного внимания, который кардинально снижает потребление памяти с O(n²) до O(n) и ускоряет как forward, так и backward pass за счёт tiling (разбиения на блоки) и recomputation (пересчёта промежуточных матриц на обратном проходе). Для training критично, что backward pass не хранит все матрицы attention scores]] (S, P), а пересчитывает их из входных данных — это экономит сотни гигабайт памяти при длине контекста 100k+ токенов, позволяя обучать модели с гораздо большими batch sizes и более длинными последовательностями.


1. Проблема: квадратичная память стандартного attention

Стандартный механизм скалярного внимания (scaled dot‑product attention) вычисляется как:

Attention(Q, K, V) = softmax( QK^T / sqrt(d) ) V

где Q, K, V — матрицы размера [n × d], n — длина последовательности, d — размерность головы.

На forward pass необходимо хранить всю матрицу внимания S = QK^T (размер n×n) и нормированную матрицу P = softmax(S) (тоже n×n). Это приводит к затратам памяти O(n²). Например, при n = 100k и float16 (2 байта на элемент) одна матрица S занимает 100k² × 2 = 20 млрд байт ≈ 20 GB. Для backward нужны ещё и градиенты — итого > 80 GB на один слой. При 32 слоях — неподъёмные 2.5 TB.

Термин «IO-aware» (aware of input/output) — алгоритм, оптимизирующий не только число арифметических операций, но и обмен данными между быстрой (SRAM) и медленной (HBM) памятью GPU.

Термин «Tiling» — разбиение матриц на блоки, которые помещаются в SRAM (обычно 16–128 элементов), чтобы выполнить вычисления внутри чипа без постоянных запросов к HBM.


2. Идея FlashAttention: вычисляем attention без сохранения полной S и P

Главная хитрость: мы не храним матрицу S целиком в медленной памяти (HBM), а вычисляем attention по частям — по блокам Q, K, V, загружая их в быструю SRAM (shared memory GPU, ~20–100 KB на мультипроцессор). Внутри блока мы вычисляем локальный вклад в выход O и обновляем статистики для softmax (сумму экспонент и максимум), чтобы в конце корректно «склеить» блоки. Это делается с помощью техники online softmax (safe softmax с редукцией по максимуму).

2.1 Forward pass: tiling + online softmax

  1. Разбиваем Q, K, V на блоки размера B_r × d, B_c × d.
  2. Для каждой пары блоков (Q_i, K_j):
    • загружаем Q_i, K_j в SRAM;
    • вычисляем локальную матрицу внимания S_ij = Q_i K_j^T / sqrt(d) (размер B_r × B_c);
    • вычисляем локальные максимумы m_ij и суммы exp и объединяем с глобальными статистиками для текущих выходных блоков;
    • вычисляем вклад в выходной блок O_i: O_i += softmax(S_ij) V_j (с учётом глобальной нормализации).

В результате не нужно писать в HBM промежуточные матрицы S и P — только финальные O, а также статистики (m, l) для backward.


3. Backward pass: recomputation — ключевое отличие для training

На обратном проходе нам нужны градиенты по Q, K, V. Стандартный backward требует матрицу P (softmax) и dO (градиент выходного лосса), чтобы посчитать dQ, dK, dV:

dQ += P * dO * V^T   (с учётом chain rule)
dK += (P^T * dO^T) * V
dV += P^T * dO

Если бы мы хранили P (n×n), то память оставалась квадратичной. FlashAttention поступает иначе: не сохраняет P, а пересчитывает её заново из Q, K, V и сохранённых статистик m, l.

3.1 Шаги backward pass

  1. Загружаем Q, K, V из HBM (они уже есть после forward) и сохранённые статистики m, l для каждой строки.
  2. Разбиваем на блоки как на forward.
  3. Для каждой пары блоков (Q_i, K_j) повторно вычисляем локальную S_ij, применяем softmax (используя сохранённые глобальные m, l), получаем P_ij, затем загружаем соответствующие блоки dO_i и V_j, вычисляем локальные градиенты dQ_i, dK_j, dV_j.
  4. Аккумулируем градиенты в HBM.

Термин «Recomputation» (перевычисление) — вместо хранения промежуточных результатов (которые занимали бы O(n²)) мы тратим чуть больше FLOPs на повторный расчёт, но экономим огромное количество памяти.


4. Почему recomputation выгоднее, чем хранение?

Несмотря на то, что backward выполняет дополнительные FLOPs (те же самые вычисления S и softmax, что и на forward), выигрыш в памяти колоссален:

  • Память: O(n) вместо O(n²).
  • Дополнительные FLOPs: ~1.33× относительно стандартного backward (по данным оригинальной статьи), т.е. на 33% больше арифметики.
  • Но благодаря лучшей утилизации кэша и меньшему числу обращений к HBM реальное время backward может быть быстрее стандартного, а не медленнее.

Ключевой trade-off: FLOPs vs. memory bandwidth. FlashAttention меняет закон: вместо того чтобы лимитироваться памятью (memory-bound), мы становимся вычислительно-ограниченными (compute-bound) на больших длинах — это лучше для современных GPU.


5. Оценка экономии памяти

Длина контекста nСтандартный attention (float16)FlashAttention
1k~2 MB (S)~O(n) от Q, K, V, O, m, l (~2 MB + незначительные)
10k~200 MB (S)~20 MB
100k~20 GB (только S)~0.2 GB (Q,K,V,O) + статистики ~ 0.8 GB
1M~2 TB (нереально)~2 GB

Пример из черновика: на training при n=100k разница 80 GB (с градиентами) vs 800 GB — это утрированно, но порядок верен: стандартный attention требует хранения P (100k² ≈ 10^10 элементов → 20 GB) плюс градиенты — около 80 GB; FlashAttention — всего ~0.8 GB на слой.


6. Влияние на обучение: большие batch sizes и длинные контексты

Стандартный attention лимитирует максимальную длину последовательности из-за памяти GPU. FlashAttention позволяет:

  • Обучать модели с контекстом 128k и более на одном GPU (например, Llama 3, GPT-4).
  • Увеличивать batch size — больше примеров в одной итерации → стабильнее градиенты, выше throughput.
  • Уменьшать объём активаций (activations) — не нужно хранить S и P, что даёт больше места для градиентов оптимизатора.

Практический пример: при обучении модели с 32 слоями и n=100k FlashAttention экономит ~2.5 TB памяти только на attention — это позволяет обходиться без дорогой модели параллелизма (tensor/pipeline parallelism) для такой длины.


7. Детали реализации: FlashAttention-2 и -3

  • FlashAttention v2 (2023): улучшенный backward, лучшее распараллеливание блоков, оптимизация для современных GPU (H100).
  • FlashAttention v3 (2024): использует асинхронный обмен данными (asynchronous SM-to-SM copy), поддержка FP8, ещё выше производительность.

Все версии сохраняют основной принцип: tiling + recomputation для training.


8. Сравнение с другими подходами экономии памяти

МетодПамять attentionДоп. вычисленияТочность
FlashAttentionO(n)~1.33× FLOPsточный
Sparse attentionO(n) или O(n√n)меньше FLOPsприближённый
Linear attentionO(n)O(n) FLOPsприближённый (без softmax)
Gradient checkpointingO(√n) активаций2× forward passточный

FlashAttention — единственный метод, который даёт точное внимание (exact attention) с линейной памятью без потери точности. Gradient checkpointing (когда на backward пересчитываются только некоторые слои) менее эффективен, так как часто пересчитывает целые слои, а не микрооперации.


9. Итоговое резюме для собеседования

  • FlashAttention решает проблему квадратичной памяти в attention, делая её линейной.
  • Для training критичен backward pass: recomputation матрицы P из Q, K, V и сохранённых статистик.
  • Перевычисление добавляет ~33% FLOPs, но экономия памяти на порядки.
  • Алгоритм IO-aware: минимизирует чтение/запись в HBM, работает на блоках в SRAM.
  • FlashAttention используется во всех современных LLM (Llama, GPT, MPT, Falcon) и позволяет обучать модели с контекстом >100k токенов.

Пет-проект для закрепления

Задача: Реализовать упрощённый backward pass FlashAttention для одного слоя (без tiling, но с имитацией recomputation с помощью повторного вызова forward) и сравнить потребление памяти с наивным attention.

Инструменты: PyTorch, CUDA (опционально), профилировщик памяти (torch.cuda.memory_allocated).

Шаги:

  1. Реализовать класс StandardAttention с функциями forward и backward (вручную через autograd или с сохранением S).
  2. Реализовать класс FlashAttentionSim:
    • forward: вычисляет и сохраняет только статистики m, l (а не всю S).
    • backward: заново вычисляет S из Q и K и использует dO и V для градиентов (имитация recomputation).
  3. Сравнить объём сохранённых тензоров при длине n=4096, d=128.
  4. Измерить время backward (можно использовать torch.cuda.Event).
  5. Расширить: добавить tiling по одному измерению (например, разбиение n на блоки по 1024).

Ожидаемый результат: вы увидите, что FlashAttentionSim использует ~2× меньше памяти (не хранит P × градиенты), а время backward увеличивается незначительно (на 10–20%). Это наглядно демонстрирует trade-off.


Связь с другими вопросами

ВопросТема
470Архитектура механизма внимания в трансформерах (Q,K,V, softmax)
471KV Cache: что такое и как работает в inference
472Multi-Query Attention / Grouped-Query Attention для экономии памяти
473Как бороться с проблемой длинного контекста (LongLoRA, NTK-aware scaling)
475Как устроен Agentic RAG: роль внимания при обработке длинных запросов

Навигация