中文翻译暂不可用,显示俄语原文。
Как работает FlashAttention-3 математически?
Краткий тезис
FlashAttention-3 (FA3, 2025) — это алгоритм точного внимания (attention) для GPU архитектуры Hopper (H100 и новее), который использует специализированные аппаратные блоки: WGMMA (Warp Group Matrix Multiply-Accumulate) для быстрых матричных умножений размером 64×64×16 за одну инструкцию и TMA (Tensor Memory Accelerator) для асинхронного копирования данных между глобальной и разделяемой памятью. Благодаря полному перекрытию вычислений и передачи данных, а также поддержке FP8, FA3 достигает 2× ускорения относительно Attention 2|FlashAttention-2 и 4× относительно стандартного внимания, сохраняя линейную сложность по памяти O(n).
1. Введение: почему нужен FlashAttention?
Стандартное (vanilla) внимание (Vaswani et al., 2017) вычисляет
Softmax(Q K^T / √d) V, где Q, K, V ∈ R^{n×d}.
Сложность O(n² d) по времени и O(n²) по памяти (из-за матрицы S = Q K^T). Для последовательностей длиной 64K токенов это десятки гигабайт – непрактично.
FlashAttention (Dao et al., 2022) решает проблему, разбивая вход на блоки и выполняя вычисления «на лету» без материализации полной матрицы S. Вместо этого каждое softmax применяется к блоку, а результат накапливается с коррекцией. Это снижает память до O(n) и ускоряет работу за счёт эффективного использования быстрой shared memory GPU.
FlashAttention-2 (2023) улучшил параллелизм, уменьшил количество синхронизаций и оптимизировал под Ampere архитектуру (A100). FlashAttention-3 (2025) — следующий шаг, заточенный под Hopper (H100/H200).
2. Ключевые аппаратные возможности Hopper, используемые FA3
| Компонент | Назначение | Характеристики |
|---|---|---|
| WGMMA (Warp Group MMA) | Выполнение матричного умножения-накопления размером 64×64×16 за одну инструкцию | Высокая пропускная способность, без участия скалярных регистров |
| TMA (Tensor Memory Accelerator) | Асинхронное копирование блоков (2D/3D) из глобальной памяти в shared memory | Аппаратный DMA-движок, освобождает warp для вычислений |
| FP8 тензорные ядра | Выполнение умножений и накоплений в 8-битной плавающей точке (E4M3 / E5M2) | Удвоение пиковой производительности по сравнению с FP16/BF16 |
| Asynchronous transaction barriers | Синхронизация асинхронных копий без ожидания всех потоков | Позволяет перекрывать копирование и вычисления |
Эти механизмы позволяют FA3 одновременно загружать данные из HBM в shared memory (TMA), выполнять матричное умножение (WGMMA) и накапливать результат — почти полностью скрывая задержку памяти.
3. Математическая основа: tiling + online softmax
FlashAttention любого поколения использует tiling (разбиение на блоки) и online softmax (SafeSoftmax). Рассмотрим на примере одного блока запросов Q_i и блоков ключей K_j, значений V_j.
-
Инициализация
Для блока Q_i размером B_r × d:
O_i = 0,m_i = -∞(максимум по строкам),ℓ_i = 0(сумма экспонент). -
Для каждого блока K_j, V_j (размер B_c × d):
- Загружаем K_j, V_j в shared memory (через TMA).
- Считаем блок матрицы S_ij = Q_i K_j^T / √d (WGMMA).
- Обновляем m_i_new = max(m_i, rowmax(S_ij)).
- Корректируем старые веса:
ℓ_i = ℓ_i * exp(m_i - m_i_new) + rowsum(exp(S_ij - m_i_new)). - Обновляем вывод:
O_i = O_i * exp(m_i - m_i_new) + exp(S_ij - m_i_new) · V_j(WGMMA). - Устанавливаем
m_i = m_i_new.
-
После всех блоков
НормализуемO_i = O_i / ℓ_i.
Ключевое отличие от FA2: в FA3 вычисление S_ij и обновление O_i выполняется через WGMMA, а копирование K_j, V_j из глобальной памяти — асинхронно через TMA, так что к моменту завершения умножения следующий блок уже загружен.
4. Работа WGMMA (Warp Group MMA)
Warp Group — это группа из 4 варпов (всего 128 потоков) на H100. WGMMA — инструкция (например, wgmma.mma_async), которая выполняет:
- Вход: два фрагмента (например, кусок Q_i размером 64×d и кусок K_j размером d×64), хранящиеся в фрагментных регистрах (RF).
- Выход: аккумулированный результат 64×64 (часть матрицы S или накопление O_i).
- За один такт выполняется умножение 64×64×16 (16 — размер накопления), т.е. 64·64·16 = 65 536 FMAD операций.
FA3 разбивает вычисления так, чтобы WGMMA обрабатывал блоки 64×64×16, что идеально ложится на архитектуру. Для остатков (некратных 64) используется маскировка.
5. Роль TMA (Tensor Memory Accelerator)
TMA — аппаратный блок, который может копировать многомерные фрагменты (2D-срез) из глобальной памяти в shared memory, не занимая инструкций warp'ов.
- TMA программируется через дескриптор (указатель, shape, strides).
- Запуск копирования — одна инструкция cp.async.bulk.
- После запуска warp продолжает вычисления.
- Синхронизация — через барьер: можно дождаться, пока нужный блок будет готов.
В FA3 TMA используется для загрузки K_j и V_j с перекрытием: пока WGMMA считает текущий блок S_ij, TMA уже загружает следующий блок K_{j+1}, V_{j+1}. Благодаря двойной буферизации в shared memory (ping-pong) копирование полностью скрывается.
6. FP8 поддержка и квантизация ошибок
Hopper поддерживает два формата FP8:
- E4M3 (4 бита экспонента, 3 мантисса) — для K и V (поскольку они обычно имеют меньший диапазон).
- E5M2 (5 экспонента, 2 мантисса) — для Q и накоплений S (больший диапазон).
Для корректного softmax FA3 автоматически масштабирует Q и K перед умножением, чтобы избежать переполнения.
Q_scaled = Q * (1/√d), K_scaled = K.
При накоплении O_i используется FP16/BF32 (внутренняя точность тензорных ядер), а финальный результат можно сохранить в FP16/FP32. Экспериментально показано, что для большинства моделей (BERT, Llama) точность не падает, а производительность удваивается относительно FA2 в BF16.
7. Псевдокод ядра FA3 (упрощённый)
# Псевдокод для одного блока Q (B_r=64)
for j in range(0, N, B_c):
# Асинхронно загружаем K_j, V_j через TMA (двойной буфер)
tma_copy(K_block[j], shared_mem_ping)
tma_copy(V_block[j], shared_mem_pong)
wait_for_previous_copy() # ждём, если нужно
# WGMMA: S_j = Q @ K_j^T (оба 64x64)
wgmma(S_j, Q_frag, K_frag)
S_j = S_j / sqrt(d)
# Online softmax + накопление O
m_new = max(m, rowmax(S_j))
l = l * exp(m - m_new) + rowsum(exp(S_j - m_new))
O = O * exp(m - m_new) + wgmma(exp(S_j - m_new), V_j)
m = m_new
8. Сложность и производительность
| Метрика | Standard | FA2 | FA3 (FP16) | FA3 (FP8) |
|---|---|---|---|---|
| Время для 8K seq., 128 dim (H100) | ~25 ms | ~2.8 ms | ~1.4 ms | ~0.7 ms |
| Память (для Q, K, V) | O(n²) | O(n) | O(n) | O(n) |
| Использование HBM bandwidth | ~20% | ~60% | ~80% | ~90% |
FA3 достигает ~300 TFLOPS для FP8 на H100, что близко к пику (989 TFLOPS) с учётом overhead softmax и нормализации.
9. Применение в Agentic RAG
Agentic RAG — системы, где агент многократно обращается к LLM с длинными контекстами (результаты поиска, история действий). FlashAttention-3 критичен, так как:
- Длинные контексты (32K+ токенов) — без FA3 внимание становится узким местом.
- Агенты с памятью — каждый шаг добавляет токены, и FA3 позволяет держать низкую задержку.
- Параллельные вызовы — на H100 можно запустить несколько голов внимания в одном потоке, используя TMA для предзагрузки ключей разных вызовов.
Например, в RAG-агенте с 8 вызовами по 16K контекста каждый, FA3 может сократить latency на 60% по сравнению с FA2.
Пет-проект для закрепления
Задача Реализовать упрощённый симулятор FA3 на CPU (без аппаратных инструкций) для понимания tiling и online softmax.
Инструменты Python + NumPy (для матричных операций), Numba (для имитации блоков).
Шаги:
- Сгенерируйте случайные Q, K, V (B=1, N=4096, d=128).
- Реализуйте vanilla attention как эталон.
- Реализуйте FA2-стиль (tiling 64x64, online softmax).
- Добавьте «асинхронное копирование» — используйте потоки Python для имитации перекрытия: в одном потоке считайте S, в другом — загружайте следующий блок.
- Замерьте время (для симуляции используйте
time.sleep(0.01)как latency копирования). - Сравните с vanilla — убедитесь, что результат совпадает (с плавающей точкой).
Ожидаемый результат — скрипт, который воспроизводит логику FA3 и демонстрирует, почему перекрытие снижает суммарное время.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 204 | FlashAttention-2: как работает мягкое внимание с памятью O(n) |
| 843 | Архитектура Hopper: отличия от Ampere |
| 845 | FP8 обучение и inference |
| 720 | Что такое тензорные ядра и как они ускоряют матричные умножения |
| 735 | Оптимизация внимания для длинных контекстов (Sparse Attention, Linear Attention) |
| 836 | Профилирование LLM на GPU с помощью Nsight |
Навигация
- Предыдущий: 843
- Следующий: 845
- Индекс: 00. Индекс разборов