English translation is not available yet. Showing Russian content.

Как работает FlashAttention математически (tiling, recomputation, не материализуя S)?

Краткий тезис

FlashAttention — это алгоритм точного вычисления attention, который не хранит матрицу S = QK^T размером n×n (где n — длина последовательности). Вместо этого он использует tiling (разбиение на блоки) и online softmax с сохранением статистики (max, sum) для коррекции результата. На обратном проходе (backward) применяется recomputation — матрица S пересчитывается заново, что экономит память. Итог: сложность по памяти снижается с O(n²) до O(n), а скорость вычислений увеличивается в 2–4 раза за счёт эффективного использования GPU-памяти.


1. Проблема: квадратичная память стандартного attention

Стандартное внимание (например, в Transformer) вычисляется по формуле:

Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V
  • Q, K, V — матрицы размером n × d (n — длина последовательности, d — размерность головы).
  • S = QK^T — матрица n × n.
  • P = softmax(S) — матрица n × n.
  • Итоговый выход O = P V — матрица n × d.

Проблема хранение S и P требует O(n²) памяти. Для n=128K (типичный long context) это ~16 GB в float16 — непозволительно для одного слоя. Кроме того, операция QK^T — это матричное умножение, которое загружает и выгружает большие объёмы данных из медленной глобальной памяти GPU.

Термин: материализация матрицы S — это явное создание и хранение матрицы S в памяти. FlashAttention избегает этого.


2. Основная идея: tiling (разбиение на блоки)

Вместо того чтобы вычислять S целиком, FlashAttention разбивает Q, K, V на блоки (tiles) фиксированного размера B (например, 128×128). Вычисления проводятся поблочно, и результат каждого блока сразу участвует в формировании выхода O.

Как это работает

  1. Разбиение по Q Q делится на блоки Q_i размером B_r × d.
  2. Разбиение по K, V: K и V делятся на блоки K_j, V_j размером B_c × d.
  3. Внешний цикл по j (блоки K, V), внутренний по i (блоки Q). Для каждой пары блоков (Q_i, K_j) вычисляется локальная матрица S_ij = Q_i K_j^T размером B_r × B_c.
  4. Softmax на лету для каждого блока Q_i softmax вычисляется инкрементально, накапливая статистику по всем j. Это требует специального алгоритма — online softmax.

Термин: Tiling (блочная обработка) — разбиение больших матриц на маленькие блоки, которые помещаются в быструю память GPU (shared memory). Это уменьшает количество обращений к глобальной памяти.


3. Online softmax: сохранение статистики (max, sum)

Стандартный softmax требует знания всех элементов строки S, чтобы вычислить max и sum для нормализации. При поблочной обработке мы не видим всю строку сразу. Решение — использовать online softmax, который обновляет статистику по мере поступления блоков.

Алгоритм для одной строки (для простоты):

Пусть для строки i мы уже обработали блоки j=1..m-1 и имеем:

  • m_prev — текущий максимум среди обработанных элементов.
  • l_prev — текущая сумма exp(S_ij - m_prev).

Приходит новый блок j=m с элементами s_1, ..., s_Bc.

  1. Вычисляем новый максимум: m_new = max(m_prev, max(s)).
  2. Корректируем старую сумму: l_prev = l_prev * exp(m_prev - m_new).
  3. Добавляем вклад нового блока: l_new = l_prev + sum(exp(s - m_new)).
  4. Обновляем максимум: m_prev = m_new, l_prev = l_new.

Для матрицы статистики хранятся для каждого блока Q_i отдельно (векторы m_i и l_i размером B_r). После обработки всех j получаем итоговые m_i и l_i для каждой строки i.

Термин: Online softmax — метод вычисления softmax без полной материализации матрицы S, используя инкрементальное обновление максимума и суммы экспонент.


4. Вычисление выхода O с коррекцией

Выход O также накапливается поблочно. Для каждого блока Q_i и K_j, V_j:

  1. Вычисляем S_ij = Q_i K_j^T.
  2. Применяем online softmax с учётом текущей статистики m_i, l_i.
  3. Получаем локальный P_ij = exp(S_ij - m_new) / l_new (но не храним, а сразу используем).
  4. Обновляем выход: O_i += P_ij V_j.

Коррекция так как статистика m_i, l_i меняется после каждого блока, ранее добавленный вклад O_i нужно скорректировать. Это делается умножением на exp(m_prev - m_new) / (l_new / l_prev). На практике алгоритм хранит O_i в неоткорректированном виде и применяет финальную нормализацию после всех блоков.

Итоговая формула для одного блока Q_i:

O_i = (l_prev / l_new) * exp(m_prev - m_new) * O_i_prev + (1 / l_new) * exp(S_ij - m_new) * V_j

Термин: Коррекция (rescaling) — поправка накопленного выхода при изменении глобального максимума и суммы.


5. Recomputation на обратном проходе (backward)

При обратном распространении ошибки стандартный attention требует градиенты по Q, K, V. Для этого нужны матрицы S и P. FlashAttention не хранит их, а пересчитывает заново (recomputation) на backward.

Как это работает

  • На forward сохраняются только статистики m_i, l_i и случайные dropout-маски (если используются) — это O(n) памяти.
  • На backward повторяется тот же поблочный цикл: для каждого блока (Q_i, K_j) заново вычисляется S_ij, затем P_ij с использованием сохранённых m_i, l_i, и на основе этого считаются градиенты.

