Как работает Multi-query attention (MQA) для long context?
Краткий тезис
Attention|Multi-Query Attention (MQA) — это модификация механизма внимания, в которой все attention heads используют общие Key (K) и Value (V) пары, а Query (Q) остаётся уникальным для каждой головы. Это радикально сокращает размер KV cache (в 8 раз для 8 голов) и ускоряет инференс на длинных контекстах, но ценой небольшого падения качества (5–10% на сложных задачах). MQA применяется в моделях, где критична скорость и память, например в PaLM и Falcon.
1. Термин: Multi-Query Attention (MQA)
Multi-Query Attention — это вариант scaled dot-product attention, предложенный в статье “Fast Transformer Decoding: One Write-Head is All You Need” (2019). В отличие от стандартного Multi-Head Attention (MHA), где каждый head имеет собственные проекции Q, K, V, в MQA все heads разделяют одну и ту же проекцию K и V аQ остаётся индивидуальным.
Ключевые понятия
- Attention head — одна «голова» внимания, которая вычисляет взвешенную сумму значений.
- KV cache — кэш ключей и значений, хранящийся для каждого слоя декодера при авторегрессивной генерации. Размер KV cache растёт линейно с длиной контекста и числом голов.
- Long context — контекст длиной более 4K–8K токенов, где KV cache становится узким местом по памяти.
2. Как работает MQA: общие KV пары
В стандартном MHA для каждого head h из H голов:
Q_h = X * W_q_hK_h = X * W_k_hV_h = X * W_v_h
Где X — входная последовательность, W_*_h — матрицы проекций.
В MQA:
Q_h = X * W_q_h(уникальный для каждого head)K = X * W_k(один на все heads)V = X * W_v(один на все heads)
Таким образом, для каждого слоя хранится только одна пара K и V вместо H пар. Размер KV cache уменьшается в H раз.
Формула внимания для одного head
Attention(Q_h, K, V) = softmax(Q_h * K^T / sqrt(d_k)) * V
Где d_k — размерность ключа (обычно d_model / H).
Визуализация (текстовая):
MHA: [head1: K1,V1] [head2: K2,V2] ... [headH: KH,VH] → H пар KV
MQA: [head1: K,V] [head2: K,V] ... [headH: K,V] → 1 пара KV (общая)
3. Сравнение MHA vs MQA vs GQA
Grouped-Query Attention (GQA) — компромисс: heads делятся на группы, каждая группа имеет общие KV. MQA — частный случай GQA с одной группой.
| Характеристика | MHA | MQA | GQA |
|---|---|---|---|
| Число KV пар | H (по числу heads) | 1 | G (число групп) |
| Размер KV cache | H * L * d_k | 1 * L * d_k | G * L * d_k |
| Экономия памяти (отн. MHA) | 1x | H раз | H/G раз |
| Качество (perplexity) | Эталон | −5..−10% | −1..−3% |
| Скорость инференса | Базовая | Высокая | Средняя |
| Примеры моделей | GPT-3, LLaMA | PaLM, Falcon | LLaMA 2 70B, Mistral |
Когда какая нужна
- MHA — максимальное качество, допустимо много памяти.
- MQA — экстремальная экономия памяти, long context, мобильные устройства.
- GQA — баланс, часто используется в больших моделях (70B+).
4. Преимущества MQA для long context
4.1. Экономия памяти KV cache
При длине контекста L=32K, d_k=128, H=8:
- MHA: 8 * 32K * 128 * 2 байта (float16) ≈ 64 MB на слой.
- MQA:
1 * 32K * 128 * 2 ≈ 8 MBна слой. Для 32 слоёв разница — 2 GB vs 256 MB. Это позволяет обрабатывать контексты в 4–8 раз длиннее на том же GPU.
4.2. Ускорение инференса
- Меньше операций записи/чтения KV cache (bandwidth-bound).
- Возможность использовать batch size больше.
- В autoregressive decoding bottleneck часто — memory bandwidth, а не compute. MQA снижает нагрузку на память.
4.3. Масштабирование на длинные контексты MQA — ключевой компонент в моделях с контекстом 100K+ токенов (например, Falcon-180B использует MQA для эффективности).
5. Недостатки и ухудшение качества
5.1. Потеря выразительности Разделение KV означает, что все heads вынуждены «смотреть» на одни и те же ключи и значения. Это ограничивает способность модели захватывать разные типы отношений (синтаксические, семантические, позиционные) независимо.
5.2. Падение качества на сложных задачах
- Perplexity растёт на 5–10% по сравнению с MHA.
- Особенно заметно на задачах, требующих точного извлечения фактов (например, Multi-document QA, long-range reasoning).
- В некоторых работах (например, “GQA: Training Generalized Multi-Query Transformer Models”) показано, что MQA хуже справляется с long-range dependencies и compositional reasoning.
5.3. Компенсация
- Увеличение числа heads (при фиксированном d_model) может частично восстановить качество, но снижает экономию.
- Использование Grouped-Query Attention (GQA) даёт лучший trade-off.
6. Примеры моделей, использующих MQA
| Модель | Размер | Контекст | Примечание |
|---|---|---|---|
| PaLM (Google, 2022) | 540B | 2048 | Использует MQA для экономии памяти при обучении и инференсе. |
| Falcon (TII, 2023) | 40B, 180B | 2048 (расширяется до 8K) | Falcon-180B — одна из первых открытых моделей с MQA. |
| Gemma (Google, 2024) | 2B, 7B | 8192 | Использует MQA для эффективности на мобильных устройствах. |
| Mamba (альтернатива Transformer) | — | — | Не attention, но тоже решает проблему long context через state space models. |
Почему PaLM и Falcon выбрали MQA?
- PaLM обучалась на 780B токенов; экономия памяти позволила ускорить обучение на 15–20%.
- Falcon-180B — самая большая открытая модель на момент выхода; MQA сделала её инференс возможным на 8×A100 (80GB).
7. Реализация MQA в коде (PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Q — отдельная проекция для каждого head (но можно сделать одну и разделить)
self.W_q = nn.Linear(d_model, d_model) # d_model = num_heads * head_dim
# K и V — общие проекции
self.W_k = nn.Linear(d_model, self.head_dim) # только head_dim, не d_model
self.W_v = nn.Linear(d_model, self.head_dim)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x) # (B, L, d_model)
K = self.W_k(x) # (B, L, head_dim)
V = self.W_v(x) # (B, L, head_dim)
# Разделяем Q на heads
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# K и V — без heads, добавляем размерность heads = 1 для broadcasting
K = K.unsqueeze(1) # (B, 1, L, head_dim)
V = V.unsqueeze(1) # (B, 1, L, head_dim)
# Scaled dot-product attention
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
out = torch.matmul(attn_weights, V) # (B, H, L, head_dim)
# Объединяем heads
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.out_proj(out)
Ключевые моменты
W_kиW_vпроецируют вhead_dim, а не вd_model. Это и даёт экономию.- При инференсе KV cache хранит только один тензор
KиVна слой, а неHтензоров.
8. Когда выбирать MQA
Сценарии, где MQA оправдан:
- Long context generation (суммаризация документов, диалоги с историей).
- Мобильные и edge-устройства (ограниченная память).
- Очень большие модели (70B+), где KV cache становится bottleneck.
- Batch inference с большим batch size (MQA уменьшает memory bandwidth).
Сценарии, где лучше MHA или GQA:
- Высокое качество критично (медицинские, юридические задачи).
- Короткие контексты (до 1K токенов) — экономия незначительна.
- Задачи с точным извлечением (например, ответ на вопрос по одному документу).
Практический совет Если вы обучаете свою модель, начните с MHA, затем попробуйте GQA (4–8 групп). MQA используйте только если memory budget жёстко ограничен.
Пет-проект для закрепления
Задача Реализовать сравнение MHA и MQA на задаче генерации текста с длинным контекстом (например, суммаризация статьи на 10K токенов).
Инструменты
- Python, PyTorch, Hugging Face Transformers (можно взять маленькую модель, например GPT-2, и модифицировать attention).
- Датасет: CNN/DailyMail (суммаризация новостей) или LongBench (бенчмарк для long context).
Шаги:
- Загрузите предобученную модель GPT-2 (она использует MHA).
- Замените её attention слои на MQA (используя код выше).
- Догрузите веса проекций Q (оставьте как есть), а для K и V усредните веса всех heads (или обучите с нуля на маленьком датасете).
- Сравните:
- Perplexity на тестовом наборе.
- Время инференса и пиковое потребление памяти для контекстов разной длины (1K, 4K, 8K).
- Качество суммаризации (ROUGE-1/2/L).
- Постройте графики зависимости времени/памяти от длины контекста.
Ожидаемый результат
- MQA даст в 2–4 раза меньший KV cache и ускорение на 20–40% на длинных контекстах.
- Perplexity вырастет на 3–8%, ROUGE упадёт на 1–3 пункта.
- Вы сможете аргументированно ответить на собеседовании, когда стоит жертвовать качеством ради скорости.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 639 | Как работает KV cache и почему он важен для long context? |
| 641 | Что такое Grouped-Query Attention (GQA) и чем отличается от MQA? |
| 642 | Как работает Sliding Window Attention для long context? |
| 643 | Как работает Flash Attention и как он ускоряет long context? |
| 644 | Какие архитектурные решения (RoPE, ALiBi) улучшают работу с длинными контекстами? |
| 645 | Как вы оцениваете качество модели на длинных контекстах? |
Навигация
- Предыдущий: 639
- Следующий: 641
- Индекс: 00. Индекс разборов