Что такое FlashAttention-3 и какие improvements он принес по сравнению с FA2?
Краткий тезис
FlashAttention-3 (FA3) — это третье поколение алгоритма точного attention с линейной по длине последовательности памятью, оптимизированное для архитектуры Hopper (NVIDIA H100/H800). Главные улучшения по сравнению с FlashAttention-2 (FA2): асинхронное выполнение с использованием инструкций WGMMA, улучшенная партицирование для длинных последовательностей и поддержка вычислений в FP8, что даёт прирост скорости до 2–3 раз на H100.
1. Что такое FlashAttention и зачем он нужен
FlashAttention — это алгоритм точного вычисления механизма attention (softmax(QK^T)V) без материализации полной матрицы QK^T в памяти GPU. Вместо этого attention вычисляется поблочно (tiling) с пересчётом softmax на лету. Это позволяет:
- Снизить сложность по памяти с O(N²) до O(N) (N — длина последовательности).
- Ускорить вычисления за счёт эффективного использования быстрой памяти (SRAM) GPU.
- Обучать и инференсить модели с контекстом до 128K+ токенов без переполнения VRAM.
Первая версия (FA1, 2022) показала ускорение в 2–3 раза. Вторая (FA2, 2023) улучшила параллелизм и поддержку разных типов attention (causal, masked). FA3 (2025) — следующий шаг, заточенный под новое поколение GPU.
2. FlashAttention-2: достижения и ограничения
FA2 (Tri Dao et al., 2023) принёс:
- Лучший параллелизм по длине последовательности и по головам (head dimension).
- Поддержка causal masking без дополнительных затрат.
- Оптимизация для A100 (Ampere) с использованием тензорных ядер (Tensor Cores) и инструкций HMMA.
- Скорость до 2× быстрее FA1.
Ограничения FA2:
- Не полностью использует возможности Hopper (H100): инструкции WGMMA, асинхронное копирование, тензорные ядра четвёртого поколения.
- Нет поддержки FP8 (только FP16/BF16).
- Партицирование (разбиение на блоки) не оптимально для очень длинных последовательностей (>32K токенов).
3. FlashAttention-3: ключевые improvements
FA3 (2025, Tri Dao, Jay Shah и др.) решает эти ограничения. Основные нововведения:
- Асинхронное выполнение (asynchronous execution) с использованием инструкций WGMMA.
- Улучшенное партицирование (improved partitioning) для long sequences.
- Поддержка FP8 (8-bit floating point).
- Оптимизация под архитектуру Hopper (H100/H800).
Рассмотрим каждое подробнее.
4. Асинхронное выполнение на Hopper (WGMMA)
WGMMA (Warp Group Matrix Multiply-Accumulate) — новая инструкция в архитектуре Hopper, позволяющая выполнять матричное умножение и накопление на уровне warp-групп (4 warp = 128 потоков) с асинхронным переносом данных.
В FA2 все операции были синхронными: загрузка блока из HBM в SRAM, вычисление, запись результата. FA3 использует асинхронное копирование (async copy) и асинхронное выполнение WGMMA, что позволяет:
- Перекрывать загрузку данных с вычислениями (overlap).
- Увеличить утилизацию тензорных ядер.
- Снизить latency за счёт конвейерной обработки.
Пример псевдокода (сравнение FA2 и FA3):
# FA2 (синхронный подход)
for block in blocks:
load_block_to_sram(block) # синхронно
compute_attention_on_block(block) # синхронно
write_result_to_hbm(block) # синхронно
# FA3 (асинхронный подход)
async_load_next_block() # запускаем загрузку следующего блока
for block in blocks:
wait_async_load_complete() # ждём завершения загрузки текущего
compute_attention_on_block(block) # вычисляем
async_load_next_block() # начинаем загрузку следующего (overlap)
write_result_to_hbm(block) # запись может быть асинхронной
Это даёт прирост скорости ~1.5–2× на H100.
5. Улучшенное партицирование для длинных последовательностей
FA2 разбивал последовательность на блоки фиксированного размера (например, 128 токенов). Для длинных последовательностей (>32K) это приводило к большому числу блоков и неэффективному использованию SRAM.
FA3 вводит адаптивное партицирование (adaptive partitioning):
- Размер блока динамически выбирается в зависимости от длины последовательности и доступной SRAM.
- Для очень длинных последовательностей используется двухуровневое партицирование: сначала разбиваем на крупные сегменты (segment), внутри — на блоки.
- Это уменьшает количество обращений к HBM и улучшает локальность данных.
Результат FA3 может эффективно обрабатывать последовательности до 256K токенов на одном H100 (с ограничением по памяти), тогда как FA2 на 128K уже сильно деградировал.
6. Поддержка FP8
FP8 (8-bit floating point) — формат чисел с плавающей точкой, поддерживаемый тензорными ядрами Hopper (H100). Существует два варианта: E4M3 (4 бита экспонента, 3 бита мантисса) и E5M2 (5 бит экспонента, 2 бита мантисса). FP8 позволяет:
- Удвоить пропускную способность тензорных ядер по сравнению с FP16 (2× больше операций в секунду).
- Снизить объём памяти для хранения весов и активаций.
FA3 поддерживает FP8 для всех операций attention: Q, K, V могут быть в FP8, а softmax и накопление — в FP16/BF16 для точности. Это даёт дополнительное ускорение ~1.5–2× на H100 при незначительной потере качества (если используется правильное масштабирование).
Важно FA3 не просто использует FP8, а интегрирует его с асинхронным выполнением и адаптивным партицированием, чтобы минимизировать overhead от преобразования форматов.
7. Сравнительная таблица FA2 vs FA3
| Характеристика | FlashAttention-2 | FlashAttention-3 |
|---|---|---|
| Архитектура GPU | Ampere (A100), Turing (V100) | Hopper (H100/H800) |
| Инструкции | HMMA (синхронные) | WGMMA (асинхронные) |
| Партицирование | Фиксированные блоки | Адаптивное, двухуровневое |
| Поддержка FP8 | Нет | Да (E4M3/E5M2) |
| Максимальная длина | ~64K (эффективно) | ~256K (эффективно) |
| Ускорение на H100 | 1× (базовый) | 2–3× (в зависимости от длины) |
| Потребление памяти | O(N) | O(N) (но меньше константа) |
| Causal masking | Поддерживается | Поддерживается, оптимизирован |
8. Влияние на производительность
На H100 FA3 показывает следующие результаты (по данным авторов):
- Скорость forward + backward для последовательности 8K: ~2.5× быстрее FA2.
- Скорость forward для последовательности 64K: ~3× быстрее FA2.
- Память: при длине 128K FA3 использует на 30% меньше VRAM, чем FA2 (за счёт лучшего партицирования).
Формула прироста (приблизительно):
Speedup ≈ (FA2_time / FA3_time) = 1.5 (async) × 1.5 (FP8) × 1.2 (partitioning) ≈ 2.7×
9. Применение в LLM
FA3 особенно полезен для:
- Тренировки моделей с длинным контекстом (GPT-4, Llama 3, Gemini) — ускорение в 2–3 раза сокращает время обучения.
- Inference с длинными последовательностями (например, анализ документов, код-генерация) — снижает latency и позволяет обслуживать больше запросов.
- Fine-tuning LoRA/QLoRA — FA3 работает с любыми вариантами attention, включая GQA (Grouped Query Attention) и MQA (Multi-Query Attention).
Ограничение FA3 требует GPU архитектуры Hopper (H100/H800). На A100 или V70 он не работает (падает на CPU fallback или использует FA2).
10. Пет-проект для закрепления
Задача Сравнить производительность FA2 и FA3 на синтетических данных.
Инструменты Python, PyTorch 2.x, NVIDIA H100 (можно через облако Lambda Labs или RunPod), библиотека flash-attn (версия 2.6+ для FA3).
Шаги:
- Установить
pip install flash-attn --no-build-isolation. - Создать случайные тензоры Q, K, V размера (batch=1, heads=32, seq_len=8192, head_dim=128) в FP16.
- Замерить время forward для FA2 (используя
flash_attn_funcсcausal=True). - Замерить время forward для FA3 (тот же вызов, но на H100 он автоматически использует FA3).
- Повторить для seq_len = 16384, 32768, 65536.
- Построить график зависимости времени от длины последовательности.
Ожидаемый результат На H100 FA3 будет в 2–3 раза быстрее FA2, особенно на длинных последовательностях. Потребление памяти будет ниже для FA3.
11. Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 203 | FlashAttention-2: детали реализации |
| 205 | PagedAttention: управление памятью в LLM inference |
| 206 | Виды attention (full, sparse, linear) |
| 202 | LongLoRA: efficient fine-tuning с длинным контекстом |
| 201 | Sparse Attention: разреженные механизмы |
12. Навигация
- Предыдущий: 203
- Следующий: 205
- Индекс: 00. Индекс разборов
Навигация
- Предыдущий: 203
- Следующий: 205
- Индекс: 00. Индекс разборов