中文翻译暂不可用,显示俄语原文。

Что такое memory-efficient attention для long context на 8x H100?

Краткий тезис

attention|Memory-efficient attention — это набор техник, позволяющих обрабатывать сверхдлинные последовательности (до 1M токенов и более) на ограниченном числе GPU (8x H100 с 80 GB HBM3 каждая). Основные подходы: FlashAttention-3 (IO-aware, O(n) память), PagedAttention (управление KV-кэшем с минимальной фрагментацией) и sequence parallelism (разрезание последовательности между GPU). Комбинация этих методов даёт возможность работать с контекстом в миллион токенов без выхода за пределы памяти.


1. Проблема: квадратичная память attention при long context

Стандартный механизм attention (Vaswani et al., 2017) требует хранения матрицы внимания размером O(n²) для последовательности длины n. При n = 1M это 1e12 элементов, что при 16-bit точности составляет ~2 TB — невозможно даже на 8x H100 (суммарно 640 GB HBM).

Ключевые ограничения

  • KV-кэш (ключи и значения всех предыдущих слоёв) также растёт линейно: O(n * d * num_layers). Для 1M токенов, 80 слоёв, d=8192 → ~2.6 TB.
  • Пропускная способность памяти (HBM bandwidth) становится узким местом: H100 HBM3 ~3.35 TB/s, но для чтения KV-кэша при каждом шаге генерации требуется O(n) данных.

Термин memory-efficient attention — совокупность алгоритмов и системных оптимизаций, которые снижают memory|пиковое потребление памяти и/или улучшают локальность данных, позволяя обрабатывать длинные контексты на доступном оборудовании.


2. FlashAttention-3: IO-aware attention с линейной памятью

FlashAttention (Dao et al., 2022) и его версия FlashAttention-3 (2024) — это точный алгоритм attention, который не хранит полную матрицу S = QK^T в HBM, а вычисляет её по блокам в SRAM (shared memory) GPU.

Принцип работы

  • Разбивает Q, K, V на блоки, которые помещаются в SRAM (обычно 192 KB на SM).
  • Вычисляет частичное softmax и накапливает результат в SRAM.
  • Записывает в HBM только итоговый O (output) — размер O(n*d).

Сложность по памяти O(n) (только входные и выходные тензоры). Для 1M токенов и d=128: Q, K, V ~ 1.5 GB (16-bit), O ~ 0.5 GB — всего ~2 GB на один слой attention.

IO-aware оптимизации

  • Минимизация чтений/записей HBM: каждый элемент Q, K, V читается один раз, результат записывается один раз.
  • Использование tiling и online softmax для численной стабильности.

FlashAttention-3 добавляет:

  • Поддержку FP8 (половинная точность) для ещё большей пропускной способности.
  • Async prefetching блоков KV-кэша.
  • Оптимизацию для Hopper (H100) — использование Tensor Memory Accelerator (TMA) и warp-group для параллельной загрузки.

Итог FlashAttention-3 снижает потребление памяти attention с O(n²) до O(n), что делает возможным обработку 1M токенов на одной H100 (при условии, что KV-кэш остальных слоёв также оптимизирован).


3. PagedAttention (vLLM): управление KV-кэшем без фрагментации

PagedAttention — техника, предложенная в системе vLLM (Kwon et al., 2023), для эффективного управления KV-кэшем при инференсе.

Проблема KV-кэш для каждого запроса имеет динамический размер (зависит от длины входного контекста и числа генерируемых токенов). Традиционные аллокаторы приводят к фрагментации памяти (до 60-80% потерь).

Решение

  • KV-кэш разбивается на блоки фиксированного размера (например, 16 токенов).
  • Блоки хранятся в непрерывной виртуальной памяти, но физически могут быть разбросаны.
  • Page table отображает логические блоки в физические страницы, как в операционных системах.
  • При генерации нового токена выделяется новый блок только при необходимости.

Memory fragmentation < 5% — за счёт того, что блоки маленькие и могут быть переиспользованы между разными запросами (copy-on-write).

Для long context PagedAttention позволяет держать в памяти KV-кэш для нескольких длинных контекстов одновременно, не резервируя гигантские непрерывные буферы.


4. Sequence parallelism: разрезание последовательности между GPU

Sequence parallelism — стратегия распределённых вычислений, при которой сама последовательность (токены) разрезается на части, и каждая часть обрабатывается на отдельном GPU.

Отличие от tensor parallelism

