English translation is not available yet. Showing Russian content.
Что такое multi-query attention (MQA) и grouped-query attention (GQA) и зачем они?
Краткий тезис
Attention|Multi-Query Attention (MQA) и Attention|Grouped-Query Attention (GQA) — это модификации стандартного multi-head attention (MHA), которые уменьшают объём KV-кэша и ускоряют авторегрессивную генерацию за счёт разделения ключей и значений между несколькими головами внимания. MQA использует один общий набор KV для всех голов, а GQA делит головы на группы, каждая со своим общим KV — это позволяет найти компромисс между скоростью декодирования и качеством предсказаний.
1. Термины: Multi-head attention, KV-кэш и проблема декодирования
Multi-head attention (MHA) — стандартный механизм, где входная последовательность проецируется в несколько голов (heads), каждая из которых независимо вычисляет attention по своим Q, K, V. На этапе генерации (декодирования) модель работает авторегрессивно: каждый новый токен использует все предыдущие токены как контекст. Чтобы не пересчитывать K и V для уже обработанных токенов, их сохраняют в KV-кэш — это массив размера (num_layers, batch_size, num_heads, seq_len, head_dim). Для больших моделей (70B+) и длинных контекстов KV-кэш может занимать десятки гигабайт памяти, что ограничивает batch size и увеличивает latency.
2. Проблема: затраты памяти и bandwidth при инференсе
При каждом шаге декодирования:
- Нужно загрузить весь KV-кэш из памяти GPU в кэш (high memory bandwidth)
- Выполнить attention для всех голов параллельно
- Размер KV-кэша растёт линейно с длиной последовательности, числом голов и batch size
Это приводит к:
- Memory-bound режиму: скорость ограничена не вычислениями, а пропускной способностью памяти
- Ограничению максимального batch size (нельзя обработать много запросов одновременно)
- Увеличению time-to-first-token и per-token latency
3. Multi-Query Attention (MQA)
Multi-Query Attention предложили в статье Shazeer et al. (2019). Идея: все головы внимания используют общие ключи и значения (один набор KV), но каждая голова сохраняет собственный Query проецирование.
Архитектурно:
- K и V проецируются один раз, а не num_heads раз
- Размер KV-кэша уменьшается в num_heads раз (например, в LLaMA-1 — 32 раза)
- Количество голов для K и V = 1, для Q = num_heads
Плюсы:
- Резкое снижение памяти KV-кэша → возможность использовать больший batch
- Ускорение декодирования (до 2–10 раз в зависимости от модели и железа)
- Меньше bandwidth — быстрее загрузка кэша
Минусы:
- Потеря точности, особенно на сложных задачах (например, понимание длинных контекстов, редкие факты)
- Ограниченная способность модели различать разные контексты разными головами
MQA использовалась в LLaMA-1 65B и Falcon 40B.
4. Grouped-Query Attention (GQA)
Grouped-Query Attention — эволюция MQA, предложенная Ainslie et al. (2023) и применённая в LLaMA-2/3. Компромисс: головы делятся на группы, и внутри каждой группы используется общий набор KV.
Параметры:
num_groups— количество групп- Внутри каждой группы:
num_heads_per_group = total_heads / num_groupsголов используют один KV - Когда
num_groups = 1→ это MQA - Когда
num_groups = total_heads→ это MHA
Пример: LLaMA-2 70B: 64 головы, 8 групп → каждая группа из 8 голов имеет общий KV. KV-кэш уменьшается в 8 раз по сравнению с MHA.
Плюсы:
- Гибкость — можно настроить число групп под hardware и требования к качеству
- Качество выше, чем у MQA, особенно при большом числе групп
- Ускорение почти как у MQA (особенно при небольшом числе групп)
Минусы:
- Выбор числа групп — гиперпараметр, требует экспериментов
- Всё ещё деградация качества относительно MHA, хотя часто незаметна
5. Сравнение MHA, MQA и GQA (таблица)
| Параметр | MHA (стандарт) | MQA | GQA |
|---|---|---|---|
| Число KV-голов | num_heads | 1 | num_groups (обычно 8–64) |
| Размер KV-кэша | num_heads × (seq_len × dim) | 1 × (seq_len × dim) | num_groups × (seq_len × dim) |
| Скорость декодирования | базовая | в 2–10× быстрее | в 1.5–5× быстрее (зависит от групп) |
| Качество (perplexity) | эталонное | заметное снижение | близкое к MHA (отклонение < 0.1-0.3) |
| Применение | обучение, точные задачи | быстрый инференс, small batch | баланс (LLaMA-2/3, Gemma, Mistral) |
| Сложность реализации | стандартная | простая модификация | чуть сложнее, чем MQA |
6. Практические реализации в современных моделях
- LLaMA-1: 65B — MQA (32 Q-головы, 1 KV-голова). Более маленькие версии (7B, 13B) остались MHA.
- LLaMA-2: 70B — GQA с 8 группами. 7B и 13B — MHA (позже LLaMA-3 8B тоже GQA).
- LLaMA-3: все версии (8B, 70B) — GQA с 8 группами (как и LLaMA-2 70B).
- Mistral 7B / Mixtral: GQA с 8 группами (32 Q-головы, 8 KV-голов).
- Falcon 180B: Multi-Query Attention (1 группа).
- Gemma: GQA с 4 группами (от Google).
Тенденция: большинство современных моделей переходят на GQA как стандарт, потому что он даёт почти бесплатное ускорение с минимальной потерей качества.
7. Влияние на инференс и throughput
MQA/GQA особенно выгодны при batch inference (обработка нескольких запросов одновременно). Поскольку KV-кэш уменьшен, можно увеличить batch size в несколько раз, не выходя за лимит памяти GPU. Это прямо повышает throughput (запросов/сек). Для моделей, развёрнутых как API (агенты, chat-боты), это снижает себестоимость.
Пример: для LLaMA-70B с MHA максимальный batch size = 16 на A100 80GB. С GQA (8 групп) — batch size = 64. Пропускная способность растёт ~4× при том же бюджете памяти.
8. Связь с agentic RAG
В сценариях agentic RAG LLM вызывается многократно в рамках одного ответа (планирование, итеративные запросы к инструментам, рефлексия). Каждый вызов генерирует новые токены, накапливая KV-кэш. Если длина контекста большая (сумма всех сообщений истории, ретривальных документов), KV-кэш быстро растёт. Использование GQA критично:
- снижает latency каждого шага (агент быстрее отвечает)
- позволяет держать более длинную историю в кэше (важно для многошаговых рассуждений)
- даёт возможность агрегировать результаты нескольких агентов в одном батче (batch inferences).
9. Когда выбирать MQA, GQA или MHA
| Сценарий | Рекомендованный тип |
|---|---|
| Точность критична (исследования, медицина) | MHA (или GQA с max групп) |
| Высоконагруженный API, много запросов | GQA (8–16 групп) |
| Ограниченная память GPU (мобильные) | MQA |
| Длинный контекст (32k+) | GQA с малым числом групп |
| Обучение модели с нуля | MHA (потом можно fine-tune с GQA) |
На практике GQA — разумный выбор по умолчанию: он не требует тонкой настройки архитектуры обучения (можно обучить с MHA, потом заменить KV-проекции, хотя лучше обучать сразу с GQA).
10. Альтернативы и комплементарные техники
- KV cache quantization — сжатие KV-кэша до 4 бит/2 бит (дополнительная экономия памяти).
- Sliding window attention — ограничение контекста окном (как в Mistral).
- FlashAttention — более эффективная реализация attention, но не уменьшает размер кэша.
- Multi-Latent Attention (MLA) — из DeepSeek-V2: сжимает KV в латентное пространство, ещё лучше, чем GQA, но сложнее.
MQA/GQA можно комбинировать со всеми этими техниками.
Пет-проект для закрепления
Задача: Реализовать упрощённый трансформер для генерации текста с возможностью переключения между MHA, MQA и GQA. Сравнить скорость декодирования и память при разных длинах контекста и batch size.
Инструменты: PyTorch, Hugging Face Transformers (для baseline), CUDA профилирование (nvcc, torch.cuda).
Шаги:
- Возьмите готовую малую модель (например, GPT-2) и замените nn.MultiheadAttention на кастомную реализацию с параметром
num_kv_heads. - Напишите классы:
MultiHeadAttention— по умолчанию (num_kv_heads = num_heads)MultiQueryAttention— num_kv_heads = 1GroupedQueryAttention— num_kv_heads = 4,8 (группировка)
- Измерьте:
- размер KV-кэша (по
element_size()) - время генерации 100 токенов для batch size = 1, 8, 32
- perplexity на тестовом датасете (WikiText-2) для каждого варианта
- размер KV-кэша (по
- Постройте графики: latency vs batch size, perplexity vs количество групп.
Ожидаемый результат: Вы увидите, что MQA и GQA дают выигрыш в скорости при batch > 1, а GQA с 4–8 группами почти не уступает MHA по перплексии.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 274 | Что такое attention и multi-head attention? |
| 275 | Что такое FlashAttention и как оно ускоряет вычисления? |
| 276 | Чем causal attention отличается от bidirectional? |
| 278 | Что такое speculative decoding и как оно ускоряет генерацию? |
| 279 | Как работает multi-step reasoning в LLM? |
| 280 | Что такое KV-кэш и как его оптимизировать? |
Навигация
- Предыдущий: 276
- Следующий: 278
- Индекс: 00. Индекс разборов