Как вы реализуете KV cache для 1M токенов на 8x H100?

Краткий тезис

KV cache для 1M токенов требует ~320 GB в FP16 даже с GQA, что не помещается в HBM одного H100 (80 GB). Решение — комбинация методов: тензорный параллелизм (8-way) для шардирования кэша, INT4-квантование для сжатия в 4 раза, sliding window (окно 128k) для ограничения хранимой истории и KV cache compression (например, H2O или SnapKV) для агрегации старых токенов. batching|Continuous batching с PagedAttention обеспечивает эффективное управление памятью. Итоговый расход на GPU|один GPU — около 10 GB под KV cache, что вписывается в лимиты H100.


1. Что такое KV cache и почему он проблема для 1M токенов

KV cache — это кэш ключей (K) и значений (V) из слоёв self-attention, сохраняемый для ранее сгенерированных токенов. При авторегрессивной генерации каждый новый токен вычисляет attention ко всем предыдущим токенам. Без кэша пришлось бы пересчитывать K и V для всей последовательности на каждом шаге, что даёт квадратичную сложность O(n²). KV cache снижает её до O(n) по памяти и O(1) по времени на шаг (с учётом кэша).

Для 1M токенов размер KV cache становится доминирующим фактором. Например, для Llama-3-70B (80 слоёв, 64 головы запросов, 8 KV-голов благодаря GQA, размер головы 128) один токен требует:

2 (K и V) × 80 слоёв × 8 KV-голов × 128 = 163840 чисел FP16 = 327680 байт ≈ 0.33 MB.

Для 1M токенов: 0.33 MB × 1 000 000 = 330 GB. Это значительно превышает HBM одного H100 (80 GB). Даже на 8 GPU суммарно 640 GB, но без шардирования каждый GPU должен хранить весь кэш — неэффективно.


2. Расчёт размера KV cache для Llama-3-70B

Параметры модели:

ПараметрЗначение
Число слоёв (L)80
Число голов запросов (Q heads)64
Число KV-голов (KV heads)8 (GQA с 8 группами)
Размер головы (d_head)128
Тип данныхFP16 (2 байта)

Размер на один токен:

2 × L × KV_heads × d_head = 2 × 80 × 8 × 128 = 163840 чисел = 327680 байт = 0.32768 MB.

Для 1M токенов: 0.32768 MB × 1 000 000 = 327.68 GB.

Если использовать INT4-квантование (0.5 байта на число), размер уменьшается в 4 раза: 327.68 / 4 = 81.92 GB. Это всё ещё больше одного H100 (80 GB), но уже близко.


3. Метод 1: Grouped Query Attention (GQA)

GQA уменьшает число KV-голов относительно голов запросов. В Llama-3-70B используется 8 KV-голов на 64 головы запросов (группы по 8). Это снижает размер KV cache в 64 / 8 = 8 раз по сравнению с полным multi-head attention.

Без GQA (64 KV-головы) размер для 1M токенов был бы 327.68 × 8 = 2.6 TB — непрактично. GQA — обязательный базовый выбор для длинных контекстов.


4. Метод 2: Квантование KV cache (INT4, FP8)

Квантование снижает точность хранения K и V. INT4 даёт сжатие в 4 раза относительно FP16, FP8 — в 2 раза.

INT4-квантование:

  • Каждое число занимает 4 бита (0.5 байта).
  • Размер на токен: 0.32768 MB / 4 = 0.08192 MB.
  • Для 1M токенов: 81.92 GB.

На практике применяют per-channel или per-token квантование с калибровкой, чтобы минимизировать потерю качества. Современные реализации (например, в vLLM, TensorRT-LLM) поддерживают INT4 KV cache с незначительным падением perplexity (<0.5).

FP8 (1 байт) — компромисс: 163.84 GB для 1M токенов, всё ещё много.


5. Метод 3: Тензорный параллелизм (8-way)

Тензорный параллелизм (TP) распределяет слои и головы между GPU. При TP=8 каждый GPU отвечает за 1/8 KV-голов. Поскольку KV-голов 8, каждый GPU получает ровно 1 KV-голову.

Размер на GPU для одного токена:

2 × 80 × 1 × 128 = 20480 чисел FP16 = 40960 байт = 0.04096 MB.

Для 1M токенов в FP16: 0.04096 MB × 1 000 000 = 40.96 GB.

С INT4: 40.96 / 4 = 10.24 GB.

Это уже помещается в HBM H100 (80 GB) с запасом для параметров модели (70B в FP16 ~140 GB, при TP=8 ~17.5 GB на GPU) и активаций.

Важно: TP требует синхронизации all-reduce между GPU, но для длинных контекстов это оправдано.


6. Метод 4: Sliding Window Attention

