Что такое grouped-query attention (GQA) как компромисс для long context?
Краткий тезис
Attention|Grouped-query attention (GQA) — это модификация механизма внимания, при которой несколько query-голов (heads) разделяют одну и ту же пару key-value (KV). Это промежуточное решение между multi-head attention (MHA) (каждая голова имеет свои KV) и Attention|multi-query attention (MQA) (все головы используют одни KV). GQA позволяет значительно сократить размер KV-кэша (в 4–8 раз по сравнению с MHA), что критично для работы с длинными контекстами, при этом сохраняя качество генерации на уровне, близком к MHA (потери 2–5%). Модели вроде Llama-3 (64 heads, 8 групп → 8 KV-пар) и Falcon используют GQA как стандарт для эффективного масштабирования длины контекста.
1. Термины: MHA, MQA, GQA
1.1 Multi-Head Attention (MHA)
В стандартном трансформере (Vaswani et al., 2017) каждый слой внимания имеет H голов. Для каждой головы вычисляются свои Q, K, V через отдельные линейные проекции. Размерность каждой головы: d_k = d_model / H.
Плюс высокая выразительность, каждая голова может фокусироваться на разных паттернах.
Минус KV-кэш растёт линейно с H: 2 * H * L * d_k (L — длина последовательности). Для длинных контекстов (32k+ токенов) это становится узким местом по памяти.
1.2 Multi-Query Attention (MQA)
Предложен в работе Shazeer (2019). Все Hquery-голов используютодин общий набор K и V (одна голова для K, одна для V).
Плюс KV-кэш уменьшается в H раз — огромная экономия памяти.
Минус качество падает сильнее, так как все головы вынуждены использовать одинаковые ключи и значения, что ограничивает разнообразие внимания.
1.3 Grouped-Query Attention (GQA)
Компромисс: H query-голов делятся на G групп (обычно G < H). Каждая группа имеет свои общие K и V. Таким образом, количество KV-пар равно G.
Пример: Llama-3-70B: H=64, G=8 → 8 KV-пар. Коэффициент сжатия KV-кэша = H/G = 8.
Плюс память в 8 раз меньше, чем MHA, качество почти как у MHA (потери 2–5%).
Минус всё ещё немного хуже MHA, но значительно лучше MQA.
2. Почему GQA важен для long context?
Long context (длинный контекст) — способность модели обрабатывать последовательности из десятков или сотен тысяч токенов. Основное ограничение — KV-кэш: на каждом шаге генерации нужно хранить K и V для всех предыдущих токенов. Размер KV-кэша пропорционален длине контекста L и количеству KV-голов.
При L=128k, H=64, d_k=128, FP16:
- MHA: 2 * 64 * 128k * 128 * 2 байта ≈ 4.2 ГБ на слой.
- GQA (G=8): 2 * 8 * 128k * 128 * 2 ≈ 0.52 ГБ на слой.
Экономия в 8 раз позволяет либо увеличить длину контекста, либо уменьшить требования к GPU.
3. Как устроена GQA внутри слоя?
Рассмотрим слой с d_model=4096, H=32, G=8.
- Размерность головы: d_k = 4096 / 32 = 128.
- Query-проекция:
W_Qразмера (d_model, H*d_k) = (4096, 4096). - Key-проекция:
W_Kразмера (d_model, G*d_k) = (4096, 1024). - Value-проекция:
W_Vразмера (d_model, G*d_k) = (4096, 1024).
На выходе K и V имеют форму (batch, G, seq_len, d_k).
Query — (batch, H, seq_len, d_k).
Внимание вычисляется так: для каждой группы g (0..G-1) берутся соответствующие query-головы (их H/G штук) и общие K_g, V_g.
Результаты конкатенируются по голове.
Псевдокод
# shapes: Q: (B, H, L, d_k), K: (B, G, L, d_k), V: (B, G, L, d_k)
# groups: heads_per_group = H // G
output = []
for g in range(G):
q_group = Q[:, g*heads_per_group : (g+1)*heads_per_group, :, :] # (B, hpg, L, d_k)
k = K[:, g, :, :].unsqueeze(1) # (B, 1, L, d_k)
v = V[:, g, :, :].unsqueeze(1) # (B, 1, L, d_k)
attn = softmax(q_group @ k.transpose(-2,-1) / sqrt(d_k)) # (B, hpg, L, L)
out = attn @ v # (B, hpg, L, d_k)
output.append(out)
output = torch.cat(output, dim=1) # (B, H, L, d_k)
4. Сравнение MHA, MQA, GQA
| Характеристика | MHA | MQA | GQA |
|---|---|---|---|
| Количество KV-пар | H | 1 | G (1 < G < H) |
| Размер KV-кэша (отн. MHA) | 1x | 1/H | G/H |
| Качество (perplexity) | эталон | хуже на 5–10% | хуже на 2–5% |
| Скорость инференса | медленнее (больше памяти) | быстрее | компромисс |
| Примеры моделей | GPT-2, BERT | PaLM, Gemini (ранние) | Llama-2/3, Falcon, Mistral |
5. Влияние на качество: почему GQA почти не уступает MHA?
Исследования (Ainslie et al., 2023) показывают, что при достаточном количестве групп (G >= 8) GQA практически не теряет в качестве на задачах понимания языка и генерации. Причина: избыточность в MHA — многие головы учатся похожим паттернам внимания. GQA заставляет группы голов делить KV, что действует как регуляризация и даже может улучшить обобщение.
На практике Llama-3-70B с GQA (G=8) показывает perplexity лишь на 0.1–0.2 хуже, чем гипотетическая MHA-версия, но при этом использует в 8 раз меньше памяти.
6. GQA и KV-кэш при инференсе
При авторегрессивной генерации KV-кэш обновляется на каждом шаге. Для GQA:
- На шаге t добавляем один токен: K_new = (B, G, 1, d_k), V_new = (B, G, 1, d_k).
- Кэш конкатенируется: K_cache = (B, G, t, d_k).
- Размер кэша растёт как 2 * G * t * d_k — линейно, но с малым коэффициентом G.
Это позволяет держать в памяти контексты длиной 128k+ токенов даже на consumer GPU (24 ГБ VRAM).
7. GQA + Quantization: ещё большая экономия
Quantization (квантование) — снижение точности весов и кэша (например, FP16 → INT8 или FP4).
GQA уже уменьшает KV-кэш в H/G раз. Если дополнительно применить KV-кэш quantization (например, до INT8), можно получить ещё 2x экономию.
Итоговое сжатие: (H/G) * 2 (при INT8). Для Llama-3-70B: 8 * 2 = 16x относительно MHA в FP16.
Это делает возможным запуск моделей с контекстом 128k на одной A100 (80 ГБ).
8. Альтернативы GQA для long context
- Sliding window attention (Mistral, Longformer) — ограничивает окно внимания, но теряет глобальный контекст.
- Sparse attention (BigBird, LongNet) — разреженные паттерны, сложнее реализовать.
- FlashAttention — оптимизирует вычисления, но не уменьшает KV-кэш.
- Multi-Query Attention — более агрессивное сжатие, но хуже качество.
GQA остаётся лучшим балансом «качество-память» для long context в современных LLM.
9. Примеры моделей, использующих GQA
| Модель | H | G | Сжатие KV-кэша |
|---|---|---|---|
| Llama-2-70B | 64 | 8 | 8x |
| Llama-3-70B | 64 | 8 | 8x |
| Falcon-40B | 64 | 8 | 8x |
| Mistral-7B | 32 | 8 | 4x |
| Gemma-7B | 16 | 8 | 2x |
10. Реализация GQA в PyTorch (упрощённый слой)
import torch
import torch.nn as nn
import math
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, n_groups):
super().__init__()
assert n_heads % n_groups == 0
self.d_model = d_model
self.n_heads = n_heads
self.n_groups = n_groups
self.head_dim = d_model // n_heads
self.heads_per_group = n_heads // n_groups
self.wq = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(d_model, n_groups * self.head_dim, bias=False)
self.wv = nn.Linear(d_model, n_groups * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, d_model, bias=False)
def forward(self, x):
B, L, _ = x.shape
Q = self.wq(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
K = self.wk(x).view(B, L, self.n_groups, self.head_dim).transpose(1, 2)
V = self.wv(x).view(B, L, self.n_groups, self.head_dim).transpose(1, 2)
# Expand K,V to match number of heads
K = K[:, :, None, :, :].expand(-1, -1, self.heads_per_group, -1, -1).reshape(B, self.n_heads, L, self.head_dim)
V = V[:, :, None, :, :].expand(-1, -1, self.heads_per_group, -1, -1).reshape(B, self.n_heads, L, self.head_dim)
attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, L, -1)
return self.wo(out)
Пет-проект для закрепления
Задача Реализовать сравнение MHA, MQA и GQA на синтетических данных с длинными последовательностями (4096 токенов). Измерить размер KV-кэша, скорость инференса и качество (perplexity на валидационном наборе).
Инструменты PyTorch, Hugging Face Transformers (для baseline), Weights & Biases (логирование).
Шаги:
- Написать три класса внимания: MHA, MQA, GQA (с параметром
n_groups). - Собрать небольшой трансформер (2–4 слоя, d_model=512, H=16).
- Обучить на текстовом датасете (например, WikiText-2) с одинаковыми гиперпараметрами.
- Для каждой архитектуры замерить:
- Размер KV-кэша при генерации (в байтах).
- Время генерации 100 токенов.
- Perplexity на тестовом наборе.
- Построить графики: зависимость perplexity от размера кэша.
Ожидаемый результат
- MHA — лучшая perplexity, но самый большой кэш.
- MQA — худшая perplexity, минимальный кэш.
- GQA (G=4 или G=8) — perplexity близка к MHA, кэш в 2–4 раза меньше.
- Вывод: GQA — оптимальный выбор для long context.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 640 | Long context: проблемы и решения |
| 642 | KV-кэш: устройство и оптимизация |
| 643 | Quantization для LLM |
| 644 | Sliding window attention |
| 645 | Sparse attention |
| 646 | FlashAttention |
Навигация
- Предыдущий: 640
- Следующий: 642
- Индекс: 00. Индекс разборов