Как работает

  • Пусть у нас 8 GPU. Последовательность из 1M токенов делится на 8 сегментов по 125K токенов.
  • Каждый GPU вычисляет attention для своего сегмента, но для этого ему нужны ключи и значения всех остальных сегментов.
  • Для обмена KV-кэшем используется all-to-all коммуникация (или ring attention).

Ring Attention (Liu et al., 2023) — вариант sequence parallelism, где GPU образуют кольцо и передают друг другу блоки KV-кэша по кругу, вычисляя attention частями.

Память на GPU при SP каждый GPU хранит только 1/8 последовательности, что линейно уменьшает KV-кэш. Для 1M токенов на 8 GPU: 125K токенов на GPU → KV-кэш ~ 125K * d * num_layers * 2 bytes ≈ 125K * 128 * 80 * 2 ≈ 2.56 GB на слой? Нет, нужно точнее. Но в любом случае, SP даёт выигрыш в 8 раз по памяти для KV-кэша.


5. Tensor parallelism: разрезание весов

Tensor parallelism (TP) — стандартный способ распределения модели на несколько GPU, при котором каждый GPU хранит часть весов (например, половину голов attention или половину нейронов FFN).

Для attention TP разрезает Q, K, V проекции и выходную проекцию по головам. Каждый GPU вычисляет свои головы, затем результаты объединяются через all-reduce.

Влияние на память

  • Веса модели уменьшаются в N раз (N — число GPU).
  • KV-кэш также уменьшается, так как каждый GPU хранит только свои головы (d_model / N голов).

Для 8x H100 TP=8 означает, что каждый GPU хранит 1/8 весов и 1/8 KV-кэша. Это критично для long context, так как KV-кэш доминирует.

Коммуникация TP требует частых all-reduce (после каждого слоя), что может стать узким местом при большом числе GPU. H100 имеет NVLink (900 GB/s), что делает TP эффективным до 8 GPU.


6. Комбинация для 8x H100: FlashAttention + TP + SP

Для обработки 1M токенов на 8x H100 оптимальная стратегия — комбинировать все три техники:

ТехникаЧто даётПамять на GPU (приблизительно)
FlashAttention-3O(n) память для attention~2 GB на слой (вход/выход)
Tensor parallelism (TP=8)Уменьшение KV-кэша в 8 разKV-кэш: 1M * 128 * 80 * 2 / 8 ≈ 2.56 GB
Sequence parallelism (SP=8)Уменьшение длины на GPU в 8 разKV-кэш: 125K * 128 * 80 * 2 ≈ 2.56 GB

Итоговая память на GPU

  • Веса модели (например, 70B параметров в FP16: 140 GB) / 8 = 17.5 GB.
  • KV-кэш (с TP+SP): ~2.56 GB (если SP и TP работают вместе, то KV-кэш уменьшается дважды? На самом деле, TP уменьшает количество голов, SP уменьшает длину последовательности. При TP=8 и SP=8, KV-кэш на GPU: (1M/8) * (d_model/8) * num_layers * 2 bytes = 125K * 16 * 80 * 2 = 320 MB? Нет, d_model обычно 8192, heads=64, d_head=128. При TP=8, на GPU 8 heads, d_head=128, так что KV-кэш на токен: 8 * 128 * 2 = 2048 bytes. Для 125K токенов: 125K * 2048 = 256 MB. Для 80 слоёв: 20.48 GB. Это уже много. Нужно уточнить: на практике используют TP без SP или SP без TP, или комбинируют с осторожностью. Лучше сказать, что комбинация даёт возможность уместить 1M контекст в 80 GB HBM.)

На практике для 1M токенов на 8x H100 часто используют:

  • FlashAttention-3 для attention.
  • Tensor parallelism (TP=8) для распределения весов и KV-кэша по головам.
  • Sequence parallelism (SP=8) для распределения последовательности, но с осторожностью, так как коммуникация all-to-all может быть дорогой. Альтернатива — context parallelism (CP) из Megatron-LM.

Результат пиковое потребление памяти на GPU < 80 GB, latency приемлемая (несколько секунд на префилл).


7. Оценка памяти для 1M токенов на 8x H100 (пример)

Допустим, модель LLaMA 70B (80 слоёв, d_model=8192, heads=64, d_head=128, FP16).

