English translation is not available yet. Showing Russian content.

Что такое FlashAttention с точки зрения CUDA programming?

Краткий тезис

FlashAttention — это IO-aware алгоритм вычисления attention, который переворачивает традиционный подход: вместо материализации большой матрицы S (scores) в медленной HBM (High Bandwidth Memory), он разбивает вычисления на блоки (tiling) и выполняет их в быстрой SRAM (shared memory) GPU. С точки зрения CUDA programming, это означает ручное управление памятью, kernel fusion, recomputationградиентов и оптимизациюoccupancy для максимальной пропускной способности. FlashAttention позволяет обрабатывать последовательности длиной до 128K токенов на одном GPU, что критически важно для современных LLM и RAG-систем с длинным контекстом.


1. Проблема стандартного attention: узкое место — HBM

Стандартное scaled dot-product attention (формула: Attention(Q,K,V) = softmax(QK^T/√d)V) требует:

  • Вычислить матрицу S размера N×N (где N — длина последовательности).
  • Сохранить её в HBM (глобальная память GPU).
  • Прочитать её обратно для softmax и умножения на V.

Проблема HBM имеет высокую пропускную способность (~1.5 TB/s на A100), но latency и bandwidth всё равно на порядок меньше, чем у SRAM (~19 TB/s). Для N=64K матрица S занимает 16 GB (в float16) — это больше, чем вся память многих GPU. Даже для меньших N чтение/запись S доминирует во времени.

Термин HBM (High Bandwidth Memory) — основная память GPU, большая (до 80 GB), но относительно медленная. SRAM (shared memory) — кэш на чипе, маленький (обычно 48–192 KB на SM), но очень быстрый.


2. Основная идея FlashAttention: не материализовать S

FlashAttention (Dao et al., 2022) предлагает не записывать S в HBM. Вместо этого:

  • Разбить Q, K, V на блоки (tiles), которые помещаются в SRAM.
  • Для каждого блока вычислить частичное внимание, обновляя выход O на лету.
  • Использовать online softmax (алгоритм, который вычисляет softmax по частям без хранения всех scores).

Ключевой принцип IO-aware — алгоритм оптимизирует не количество FLOPs, а количество обращений к HBM. FlashAttention делает в 2–4 раза меньше чтений/записей в HBM, чем стандартная реализация.


3. Tiling: разбиение на блоки

Tiling — техника разбиения данных на небольшие фрагменты, которые помещаются в SRAM.

Как это работает

  • Пусть размер блока Br (для Q) и Bc (для K,V) выбираются так, чтобы Br * d + Bc * d умещалось в SRAM (например, 64×64 при d=128).
  • Загружаем один блок Q и один блок K в SRAM.
  • Вычисляем частичную матрицу scores S_block = Q_block * K_block^T (размер Br × Bc).
  • Применяем softmax к этому блоку, используя online softmax (храним running max и сумму экспонент).
  • Умножаем результат на соответствующий блок V и накапливаем в выходной блок O (который тоже хранится в SRAM и периодически сбрасывается в HBM).

Псевдокод (упрощённо):

for q_block in range(0, N, Br):
    load Q_block to SRAM
    O_block = zeros(Br, d)  # в SRAM
    for kv_block in range(0, N, Bc):
        load K_block, V_block to SRAM
        S = Q_block @ K_block.T  # Br x Bc
        # online softmax
        m_prev = row_max(O_block)  # на самом деле храним отдельно
        l_prev = row_sum_exp(O_block)
        # обновляем m, l, O
        m_new = max(m_prev, row_max(S))
        l_new = exp(m_prev - m_new) * l_prev + row_sum_exp(S - m_new)
        P = exp(S - m_new)  # softmax scores
        O_block = (l_prev / l_new) * exp(m_prev - m_new) * O_block + (1/l_new) * P @ V_block
    write O_block to HBM

Термин Online softmax — алгоритм, который вычисляет softmax по частям, обновляя максимум и сумму экспонент, не требуя хранения всех scores.


4. Recomputation: не хранить S для backward pass

В стандартном attention для обратного распространения ошибки нужны матрицы P (softmax output) и S (scores). FlashAttention не сохраняет их в HBM, а пересчитывает (recompute) во время backward pass.

Как это делается

  • Во время forward сохраняются только статистики (row max и row sum) для каждого блока — это O(N) памяти вместо O(N^2).
  • Во время backward блоки Q, K, V снова загружаются в SRAM, и по сохранённым статистикам восстанавливаются матрицы P и S (или их части) для вычисления градиентов.

