English translation is not available yet. Showing Russian content.

Как работает FlashAttention-3 математически?

Краткий тезис

FlashAttention-3 (FA3, 2025) — это алгоритм точного внимания (attention) для GPU архитектуры Hopper (H100 и новее), который использует специализированные аппаратные блоки: WGMMA (Warp Group Matrix Multiply-Accumulate) для быстрых матричных умножений размером 64×64×16 за одну инструкцию и TMA (Tensor Memory Accelerator) для асинхронного копирования данных между глобальной и разделяемой памятью. Благодаря полному перекрытию вычислений и передачи данных, а также поддержке FP8, FA3 достигает 2× ускорения относительно Attention 2|FlashAttention-2 и 4× относительно стандартного внимания, сохраняя линейную сложность по памяти O(n).

1. Введение: почему нужен FlashAttention?

Стандартное (vanilla) внимание (Vaswani et al., 2017) вычисляет

Softmax(Q K^T / √d) V, где Q, K, V ∈ R^{n×d}.
Сложность O(n² d) по времени и O(n²) по памяти (из-за матрицы S = Q K^T). Для последовательностей длиной 64K токенов это десятки гигабайт – непрактично.

FlashAttention (Dao et al., 2022) решает проблему, разбивая вход на блоки и выполняя вычисления «на лету» без материализации полной матрицы S. Вместо этого каждое softmax применяется к блоку, а результат накапливается с коррекцией. Это снижает память до O(n) и ускоряет работу за счёт эффективного использования быстрой shared memory GPU.

FlashAttention-2 (2023) улучшил параллелизм, уменьшил количество синхронизаций и оптимизировал под Ampere архитектуру (A100). FlashAttention-3 (2025) — следующий шаг, заточенный под Hopper (H100/H200).

2. Ключевые аппаратные возможности Hopper, используемые FA3

КомпонентНазначениеХарактеристики
WGMMA (Warp Group MMA)Выполнение матричного умножения-накопления размером 64×64×16 за одну инструкциюВысокая пропускная способность, без участия скалярных регистров
TMA (Tensor Memory Accelerator)Асинхронное копирование блоков (2D/3D) из глобальной памяти в shared memoryАппаратный DMA-движок, освобождает warp для вычислений
FP8 тензорные ядраВыполнение умножений и накоплений в 8-битной плавающей точке (E4M3 / E5M2)Удвоение пиковой производительности по сравнению с FP16/BF16
Asynchronous transaction barriersСинхронизация асинхронных копий без ожидания всех потоковПозволяет перекрывать копирование и вычисления

Эти механизмы позволяют FA3 одновременно загружать данные из HBM в shared memory (TMA), выполнять матричное умножение (WGMMA) и накапливать результат — почти полностью скрывая задержку памяти.

3. Математическая основа: tiling + online softmax

FlashAttention любого поколения использует tiling (разбиение на блоки) и online softmax (SafeSoftmax). Рассмотрим на примере одного блока запросов Q_i и блоков ключей K_j, значений V_j.

  1. Инициализация
    Для блока Q_i размером B_r × d:
    O_i = 0, m_i = -∞ (максимум по строкам), ℓ_i = 0 (сумма экспонент).

  2. Для каждого блока K_j, V_j (размер B_c × d):

    • Загружаем K_j, V_j в shared memory (через TMA).
    • Считаем блок матрицы S_ij = Q_i K_j^T / √d (WGMMA).
    • Обновляем m_i_new = max(m_i, rowmax(S_ij)).
    • Корректируем старые веса: ℓ_i = ℓ_i * exp(m_i - m_i_new) + rowsum(exp(S_ij - m_i_new)).
    • Обновляем вывод:
      O_i = O_i * exp(m_i - m_i_new) + exp(S_ij - m_i_new) · V_j (WGMMA).
    • Устанавливаем m_i = m_i_new.
  3. После всех блоков
    Нормализуем O_i = O_i / ℓ_i.

Ключевое отличие от FA2: в FA3 вычисление S_ij и обновление O_i выполняется через WGMMA, а копирование K_j, V_j из глобальной памяти — асинхронно через TMA, так что к моменту завершения умножения следующий блок уже загружен.

4. Работа WGMMA (Warp Group MMA)

Warp Group — это группа из 4 варпов (всего 128 потоков) на H100. WGMMA — инструкция (например, wgmma.mma_async), которая выполняет:

  • Вход: два фрагмента (например, кусок Q_i размером 64×d и кусок K_j размером d×64), хранящиеся в фрагментных регистрах (RF).
  • Выход: аккумулированный результат 64×64 (часть матрицы S или накопление O_i).
  • За один такт выполняется умножение 64×64×16 (16 — размер накопления), т.е. 64·64·16 = 65 536 FMAD операций.

FA3 разбивает вычисления так, чтобы WGMMA обрабатывал блоки 64×64×16, что идеально ложится на архитектуру. Для остатков (некратных 64) используется маскировка.

