中文翻译暂不可用,显示俄语原文。
Как работает FlashAttention математически (tiling, recomputation, не материализуя S)?
Краткий тезис
FlashAttention — это алгоритм точного вычисления attention, который не хранит матрицу S = QK^T размером n×n (где n — длина последовательности). Вместо этого он использует tiling (разбиение на блоки) и online softmax с сохранением статистики (max, sum) для коррекции результата. На обратном проходе (backward) применяется recomputation — матрица S пересчитывается заново, что экономит память. Итог: сложность по памяти снижается с O(n²) до O(n), а скорость вычислений увеличивается в 2–4 раза за счёт эффективного использования GPU-памяти.
1. Проблема: квадратичная память стандартного attention
Стандартное внимание (например, в Transformer) вычисляется по формуле:
Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V
Q, K, V— матрицы размеромn × d(n — длина последовательности, d — размерность головы).- S = QK^T — матрица
n × n. - P = softmax(S) — матрица
n × n. - Итоговый выход
O = P V— матрицаn × d.
Проблема хранение S и P требует O(n²) памяти. Для n=128K (типичный long context) это ~16 GB в float16 — непозволительно для одного слоя. Кроме того, операция QK^T — это матричное умножение, которое загружает и выгружает большие объёмы данных из медленной глобальной памяти GPU.
Термин: материализация матрицы S — это явное создание и хранение матрицы S в памяти. FlashAttention избегает этого.
2. Основная идея: tiling (разбиение на блоки)
Вместо того чтобы вычислять S целиком, FlashAttention разбивает Q, K, V на блоки (tiles) фиксированного размера B (например, 128×128). Вычисления проводятся поблочно, и результат каждого блока сразу участвует в формировании выхода O.
Как это работает
- Разбиение по Q
Qделится на блокиQ_iразмеромB_r × d. - Разбиение по K, V:
KиVделятся на блокиK_j,V_jразмеромB_c × d. - Внешний цикл по j (блоки K, V), внутренний по i (блоки Q). Для каждой пары блоков
(Q_i, K_j)вычисляется локальная матрицаS_ij = Q_i K_j^TразмеромB_r × B_c. - Softmax на лету для каждого блока
Q_isoftmax вычисляется инкрементально, накапливая статистику по всемj. Это требует специального алгоритма — online softmax.
Термин: Tiling (блочная обработка) — разбиение больших матриц на маленькие блоки, которые помещаются в быструю память GPU (shared memory). Это уменьшает количество обращений к глобальной памяти.
3. Online softmax: сохранение статистики (max, sum)
Стандартный softmax требует знания всех элементов строки S, чтобы вычислить max и sum для нормализации. При поблочной обработке мы не видим всю строку сразу. Решение — использовать online softmax, который обновляет статистику по мере поступления блоков.
Алгоритм для одной строки (для простоты):
Пусть для строки i мы уже обработали блоки j=1..m-1 и имеем:
m_prev— текущий максимум среди обработанных элементов.l_prev— текущая суммаexp(S_ij - m_prev).
Приходит новый блок j=m с элементами s_1, ..., s_Bc.
- Вычисляем новый максимум:
m_new = max(m_prev, max(s)). - Корректируем старую сумму:
l_prev = l_prev * exp(m_prev - m_new). - Добавляем вклад нового блока:
l_new = l_prev + sum(exp(s - m_new)). - Обновляем максимум:
m_prev = m_new,l_prev = l_new.
Для матрицы статистики хранятся для каждого блока Q_i отдельно (векторы m_i и l_i размером B_r). После обработки всех j получаем итоговые m_i и l_i для каждой строки i.
Термин: Online softmax — метод вычисления softmax без полной материализации матрицы S, используя инкрементальное обновление максимума и суммы экспонент.
4. Вычисление выхода O с коррекцией
Выход O также накапливается поблочно. Для каждого блока Q_i и K_j, V_j:
- Вычисляем
S_ij = Q_i K_j^T. - Применяем online softmax с учётом текущей статистики
m_i, l_i. - Получаем локальный
P_ij = exp(S_ij - m_new) / l_new(но не храним, а сразу используем). - Обновляем выход:
O_i += P_ij V_j.
Коррекция так как статистика m_i, l_i меняется после каждого блока, ранее добавленный вклад O_i нужно скорректировать. Это делается умножением на exp(m_prev - m_new) / (l_new / l_prev). На практике алгоритм хранит O_i в неоткорректированном виде и применяет финальную нормализацию после всех блоков.
Итоговая формула для одного блока Q_i:
O_i = (l_prev / l_new) * exp(m_prev - m_new) * O_i_prev + (1 / l_new) * exp(S_ij - m_new) * V_j
Термин: Коррекция (rescaling) — поправка накопленного выхода при изменении глобального максимума и суммы.
5. Recomputation на обратном проходе (backward)
При обратном распространении ошибки стандартный attention требует градиенты по Q, K, V. Для этого нужны матрицы S и P. FlashAttention не хранит их, а пересчитывает заново (recomputation) на backward.
Как это работает
- На forward сохраняются только статистики
m_i, l_iи случайные dropout-маски (если используются) — этоO(n)памяти. - На backward повторяется тот же поблочный цикл: для каждого блока
(Q_i, K_j)заново вычисляетсяS_ij, затемP_ijс использованием сохранённыхm_i, l_i, и на основе этого считаются градиенты.
Термин: Recomputation — повторное вычисление промежуточных результатов на обратном проходе вместо их хранения. Это trade-off: время вычислений увеличивается (примерно на 30%), но память экономится радикально.
Почему это выгодно для long context узким местом является память, а не вычисления. Recomputation позволяет обрабатывать последовательности в 10–100 раз длиннее при том же объёме GPU.
6. Математическая формализация алгоритма
Приведём псевдокод для одного блока Q_i (упрощённо):
# Инициализация
m_i = -inf # вектор размера B_r
l_i = 0 # вектор размера B_r
O_i = 0 # матрица B_r x d
# Цикл по блокам K, V
for j in range(0, n, B_c):
K_j = K[j:j+B_c, :] # B_c x d
V_j = V[j:j+B_c, :] # B_c x d
S_ij = Q_i @ K_j.T # B_r x B_c
# Online softmax
m_new = max(m_i, rowmax(S_ij)) # поэлементно по строкам
l_new = exp(m_i - m_new) * l_i + rowsum(exp(S_ij - m_new))
# Коррекция и накопление O_i
O_i = exp(m_i - m_new) * (l_i / l_new) * O_i + (1 / l_new) * (exp(S_ij - m_new) @ V_j)
# Обновление статистики
m_i = m_new
l_i = l_new
# После цикла O_i уже нормализован (деление на l_i уже учтено)
Термин: rowmax, rowsum — поэлементные операции по строкам: максимум и сумма экспонент.
7. Сравнение со стандартным attention
| Характеристика | Стандартный attention | FlashAttention |
|---|---|---|
| Память (forward) | O(n²) | O(n) |
| Память (backward) | O(n²) | O(n) |
| Время (forward) | ~2x быстрее (без overhead) | ~2-4x быстрее (за счёт эффективного использования shared memory) |
| Время (backward) | ~2x быстрее | ~1.3x медленнее (из-за recomputation) |
| Точность | Точный | Точный (не приближение) |
| Поддержка dropout | Да | Да (маска хранится) |
Почему FlashAttention быстрее на forward Основное узкое место — чтение/запись в глобальную память. FlashAttention загружает блоки K, V в shared memory и переиспользует их для нескольких блоков Q. Это уменьшает количество обращений к HBM (High Bandwidth Memory) в ~2-4 раза.
8. Применение в Agentic RAG и long context
В архитектурах Agentic RAG агенты часто обрабатывают длинные контексты: историю диалога, множество документов, результаты вызовов инструментов. FlashAttention позволяет:
- Использовать модели с контекстом до 128K+ токенов без увеличения памяти.
- Ускорять инференс за счёт меньшего числа операций с памятью.
- Обучать модели на длинных последовательностях (fine-tuning RAG-агентов).
Пример: RAG-агент, который читает 50 документов по 2000 токенов каждый (100K токенов). Без FlashAttention потребовалось бы ~10 GB только на матрицу S для одного слоя. С FlashAttention — ~80 MB.
9. Вариации и расширения
- FlashAttention-2 улучшенная версия с более эффективным tiling (меньше синхронизаций, лучшее использование warp-level primitives). Скорость увеличена ещё на ~2x.
- FlashAttention-3 (Hopper): использует новые инструкции NVIDIA Hopper (FP8, Tensor Core) для ещё большей производительности.
- Block-sparse FlashAttention для разреженных паттернов внимания (например, sliding window, dilated attention).
- FlashDecoding оптимизация для инференса (генерации), где
K, Vкэшируются, аQ— один токен.
Термин: Warp-level primitives — низкоуровневые операции на GPU, выполняемые группой потоков (warp). FlashAttention использует их для быстрых редукций (max, sum).
Пет-проект для закрепления
Задача Реализовать упрощённую версию FlashAttention на Python с NumPy, которая не материализует S, и сравнить её со стандартным attention по памяти и времени.
Инструменты Python, NumPy, timeit, memory_profiler (опционально).
Шаги:
- Сгенерируйте случайные
Q, K, Vразмеромn=1024,d=64. - Реализуйте стандартный attention:
S = Q @ K.T,P = softmax(S),O = P @ V. Замерьте время и пиковую память (например, черезsys.getsizeofилиmemory_profiler). - Реализуйте FlashAttention:
- Выберите размер блока
B=128. - Напишите цикл по блокам
Q_iиK_j. - Используйте online softmax с коррекцией.
- Не храните
Sцеликом.
- Выберите размер блока
- Проверьте, что выход
Oсовпадает со стандартным (с точностью до1e-5). - Замерьте время и память. Убедитесь, что память ~
O(n*d)(линейная), а неO(n^2).
Ожидаемый результат Вы увидите, что при n=4096 стандартный attention потребляет ~128 MB (матрица S), а FlashAttention — ~2 MB. Время может быть немного больше из-за циклов на Python, но на GPU разница будет в пользу FlashAttention.
Дополнительно Реализуйте backward с recomputation: повторите forward на обратном проходе, используя сохранённые m_i, l_i.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 665 | Как работает attention в Transformer (Q, K, V, softmax) |
| 666 | Проблема квадратичной сложности attention и пути решения |
| 668 | Что такое KV-cache и как он используется в инференсе |
| 669 | Как устроен Multi-Query Attention (MQA) и Grouped-Query Attention (GQA) |
| 670 | Как работает Sliding Window Attention |
| 675 | Как оптимизировать RAG для работы с длинными документами (long context) |
Навигация
- Предыдущий: 666
- Следующий: 668
- Индекс: 00. Индекс разборов