Что такое FlashAttention с точки зрения CUDA programming?
Краткий тезис
FlashAttention — это IO-aware алгоритм вычисления attention, который переворачивает традиционный подход: вместо материализации большой матрицы S (scores) в медленной HBM (High Bandwidth Memory), он разбивает вычисления на блоки (tiling) и выполняет их в быстрой SRAM (shared memory) GPU. С точки зрения CUDA programming, это означает ручное управление памятью, kernel fusion, recomputationградиентов и оптимизациюoccupancy для максимальной пропускной способности. FlashAttention позволяет обрабатывать последовательности длиной до 128K токенов на одном GPU, что критически важно для современных LLM и RAG-систем с длинным контекстом.
1. Проблема стандартного attention: узкое место — HBM
Стандартное scaled dot-product attention (формула: Attention(Q,K,V) = softmax(QK^T/√d)V) требует:
- Вычислить матрицу S размера
N×N(гдеN— длина последовательности). - Сохранить её в HBM (глобальная память GPU).
- Прочитать её обратно для softmax и умножения на V.
Проблема HBM имеет высокую пропускную способность (~1.5 TB/s на A100), но latency и bandwidth всё равно на порядок меньше, чем у SRAM (~19 TB/s). Для N=64K матрица S занимает 16 GB (в float16) — это больше, чем вся память многих GPU. Даже для меньших N чтение/запись S доминирует во времени.
Термин HBM (High Bandwidth Memory) — основная память GPU, большая (до 80 GB), но относительно медленная. SRAM (shared memory) — кэш на чипе, маленький (обычно 48–192 KB на SM), но очень быстрый.
2. Основная идея FlashAttention: не материализовать S
FlashAttention (Dao et al., 2022) предлагает не записывать S в HBM. Вместо этого:
- Разбить Q, K, V на блоки (tiles), которые помещаются в SRAM.
- Для каждого блока вычислить частичное внимание, обновляя выход O на лету.
- Использовать online softmax (алгоритм, который вычисляет softmax по частям без хранения всех scores).
Ключевой принцип IO-aware — алгоритм оптимизирует не количество FLOPs, а количество обращений к HBM. FlashAttention делает в 2–4 раза меньше чтений/записей в HBM, чем стандартная реализация.
3. Tiling: разбиение на блоки
Tiling — техника разбиения данных на небольшие фрагменты, которые помещаются в SRAM.
Как это работает
- Пусть размер блока Br (для Q) и Bc (для K,V) выбираются так, чтобы
Br * d + Bc * dумещалось в SRAM (например, 64×64 при d=128). - Загружаем один блок Q и один блок K в SRAM.
- Вычисляем частичную матрицу scores
S_block = Q_block * K_block^T(размерBr × Bc). - Применяем softmax к этому блоку, используя online softmax (храним running max и сумму экспонент).
- Умножаем результат на соответствующий блок V и накапливаем в выходной блок O (который тоже хранится в SRAM и периодически сбрасывается в HBM).
Псевдокод (упрощённо):
for q_block in range(0, N, Br):
load Q_block to SRAM
O_block = zeros(Br, d) # в SRAM
for kv_block in range(0, N, Bc):
load K_block, V_block to SRAM
S = Q_block @ K_block.T # Br x Bc
# online softmax
m_prev = row_max(O_block) # на самом деле храним отдельно
l_prev = row_sum_exp(O_block)
# обновляем m, l, O
m_new = max(m_prev, row_max(S))
l_new = exp(m_prev - m_new) * l_prev + row_sum_exp(S - m_new)
P = exp(S - m_new) # softmax scores
O_block = (l_prev / l_new) * exp(m_prev - m_new) * O_block + (1/l_new) * P @ V_block
write O_block to HBM
Термин Online softmax — алгоритм, который вычисляет softmax по частям, обновляя максимум и сумму экспонент, не требуя хранения всех scores.
4. Recomputation: не хранить S для backward pass
В стандартном attention для обратного распространения ошибки нужны матрицы P (softmax output) и S (scores). FlashAttention не сохраняет их в HBM, а пересчитывает (recompute) во время backward pass.
Как это делается
- Во время forward сохраняются только статистики (row max и row sum) для каждого блока — это
O(N)памяти вместоO(N^2). - Во время backward блоки Q, K, V снова загружаются в SRAM, и по сохранённым статистикам восстанавливаются матрицы P и S (или их части) для вычисления градиентов.
Термин Recomputation — техника, когда промежуточные результаты не сохраняются, а пересчитываются заново. Это увеличивает FLOPs (на ~30%), но радикально снижает потребление памяти и обращения к HBM.
5. Kernel fusion: объединение операций
Kernel fusion — объединение нескольких последовательных операций (matmul, softmax, dropout, matmul) в один CUDA kernel. Это уменьшает количество запусков ядер и синхронизаций, а главное — позволяет держать данные в SRAM между операциями, не сбрасывая в HBM.
В стандартном подходе:
- Kernel 1:
S = Q @ K^T→ запись S в HBM - Kernel 2:
P = softmax(S)→ чтение S, запись P - Kernel 3:
O = P @ V→ чтение P, запись O
FlashAttention делает всё в одном kernel: загрузил блоки, вычислил, обновил O, записал только финальный O.
6. CUDA programming: occupancy и shared memory
Occupancy — отношение числа активных warps на SM к максимально возможному. Высокий occupancy помогает скрыть latency (задержки) при обращении к HBM.
FlashAttention использует shared memory для хранения блоков Q, K, V и промежуточных результатов. Размер shared memory на SM ограничен (например, 48 KB на SM в A100). Поэтому размер блока Br и Bc выбирается так, чтобы не превысить лимит, иначе occupancy упадёт.
Оптимизации:
- Использование async copy (например,
cp.asyncна Ampere) для загрузки данных из HBM в shared memory без блокировки. - Warp-level primitives (shuffle, reduce) для быстрого вычисления softmax внутри warp.
- Bank conflicts — при работе с shared memory нужно избегать конфликтов банков, выравнивая доступ.
Пример конфигурации для A100 (d=128, float16):
- Br = 64, Bc = 64 → Q_block: 64×128 = 8 KB, K_block: 64×128 = 8 KB, V_block: 64×128 = 8 KB, O_block: 64×128 = 8 KB, плюс статистики — всего ~36 KB, укладывается в 48 KB.
7. Сравнение: FlashAttention vs стандартная реализация
| Характеристика | Стандартный attention | FlashAttention |
|---|---|---|
| Память для S | O(N²) | O(N) (только статистики) |
| Обращения к HBM | O(N²) | O(N² / Br) — в разы меньше |
| FLOPs | O(N²d) | O(N²d) + overhead recompute |
| Скорость (N=4096, A100) | ~1.0x | ~2-3x быстрее |
| Макс. длина контекста | Ограничен памятью (N ~ 8K) | До 128K+ |
8. FlashAttention v2 и v3: эволюция
- FlashAttention v2 (2023): улучшенная организация циклов (параллелизм по головам), меньше регистров, поддержка multi-query attention.
- FlashAttention v3 (2024): использует Hopper архитектуру (H100), warp specialization (одни warps загружают данные, другие считают), asynchronous execution.
- FlashAttention-2 уже стала стандартом в PyTorch (через
torch.nn.functional.scaled_dot_product_attention).
9. Влияние на RAG и LLM
FlashAttention позволяет:
- Использовать длинные контексты (до 128K токенов) в RAG-системах без потери производительности.
- Обучать модели с long-context (например, 32K, 64K) на одном GPU.
- Ускорять inference в 2-3 раза, что критично для Agentic RAG, где агент может делать много шагов с attention.
Пет-проект для закрепления
Задача Реализовать упрощённую версию FlashAttention на Python с имитацией tiling и online softmax (без CUDA, но с пониманием алгоритма).
Инструменты Python, NumPy, PyTorch (для проверки).
Шаги:
- Напишите функцию
flash_attention_simulated(Q, K, V, block_size=64). - Разбейте Q на блоки по
block_size(по строкам), K и V — по столбцам. - Для каждого блока Q загрузите соответствующие блоки K, V.
- Реализуйте online softmax с хранением
mиl. - Обновите выходной блок O.
- Сравните результат с
torch.nn.functional.scaled_dot_product_attention(которая уже использует FlashAttention). - Измерьте количество операций чтения/записи (имитация) и сравните с наивной реализацией.
Ожидаемый результат Вы убедитесь, что результат совпадает с точностью до float16, а количество обращений к «HBM» (в вашей симуляции — к большому массиву) уменьшилось в N/block_size раз.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 302 | Как работает attention в transformer? |
| 303 | Что такое KV cache и как он оптимизирует inference? |
| 305 | Как устроена архитектура Mixture of Experts (MoE)? |
| 306 | Какие техники используются для обработки длинных контекстов? |
| 310 | Как работает speculative decoding? |
| 315 | Что такое PagedAttention и vLLM? |
Навигация
- Предыдущий: 303
- Следующий: 305
- Индекс: 00. Индекс разборов