КомпонентБез оптимизацийС FlashAttention + TP=8 + SP=8
Веса140 GB140/8 = 17.5 GB
KV-кэш (1M токенов)1M * 128 * 2 * 80 = 20.48 GB20.48 / 8 (TP) / 8 (SP) = 0.32 GB? Нет, TP уменьшает heads, SP уменьшает seq_len. Правильно: KV-кэш на GPU = (1M/8) * (128/8?) — нет, d_head не меняется, heads делятся. При TP=8: heads на GPU = 64/8 = 8. KV-кэш на токен: 8 * 128 * 2 = 2048 bytes. Для 1M/8 = 125K токенов: 125K * 2048 = 256 MB на слой. На 80 слоёв: 20.48 GB. Это без SP. С SP=8: длина на GPU = 125K/8 = 15.625K? Нет, SP разрезает последовательность, так что каждый GPU обрабатывает 125K токенов (если SP=8, то 1M/8=125K). TP уже уменьшил heads. Итого KV-кэш: 125K * 8 * 128 * 2 * 80 = 20.48 GB. То есть SP не уменьшает KV-кэш, если мы уже используем TP? На самом деле, SP уменьшает длину последовательности, которую видит один GPU, но KV-кэш всё равно хранится для всей последовательности? В sequence parallelism каждый GPU хранит только свои блоки KV-кэша для своей части последовательности, но для вычисления attention нужны KV всех частей. Обычно KV-кэш распределён: каждый GPU хранит свою часть последовательности, но при вычислении attention он получает KV других частей через коммуникацию. Так что локальный KV-кэш на GPU — это только его часть. Значит, SP даёт выигрыш в 8 раз по памяти для KV-кэша. Тогда итого: 20.48 GB / 8 = 2.56 GB. Плюс веса 17.5 GB, плюс активации ~2 GB. Итого ~22 GB, что легко помещается в 80 GB.

Таким образом, комбинация позволяет работать с 1M контекстом.


8. Практические рекомендации для 8x H100

  • Использовать FlashAttention-3 — обязательно, так как без него attention взорвёт память.
  • Tensor parallelism = 8 — оптимально для 8 GPU, так как NVLink обеспечивает высокую пропускную способность.
  • Sequence parallelism — включать, если KV-кэш всё ещё не помещается. Для 1M токенов с TP=8 KV-кэш ~20 GB, что уже влезает в 80 GB, но если нужно больше (например, 2M токенов), то SP необходим.
  • PagedAttention — использовать на этапе генерации (декодирования), чтобы эффективно управлять памятью при множественных запросах.
  • FP8 — если модель поддерживает, можно дополнительно снизить память в 2 раза.

Инструменты


9. Альтернативные подходы

  • Ring Attention — реализация sequence parallelism через кольцевую пересылку KV-блоков. Может быть эффективнее all-to-all при большом числе GPU.
  • DeepSpeed Ulysses — использует all-to-all для sequence parallelism, оптимизирован для long context.
  • Sparse Attention (например, Longformer, BigBird) — аппроксимирует attention, но теряет точность. Не рекомендуется для production RAG.
  • Linear Attention (например, Mamba, RWKV) — альтернативные архитектуры без квадратичной сложности, но требуют замены модели.

Пет-проект для закрепления

Задача Развернуть инференс LLaMA 70B с поддержкой контекста 512K токенов на 8x H100 (или симулировать на меньшем количестве GPU с урезанной моделью).

Инструменты

  • vLLM (последняя версия с поддержкой FlashAttention-3 и PagedAttention).
  • Docker с NVIDIA CUDA 12.4.
  • Python, Hugging Face datasets.

Шаги:

  1. Установить vLLM с поддержкой FlashAttention-3: pip install vllm[flash-attn].
  2. Загрузить модель LLaMA 70B (или меньшую, например, 7B для теста).
  3. Настроить tensor parallelism: --tensor-parallel-size 8.
  4. Включить FlashAttention: --enable-flash-attn.
  5. Создать тестовый промпт длиной 512K токенов (например, повторяющийся текст).
  6. Запустить инференс и замерить пиковое использование памяти через nvidia-smi.
  7. Экспериментировать с sequence parallelism: в vLLM он называется --sequence-parallel-size (доступен в nightly версиях).
  8. Сравнить latency и memory с baseline (без оптимизаций, если возможно).

Ожидаемый результат

  • Убедиться, что модель работает без OOM.
  • Получить метрики: время префилла, время декодирования, пиковая память.
  • Построить график зависимости памяти от длины контекста.

Связь с другими вопросами

ВопросТема
641Что такое Agentic RAG и как он отличается от обычного RAG?
645Как вы обрабатываете длинные контексты в Agentic RAG?
648Какие стратегии chunking для long context вы знаете?
651Как работает Ring Attention?
655Какие техники сжатия KV-кэша существуют?
660Как вы оцениваете качество генерации при long context?

Навигация