Термин Recomputation — техника, когда промежуточные результаты не сохраняются, а пересчитываются заново. Это увеличивает FLOPs (на ~30%), но радикально снижает потребление памяти и обращения к HBM.


5. Kernel fusion: объединение операций

Kernel fusion — объединение нескольких последовательных операций (matmul, softmax, dropout, matmul) в один CUDA kernel. Это уменьшает количество запусков ядер и синхронизаций, а главное — позволяет держать данные в SRAM между операциями, не сбрасывая в HBM.

В стандартном подходе:

  • Kernel 1: S = Q @ K^T → запись S в HBM
  • Kernel 2: P = softmax(S) → чтение S, запись P
  • Kernel 3: O = P @ V → чтение P, запись O

FlashAttention делает всё в одном kernel: загрузил блоки, вычислил, обновил O, записал только финальный O.


6. CUDA programming: occupancy и shared memory

Occupancy — отношение числа активных warps на SM к максимально возможному. Высокий occupancy помогает скрыть latency (задержки) при обращении к HBM.

FlashAttention использует shared memory для хранения блоков Q, K, V и промежуточных результатов. Размер shared memory на SM ограничен (например, 48 KB на SM в A100). Поэтому размер блока Br и Bc выбирается так, чтобы не превысить лимит, иначе occupancy упадёт.

Оптимизации:

  • Использование async copy (например, cp.async на Ampere) для загрузки данных из HBM в shared memory без блокировки.
  • Warp-level primitives (shuffle, reduce) для быстрого вычисления softmax внутри warp.
  • Bank conflicts — при работе с shared memory нужно избегать конфликтов банков, выравнивая доступ.

Пример конфигурации для A100 (d=128, float16):

  • Br = 64, Bc = 64 → Q_block: 64×128 = 8 KB, K_block: 64×128 = 8 KB, V_block: 64×128 = 8 KB, O_block: 64×128 = 8 KB, плюс статистики — всего ~36 KB, укладывается в 48 KB.

7. Сравнение: FlashAttention vs стандартная реализация

ХарактеристикаСтандартный attentionFlashAttention
Память для SO(N²)O(N) (только статистики)
Обращения к HBMO(N²)O(N² / Br) — в разы меньше
FLOPsO(N²d)O(N²d) + overhead recompute
Скорость (N=4096, A100)~1.0x~2-3x быстрее
Макс. длина контекстаОграничен памятью (N ~ 8K)До 128K+

8. FlashAttention v2 и v3: эволюция

  • FlashAttention v2 (2023): улучшенная организация циклов (параллелизм по головам), меньше регистров, поддержка multi-query attention.
  • FlashAttention v3 (2024): использует Hopper архитектуру (H100), warp specialization (одни warps загружают данные, другие считают), asynchronous execution.
  • FlashAttention-2 уже стала стандартом в PyTorch (через torch.nn.functional.scaled_dot_product_attention).

9. Влияние на RAG и LLM

FlashAttention позволяет:

  • Использовать длинные контексты (до 128K токенов) в RAG-системах без потери производительности.
  • Обучать модели с long-context (например, 32K, 64K) на одном GPU.
  • Ускорять inference в 2-3 раза, что критично для Agentic RAG, где агент может делать много шагов с attention.

Пет-проект для закрепления

Задача Реализовать упрощённую версию FlashAttention на Python с имитацией tiling и online softmax (без CUDA, но с пониманием алгоритма).

Инструменты Python, NumPy, PyTorch (для проверки).

Шаги:

  1. Напишите функцию flash_attention_simulated(Q, K, V, block_size=64).
  2. Разбейте Q на блоки по block_size (по строкам), K и V — по столбцам.
  3. Для каждого блока Q загрузите соответствующие блоки K, V.
  4. Реализуйте online softmax с хранением m и l.
  5. Обновите выходной блок O.
  6. Сравните результат с torch.nn.functional.scaled_dot_product_attention (которая уже использует FlashAttention).
  7. Измерьте количество операций чтения/записи (имитация) и сравните с наивной реализацией.

Ожидаемый результат Вы убедитесь, что результат совпадает с точностью до float16, а количество обращений к «HBM» (в вашей симуляции — к большому массиву) уменьшилось в N/block_size раз.


Связь с другими вопросами

ВопросТема
302Как работает attention в transformer?
303Что такое KV cache и как он оптимизирует inference?
305Как устроена архитектура Mixture of Experts (MoE)?
306Какие техники используются для обработки длинных контекстов?
310Как работает speculative decoding?
315Что такое PagedAttention и vLLM?

Навигация