Термин: Recomputation — повторное вычисление промежуточных результатов на обратном проходе вместо их хранения. Это trade-off: время вычислений увеличивается (примерно на 30%), но память экономится радикально.

Почему это выгодно для long context узким местом является память, а не вычисления. Recomputation позволяет обрабатывать последовательности в 10–100 раз длиннее при том же объёме GPU.


6. Математическая формализация алгоритма

Приведём псевдокод для одного блока Q_i (упрощённо):

# Инициализация
m_i = -inf  # вектор размера B_r
l_i = 0     # вектор размера B_r
O_i = 0     # матрица B_r x d

# Цикл по блокам K, V
for j in range(0, n, B_c):
    K_j = K[j:j+B_c, :]   # B_c x d
    V_j = V[j:j+B_c, :]   # B_c x d
    S_ij = Q_i @ K_j.T    # B_r x B_c

    # Online softmax
    m_new = max(m_i, rowmax(S_ij))   # поэлементно по строкам
    l_new = exp(m_i - m_new) * l_i + rowsum(exp(S_ij - m_new))

    # Коррекция и накопление O_i
    O_i = exp(m_i - m_new) * (l_i / l_new) * O_i + (1 / l_new) * (exp(S_ij - m_new) @ V_j)

    # Обновление статистики
    m_i = m_new
    l_i = l_new

# После цикла O_i уже нормализован (деление на l_i уже учтено)

Термин: rowmax, rowsum — поэлементные операции по строкам: максимум и сумма экспонент.


7. Сравнение со стандартным attention

ХарактеристикаСтандартный attentionFlashAttention
Память (forward)O(n²)O(n)
Память (backward)O(n²)O(n)
Время (forward)~2x быстрее (без overhead)~2-4x быстрее (за счёт эффективного использования shared memory)
Время (backward)~2x быстрее~1.3x медленнее (из-за recomputation)
ТочностьТочныйТочный (не приближение)
Поддержка dropoutДаДа (маска хранится)

Почему FlashAttention быстрее на forward Основное узкое место — чтение/запись в глобальную память. FlashAttention загружает блоки K, V в shared memory и переиспользует их для нескольких блоков Q. Это уменьшает количество обращений к HBM (High Bandwidth Memory) в ~2-4 раза.


8. Применение в Agentic RAG и long context

В архитектурах Agentic RAG агенты часто обрабатывают длинные контексты: историю диалога, множество документов, результаты вызовов инструментов. FlashAttention позволяет:

  • Использовать модели с контекстом до 128K+ токенов без увеличения памяти.
  • Ускорять инференс за счёт меньшего числа операций с памятью.
  • Обучать модели на длинных последовательностях (fine-tuning RAG-агентов).

Пример: RAG-агент, который читает 50 документов по 2000 токенов каждый (100K токенов). Без FlashAttention потребовалось бы ~10 GB только на матрицу S для одного слоя. С FlashAttention — ~80 MB.


9. Вариации и расширения

  • FlashAttention-2 улучшенная версия с более эффективным tiling (меньше синхронизаций, лучшее использование warp-level primitives). Скорость увеличена ещё на ~2x.
  • FlashAttention-3 (Hopper): использует новые инструкции NVIDIA Hopper (FP8, Tensor Core) для ещё большей производительности.
  • Block-sparse FlashAttention для разреженных паттернов внимания (например, sliding window, dilated attention).
  • FlashDecoding оптимизация для инференса (генерации), где K, V кэшируются, а Q — один токен.

Термин: Warp-level primitives — низкоуровневые операции на GPU, выполняемые группой потоков (warp). FlashAttention использует их для быстрых редукций (max, sum).


Пет-проект для закрепления

Задача Реализовать упрощённую версию FlashAttention на Python с NumPy, которая не материализует S, и сравнить её со стандартным attention по памяти и времени.

Инструменты Python, NumPy, timeit, memory_profiler (опционально).

Шаги:

  1. Сгенерируйте случайные Q, K, V размером n=1024, d=64.
  2. Реализуйте стандартный attention: S = Q @ K.T, P = softmax(S), O = P @ V. Замерьте время и пиковую память (например, через sys.getsizeof или memory_profiler).
  3. Реализуйте FlashAttention:
    • Выберите размер блока B=128.
    • Напишите цикл по блокам Q_i и K_j.
    • Используйте online softmax с коррекцией.
    • Не храните S целиком.
  4. Проверьте, что выход O совпадает со стандартным (с точностью до 1e-5).
  5. Замерьте время и память. Убедитесь, что память ~O(n*d) (линейная), а не O(n^2).

Ожидаемый результат Вы увидите, что при n=4096 стандартный attention потребляет ~128 MB (матрица S), а FlashAttention — ~2 MB. Время может быть немного больше из-за циклов на Python, но на GPU разница будет в пользу FlashAttention.

Дополнительно Реализуйте backward с recomputation: повторите forward на обратном проходе, используя сохранённые m_i, l_i.


Связь с другими вопросами

ВопросТема
665Как работает attention в Transformer (Q, K, V, softmax)
666Проблема квадратичной сложности attention и пути решения
668Что такое KV-cache и как он используется в инференсе
669Как устроен Multi-Query Attention (MQA) и Grouped-Query Attention (GQA)
670Как работает Sliding Window Attention
675Как оптимизировать RAG для работы с длинными документами (long context)

Навигация