中文翻译暂不可用,显示俄语原文。

Что такое multi-query attention (MQA) и grouped-query attention (GQA) и зачем они?

Краткий тезис

Attention|Multi-Query Attention (MQA) и Attention|Grouped-Query Attention (GQA) — это модификации стандартного multi-head attention (MHA), которые уменьшают объём KV-кэша и ускоряют авторегрессивную генерацию за счёт разделения ключей и значений между несколькими головами внимания. MQA использует один общий набор KV для всех голов, а GQA делит головы на группы, каждая со своим общим KV — это позволяет найти компромисс между скоростью декодирования и качеством предсказаний.


1. Термины: Multi-head attention, KV-кэш и проблема декодирования

Multi-head attention (MHA) — стандартный механизм, где входная последовательность проецируется в несколько голов (heads), каждая из которых независимо вычисляет attention по своим Q, K, V. На этапе генерации (декодирования) модель работает авторегрессивно: каждый новый токен использует все предыдущие токены как контекст. Чтобы не пересчитывать K и V для уже обработанных токенов, их сохраняют в KV-кэш — это массив размера (num_layers, batch_size, num_heads, seq_len, head_dim). Для больших моделей (70B+) и длинных контекстов KV-кэш может занимать десятки гигабайт памяти, что ограничивает batch size и увеличивает latency.


2. Проблема: затраты памяти и bandwidth при инференсе

При каждом шаге декодирования:

  • Нужно загрузить весь KV-кэш из памяти GPU в кэш (high memory bandwidth)
  • Выполнить attention для всех голов параллельно
  • Размер KV-кэша растёт линейно с длиной последовательности, числом голов и batch size

Это приводит к:

  • Memory-bound режиму: скорость ограничена не вычислениями, а пропускной способностью памяти
  • Ограничению максимального batch size (нельзя обработать много запросов одновременно)
  • Увеличению time-to-first-token и per-token latency

3. Multi-Query Attention (MQA)

Multi-Query Attention предложили в статье Shazeer et al. (2019). Идея: все головы внимания используют общие ключи и значения (один набор KV), но каждая голова сохраняет собственный Query проецирование.

Архитектурно:

  • K и V проецируются один раз, а не num_heads раз
  • Размер KV-кэша уменьшается в num_heads раз (например, в LLaMA-1 — 32 раза)
  • Количество голов для K и V = 1, для Q = num_heads

Плюсы:

  • Резкое снижение памяти KV-кэша → возможность использовать больший batch
  • Ускорение декодирования (до 2–10 раз в зависимости от модели и железа)
  • Меньше bandwidth — быстрее загрузка кэша

Минусы:

  • Потеря точности, особенно на сложных задачах (например, понимание длинных контекстов, редкие факты)
  • Ограниченная способность модели различать разные контексты разными головами

MQA использовалась в LLaMA-1 65B и Falcon 40B.


4. Grouped-Query Attention (GQA)

Grouped-Query Attention — эволюция MQA, предложенная Ainslie et al. (2023) и применённая в LLaMA-2/3. Компромисс: головы делятся на группы, и внутри каждой группы используется общий набор KV.

Параметры:

  • num_groups — количество групп
  • Внутри каждой группы: num_heads_per_group = total_heads / num_groups голов используют один KV
  • Когда num_groups = 1 → это MQA
  • Когда num_groups = total_heads → это MHA

Пример: LLaMA-2 70B: 64 головы, 8 групп → каждая группа из 8 голов имеет общий KV. KV-кэш уменьшается в 8 раз по сравнению с MHA.

Плюсы:

  • Гибкость — можно настроить число групп под hardware и требования к качеству
  • Качество выше, чем у MQA, особенно при большом числе групп
  • Ускорение почти как у MQA (особенно при небольшом числе групп)

Минусы:

  • Выбор числа групп — гиперпараметр, требует экспериментов
  • Всё ещё деградация качества относительно MHA, хотя часто незаметна

5. Сравнение MHA, MQA и GQA (таблица)

ПараметрMHA (стандарт)MQAGQA
Число KV-головnum_heads1num_groups (обычно 8–64)
Размер KV-кэшаnum_heads × (seq_len × dim)1 × (seq_len × dim)num_groups × (seq_len × dim)
Скорость декодированиябазоваяв 2–10× быстреев 1.5–5× быстрее (зависит от групп)
Качество (perplexity)эталонноезаметное снижениеблизкое к MHA (отклонение < 0.1-0.3)
Применениеобучение, точные задачибыстрый инференс, small batchбаланс (LLaMA-2/3, Gemma, Mistral)
Сложность реализациистандартнаяпростая модификациячуть сложнее, чем MQA