Даже с TP и INT4 хранить весь 1M токенов может быть избыточно. Sliding window attention ограничивает контекст фиксированным окном (например, 128k токенов). Старые токены за пределами окна отбрасываются.

Для 1M токенов с окном 128k:

  • Храним только последние 128k токенов.
  • Размер KV cache на GPU (INT4, TP=8): 128000 × 0.04096 MB / 4 = 1.31 GB.

Это очень мало. Однако модель теряет доступ к ранним токенам. Для задач, где важна вся история (например, анализ всего документа), sliding window не подходит. Тогда применяют KV cache compression.


7. Метод 5: KV cache compression (H2O, SnapKV, StreamingLLM)

Методы сжатия KV cache позволяют хранить не все токены, а только важные. Примеры:

  • H2O (Heavy Hitter Oracle): сохраняет токены с наибольшими attention scores (heavy hitters) и случайные для разнообразия. Позволяет сжать кэш в 2–5 раз с малым падением качества.
  • SnapKV: выбирает ключевые токены на основе паттернов attention в первых слоях.
  • StreamingLLM: сохраняет несколько начальных токенов (sink tokens) и последнее окно, остальные отбрасывает.

Для 1M токенов можно комбинировать sliding window (последние 128k) и compression для токенов внутри окна (например, H2O с retention rate 20%). Тогда на GPU (INT4, TP=8) размер будет 128000 × 0.04096 MB / 4 × 0.2 = 0.26 GB — ещё меньше.


8. Continuous batching и PagedAttention

Continuous batching позволяет обрабатывать несколько запросов одновременно, динамически распределяя память под KV cache. PagedAttention (используется в vLLM) управляет KV cache как страницами виртуальной памяти: блоки фиксированного размера (page) выделяются по мере необходимости, что устраняет фрагментацию и позволяет эффективно использовать HBM.

Для 1M токенов на 8 GPU PagedAttention даёт:

  • Гибкое выделение памяти под разные последовательности.
  • Возможность дедупликации кэша для общих префиксов (prefix caching).
  • Поддержка sliding window и compression на уровне страниц.

9. Итоговая архитектура для 1M токенов на 8x H100

Комбинируем все методы:

  1. Модель: Llama-3-70B с GQA (8 KV-голов).
  2. Тензорный параллелизм: 8-way, каждый GPU отвечает за 1 KV-голову.
  3. Квантование KV cache: INT4 (4 бита на число).
  4. Sliding window: окно 128k токенов (последние).
  5. KV cache compression: H2O с retention 20% внутри окна.
  6. Continuous batching + PagedAttention: управление памятью.

Оценка памяти на один GPU:

  • Параметры модели (FP16): 70B / 8 ≈ 8.75B параметров → 17.5 GB.
  • KV cache (INT4, сжатый): 128k × 0.04096 MB / 4 × 0.2 ≈ 0.26 GB.
  • Активации (зависит от batch size, допустим 2 GB).
  • Итого: ~20 GB. HBM H100 (80 GB) — огромный запас.

Производительность: генерация 1M токенов возможна с latency, ограниченной attention over окна 128k (O(n²) внутри окна, но с FlashAttention-2 это ~0.5 сек на шаг). Continuous batching позволяет обслуживать несколько запросов параллельно.


10. Пет-проект для закрепления

Задача: Реализовать прототип KV cache с sliding window и INT4 квантованием для небольшой модели (например, GPT-2 или Llama-3-8B) на одной H100, симулируя контекст 1M токенов (используя случайные данные).

Инструменты: Python, PyTorch, bitsandbytes (для INT4), Hugging Face Transformers.

Шаги:

  1. Загрузить модель с GQA (например, meta-llama/Llama-3.2-1B — у неё 8 KV-голов).
  2. Реализовать класс SlidingWindowKVCache:
    • Хранить K и V в формате INT4 (использовать torch.quantize_per_tensor или bitsandbytes).
    • При добавлении нового токена удалять самый старый, если размер превышает окно.
  3. Написать функцию генерации с пошаговым обновлением кэша.
  4. Замерить потребление памяти при окне 128k и сравнить с теоретическим расчётом.
  5. Добавить H2O-подобную эвристику: сохранять только top-20% токенов по attention score.

Ожидаемый результат: Рабочий генератор, способный обрабатывать последовательности до 1M токенов (симулированных) с памятью < 10 GB на GPU. Код выложить на GitHub с бенчмарками.


11. Связь с другими вопросами

ВопросТема
640Как спроектировать Agentic RAG с длинным контекстом
641Оптимизация внимания для длинных последовательностей (FlashAttention)
643Как реализовать prefix caching для ускорения
644Сравнение методов сжатия контекста (H2O, SnapKV, StreamingLLM)
645Как работает PagedAttention в vLLM
646Балансировка памяти между KV cache и параметрами при TP

12. Навигация


Навигация