5. Роль TMA (Tensor Memory Accelerator)

TMA — аппаратный блок, который может копировать многомерные фрагменты (2D-срез) из глобальной памяти в shared memory, не занимая инструкций warp'ов.

  • TMA программируется через дескриптор (указатель, shape, strides).
  • Запуск копирования — одна инструкция cp.async.bulk.
  • После запуска warp продолжает вычисления.
  • Синхронизация — через барьер: можно дождаться, пока нужный блок будет готов.

В FA3 TMA используется для загрузки K_j и V_j с перекрытием: пока WGMMA считает текущий блок S_ij, TMA уже загружает следующий блок K_{j+1}, V_{j+1}. Благодаря двойной буферизации в shared memory (ping-pong) копирование полностью скрывается.

6. FP8 поддержка и квантизация ошибок

Hopper поддерживает два формата FP8:

  • E4M3 (4 бита экспонента, 3 мантисса) — для K и V (поскольку они обычно имеют меньший диапазон).
  • E5M2 (5 экспонента, 2 мантисса) — для Q и накоплений S (больший диапазон).

Для корректного softmax FA3 автоматически масштабирует Q и K перед умножением, чтобы избежать переполнения.
Q_scaled = Q * (1/√d), K_scaled = K.

При накоплении O_i используется FP16/BF32 (внутренняя точность тензорных ядер), а финальный результат можно сохранить в FP16/FP32. Экспериментально показано, что для большинства моделей (BERT, Llama) точность не падает, а производительность удваивается относительно FA2 в BF16.

7. Псевдокод ядра FA3 (упрощённый)

# Псевдокод для одного блока Q (B_r=64)
for j in range(0, N, B_c):
    # Асинхронно загружаем K_j, V_j через TMA (двойной буфер)
    tma_copy(K_block[j], shared_mem_ping)
    tma_copy(V_block[j], shared_mem_pong)
    
    wait_for_previous_copy()   # ждём, если нужно
    
    # WGMMA: S_j = Q @ K_j^T (оба 64x64)
    wgmma(S_j, Q_frag, K_frag)
    S_j = S_j / sqrt(d)
    
    # Online softmax + накопление O
    m_new = max(m, rowmax(S_j))
    l = l * exp(m - m_new) + rowsum(exp(S_j - m_new))
    O = O * exp(m - m_new) + wgmma(exp(S_j - m_new), V_j)
    m = m_new

8. Сложность и производительность

МетрикаStandardFA2FA3 (FP16)FA3 (FP8)
Время для 8K seq., 128 dim (H100)~25 ms~2.8 ms~1.4 ms~0.7 ms
Память (для Q, K, V)O(n²)O(n)O(n)O(n)
Использование HBM bandwidth~20%~60%~80%~90%

FA3 достигает ~300 TFLOPS для FP8 на H100, что близко к пику (989 TFLOPS) с учётом overhead softmax и нормализации.

9. Применение в Agentic RAG

Agentic RAG — системы, где агент многократно обращается к LLM с длинными контекстами (результаты поиска, история действий). FlashAttention-3 критичен, так как:

  • Длинные контексты (32K+ токенов) — без FA3 внимание становится узким местом.
  • Агенты с памятью — каждый шаг добавляет токены, и FA3 позволяет держать низкую задержку.
  • Параллельные вызовы — на H100 можно запустить несколько голов внимания в одном потоке, используя TMA для предзагрузки ключей разных вызовов.

Например, в RAG-агенте с 8 вызовами по 16K контекста каждый, FA3 может сократить latency на 60% по сравнению с FA2.

Пет-проект для закрепления

Задача Реализовать упрощённый симулятор FA3 на CPU (без аппаратных инструкций) для понимания tiling и online softmax.

Инструменты Python + NumPy (для матричных операций), Numba (для имитации блоков).

Шаги:

  1. Сгенерируйте случайные Q, K, V (B=1, N=4096, d=128).
  2. Реализуйте vanilla attention как эталон.
  3. Реализуйте FA2-стиль (tiling 64x64, online softmax).
  4. Добавьте «асинхронное копирование» — используйте потоки Python для имитации перекрытия: в одном потоке считайте S, в другом — загружайте следующий блок.
  5. Замерьте время (для симуляции используйте time.sleep(0.01) как latency копирования).
  6. Сравните с vanilla — убедитесь, что результат совпадает (с плавающей точкой).

Ожидаемый результат — скрипт, который воспроизводит логику FA3 и демонстрирует, почему перекрытие снижает суммарное время.

Связь с другими вопросами

ВопросТема
204FlashAttention-2: как работает мягкое внимание с памятью O(n)
843Архитектура Hopper: отличия от Ampere
845FP8 обучение и inference
720Что такое тензорные ядра и как они ускоряют матричные умножения
735Оптимизация внимания для длинных контекстов (Sparse Attention, Linear Attention)
836Профилирование LLM на GPU с помощью Nsight

Навигация