6. Практические реализации в современных моделях

  • LLaMA-1: 65B — MQA (32 Q-головы, 1 KV-голова). Более маленькие версии (7B, 13B) остались MHA.
  • LLaMA-2: 70B — GQA с 8 группами. 7B и 13B — MHA (позже LLaMA-3 8B тоже GQA).
  • LLaMA-3: все версии (8B, 70B) — GQA с 8 группами (как и LLaMA-2 70B).
  • Mistral 7B / Mixtral: GQA с 8 группами (32 Q-головы, 8 KV-голов).
  • Falcon 180B: Multi-Query Attention (1 группа).
  • Gemma: GQA с 4 группами (от Google).

Тенденция: большинство современных моделей переходят на GQA как стандарт, потому что он даёт почти бесплатное ускорение с минимальной потерей качества.


7. Влияние на инференс и throughput

MQA/GQA особенно выгодны при batch inference (обработка нескольких запросов одновременно). Поскольку KV-кэш уменьшен, можно увеличить batch size в несколько раз, не выходя за лимит памяти GPU. Это прямо повышает throughput (запросов/сек). Для моделей, развёрнутых как API (агенты, chat-боты), это снижает себестоимость.

Пример: для LLaMA-70B с MHA максимальный batch size = 16 на A100 80GB. С GQA (8 групп) — batch size = 64. Пропускная способность растёт ~4× при том же бюджете памяти.


8. Связь с agentic RAG

В сценариях agentic RAG LLM вызывается многократно в рамках одного ответа (планирование, итеративные запросы к инструментам, рефлексия). Каждый вызов генерирует новые токены, накапливая KV-кэш. Если длина контекста большая (сумма всех сообщений истории, ретривальных документов), KV-кэш быстро растёт. Использование GQA критично:

  • снижает latency каждого шага (агент быстрее отвечает)
  • позволяет держать более длинную историю в кэше (важно для многошаговых рассуждений)
  • даёт возможность агрегировать результаты нескольких агентов в одном батче (batch inferences).

9. Когда выбирать MQA, GQA или MHA

СценарийРекомендованный тип
Точность критична (исследования, медицина)MHA (или GQA с max групп)
Высоконагруженный API, много запросовGQA (8–16 групп)
Ограниченная память GPU (мобильные)MQA
Длинный контекст (32k+)GQA с малым числом групп
Обучение модели с нуляMHA (потом можно fine-tune с GQA)

На практике GQA — разумный выбор по умолчанию: он не требует тонкой настройки архитектуры обучения (можно обучить с MHA, потом заменить KV-проекции, хотя лучше обучать сразу с GQA).


10. Альтернативы и комплементарные техники

  • KV cache quantization — сжатие KV-кэша до 4 бит/2 бит (дополнительная экономия памяти).
  • Sliding window attention — ограничение контекста окном (как в Mistral).
  • FlashAttention — более эффективная реализация attention, но не уменьшает размер кэша.
  • Multi-Latent Attention (MLA) — из DeepSeek-V2: сжимает KV в латентное пространство, ещё лучше, чем GQA, но сложнее.

MQA/GQA можно комбинировать со всеми этими техниками.


Пет-проект для закрепления

Задача: Реализовать упрощённый трансформер для генерации текста с возможностью переключения между MHA, MQA и GQA. Сравнить скорость декодирования и память при разных длинах контекста и batch size.

Инструменты: PyTorch, Hugging Face Transformers (для baseline), CUDA профилирование (nvcc, torch.cuda).
Шаги:

  1. Возьмите готовую малую модель (например, GPT-2) и замените nn.MultiheadAttention на кастомную реализацию с параметром num_kv_heads.
  2. Напишите классы:
    • MultiHeadAttention — по умолчанию (num_kv_heads = num_heads)
    • MultiQueryAttention — num_kv_heads = 1
    • GroupedQueryAttention — num_kv_heads = 4,8 (группировка)
  3. Измерьте:
    • размер KV-кэша (по element_size())
    • время генерации 100 токенов для batch size = 1, 8, 32
    • perplexity на тестовом датасете (WikiText-2) для каждого варианта
  4. Постройте графики: latency vs batch size, perplexity vs количество групп.

Ожидаемый результат: Вы увидите, что MQA и GQA дают выигрыш в скорости при batch > 1, а GQA с 4–8 группами почти не уступает MHA по перплексии.


Связь с другими вопросами

ВопросТема
274Что такое attention и multi-head attention?
275Что такое FlashAttention и как оно ускоряет вычисления?
276Чем causal attention отличается от bidirectional?
278Что такое speculative decoding и как оно ускоряет генерацию?
279Как работает multi-step reasoning в LLM?
280Что такое KV-кэш и как его оптимизировать?

Навигация