English translation is not available yet. Showing Russian content.
Как работает grouped-query attention (GQA) и как trade-off speed/quality?
Краткий тезис
Attention|Grouped-query attention (GQA) — это компромисс между multi-head attention (MHA) (максимум качества, много KV-памяти) и Attention|multi-query attention (MQA) (максимум скорости, но возможна потеря качества). В GQA головы делятся на группы, и каждая группа использует одну общую пару ключей-значений (KV). Параметр n_groups регулирует баланс: при 4 группах качество практически не уступает MHA, а скорость инференса растёт в 3‑5 раз; при 16 группах экономия памяти максимальна, но на сложных задачах возможно снижение качества.
1. Термин: Multi-Head Attention (MHA) — традиционный механизм
Multi-head attention — стандартное внимание в трансформерах. Входные запросы (Q), ключи (K) и значения (V) проецируются несколько раз (каждая голова) с разными весовыми матрицами. Каждая голова независимо вычисляет внимание, результаты конкатенируются и снова проецируются.
Формула одной головы:
head_i = Attention(Q·W_Q_i, K·W_K_i, V·W_V_i)
- W_Q_i, W_K_i, W_V_i — отдельные матрицы проекций для каждой головы.
- Количество голов — гиперпараметр (обычно 8, 16, 32, 64).
Особенность: каждая голова хранит свои KV-кэш (состояния ключей и значений для генерации). При большом числе голов (например, 64) кэш занимает много памяти, что критично при длинных контекстах и batch-инференсе.
2. Проблема: KV-кэш и узкое место памяти
При авторегрессионной генерации (особенно в чат-ботах) Transformer хранит все предыдущие ключи и значения в KV-кэше. Для MHA размер кэша пропорционален num_heads × d_k × seq_len. Например:
- num_heads = 64, d_k = 128,
seq_len = 4096, batch_size = 4 → один слой даёт ≈ 134 МБ. - При 32 слоях → около 4,3 ГБ на один запрос.
С ростом числа голов память растёт линейно, а для больших моделей (Llama-3 70B, 64 головы) кэш становится главным лимитирующим фактором.
3. Multi-Query Attention (MQA) — радикальное решение
Multi-query attention (MQA) — предложен в работе “Fast Transformer Decoding” (Shazeer, 2019). Все головы запросов (Q) используют один и тот же набор ключей (K) и значений (V). То есть проекции K и V — общие для всех голов.
- Плюс: KV-кэш уменьшается в num_heads раз (например, в 64 раза).
- Минус: из-за одного K/V качество на задачах, требующих тонкого различного внимания к разным аспектам, может снижаться. MQA иногда “сглаживает” контекстную информацию.
4. Grouped-Query Attention (GQA) — компромисс
Grouped-query attention (GQA) — развитие MQA, предложенное в статье “GQA: Training Generalized Multi-Query Transformer Models” (Ainslie et al., 2023). Идея: разделить головы на группы, внутри группы — один общий K/V, между группами — разные K/V.
- Пусть всего голов запросов num_heads, а число групп —
g. Тогда число уникальных KV-пар равноg, каждая обслуживает num_heads / g голов. - Для Llama-3: 64 головы, 8 групп → 8 KV-пар (вместо 64).
5. Как работает GQA: механизм и пример
Схема работы
- Проекция запросов: выполняется для каждой головы отдельно (как в MHA).
- Проекция ключей и значений: выполняется только для
gобщих пар (как в MQA). - Разделение на группы: головы распределяются по группам равномерно. Внутри группы все головы используют один и тот же K, V.
- Attention: каждая голова вычисляет attention, используя свой Q и общий для группы K, V.
- Объединение: результаты голов конкатенируются и проецируются.
Пример расчёта:
- num_heads = 8,
g = 4→ 2 головы на группу, 4 KV-пары (экономия в 2 раза по сравнению с MHA). - num_heads = 64,
g = 16→ 4 головы на группу, 16 KV-пар (экономия в 4 раза).
6. Trade-off speed/quality — таблица и анализ
Основные метрики:
| Параметр | MHA | MQA | GQA (g=4) | GQA (g=16) |
|---|---|---|---|---|
| KV-кэш (отн.) | 1x | 1/num_heads | 1/g (g=4 → 1/4) | 1/16 |
| Скорость инференса (batch) | медленный | ~3-5x быстрее MHA | ~2-4x быстрее MHA | ~2-3x быстрее MHA |
| Качество (BLEU/ROUGE) | эталон | небольшое падение | практически идентично MHA | возможна потеря на сложных задачах |
Эмпирические результаты (из статей и практики):
- g=4 (например, Falcon-7B, 64 головы, 4 группы) — почти неотличимо от MHA на большинстве бенчмарков (MMLU, TriviaQA, ROUGE-L).
- g=16 (как в Llama-3) — скорость растёт меньше, но память экономится существенно. При этом на задачах, требующих тонкого детектирования аспектов (например, распознавание сущностей, многозначность), может наблюдаться падение качества на 1-3%.
- g=1 (MQA) — сильная экономия, но потеря качества заметна на задачах с длинными контекстами и сложной логикой.
Торговля между скоростью и качеством регулируется именно параметром g. В больших моделях часто выбирают g = 4 или g = 8 как сбалансированный вариант.
7. Практическая реализация GQA (PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_groups, dropout=0.0):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.num_groups = num_groups
self.head_dim = d_model // num_heads
assert num_heads % num_groups == 0, "heads must be divisible by groups"
self.heads_per_group = num_heads // num_groups
# Проекции Q — для каждой головы (num_heads матриц)
self.W_Q = nn.Linear(d_model, d_model, bias=False)
# Проекции K, V — только для num_groups групп
self.W_K = nn.Linear(d_model, d_model // (num_heads // num_groups), bias=False)
self.W_V = nn.Linear(d_model, d_model // (num_heads // num_groups), bias=False)
# Выходная проекция
self.W_O = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch, seq_len, _ = x.shape
# Размерность общих K,V: d_kv = d_model // (num_heads // num_groups)
d_kv = self.head_dim * self.heads_per_group
# Q: [batch, seq_len, num_heads, head_dim]
q = self.W_Q(x).view(batch, seq_len, self.num_heads, self.head_dim)
# K, V: [batch, seq_len, num_groups, d_kv]
k = self.W_K(x).view(batch, seq_len, self.num_groups, d_kv)
v = self.W_V(x).view(batch, seq_len, self.num_groups, d_kv)
# Расширяем K,V для каждой группы до размеров голов
# [batch, seq_len, num_groups, 1, d_kv] -> [batch, seq_len, num_groups, heads_per_group, d_kv]
k = k.unsqueeze(3).expand(-1, -1, -1, self.heads_per_group, -1)
v = v.unsqueeze(3).expand(-1, -1, -1, self.heads_per_group, -1)
# Затем объединяем группы и головы: [batch, seq_len, num_heads, head_dim]
k = k.reshape(batch, seq_len, self.num_heads, self.head_dim)
v = v.reshape(batch, seq_len, self.num_heads, self.head_dim)
# Применяем attention для каждой головы
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, v)
# Конкатенация и выходной проекция
attn_output = attn_output.contiguous().view(batch, seq_len, -1)
return self.W_O(attn_output)
В этом коде:
- Размерность K,V для каждой группы равна head_dim * heads_per_group. После расширения и reshape каждая голова получает свою часть K,V.
- При batch-инференсе KV-кэш хранит только
num_groupsпар, а не num_heads.
8. Влияние на обучение и инференс
Обучение: GQA не даёт значительного ускорения, так как при обучении важно качество, а кэш не ограничен. Но есть техники up-training — дообучение MHA модели с преобразованием в GQA (например, через проекции K,V).
Инференс: GQA выигрывает за счёт меньшего объёма KV-кэша. Это особенно важно при длинных контекстах (например, 32768 токенов в Llama-3). Уменьшение занимаемой памяти позволяет увеличить batch size на GPU или снизить latency.
Торговля если выбрать слишком мало групп (например, 1 — MQA), качество может упасть; если много групп (близко к MHA) — экономия памяти минимальна. Оптимальный range — от 4 до 8 групп при 32‑64 головах.
9. Эмпирические результаты в известных моделях
- Falcon-7B: 64 головы, 4 группы (GQA-4). Качество на MMLU соответствует MHA.
- Falcon-40B: 128 голов, 8 групп (GQA-8). Скорость инференса в 3-4 раза выше MHA.
- Llama-2: 32 головы, 0 групп (MHA).
- Llama-3 (8B/70B): 64/128 голов, 8 групп (GQA-8). При этом Llama-3 использует GQA для экономии памяти при длинном контексте (до 8K/32K).
- Gemma-2B: 8 голов, 2 группы (GQA-2).
10. Связь с другими механизмами внимания
GQA — не единственная оптимизация. Есть также:
- FlashAttention (алгоритмическая оптимизация IO для attention).
- PagedAttention (vLLM) — управление KV-кэшем с помощью виртуальных страниц.
- Sliding window attention (Mistral) — ограничение окна внимания для снижения памяти.
GQA ортогональна этим методам и может комбинироваться с ними.
11. Пет-проект для закрепления
Задача: Реализовать сравнение MHA, MQA и GQA на простой задаче — обучение трансформера для генерации текста (например, на датасете Tiny Shakespeare) с последующей оценкой качества (perplexity) и замером скорости инференса.
Инструменты:
- PyTorch, Transformers (для baseline).
- Библиотека
timeit,psutilдля измерения памяти. - Датасет:
Tiny Shakespeare(1 МБ текста).
Шаги:
- Реализовать три версии attention (MHA, MQA, GQA).
- Обучить декодер-трансформер с одинаковой размерностью (d_model=256, num_heads=8, для GQA — 4 группы).
- Измерить:
- Perplexity на валидации (качество).
- Время инференса для batch_size=1 и batch_size=16.
- Размер KV-кэша при seq_len=1024.
- Построить таблицу и графики.
Ожидаемый результат:
- MHA: низкая перплексия (~70), медленный инференс, большой кэш.
- MQA: перплексия немного выше (~72), максимальная скорость, минимальный кэш.
- GQA (g=2): перплексия ~71, скорость на 2-3x больше MHA, кэш в 4 раза меньше. Подтверждение trade-off.
12. Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 432 | Как работает multi-head attention (MHA) и зачем он нужен? |
| 433 | Что такое Multi-Query Attention (MQA) и чем он отличается от MHA? |
| 435 | Оптимизация KV-кеша (PagedAttention, FlashAttention) |
| 436 | Как разбивается внимание на группы в Llama-3 и Falcon? |
| 437 | Up-training: как преобразовать MHA в GQA без потери качества? |
| 430 | Архитектура Decoder-only трансформера (Llama-3) |
13. Навигация
- Предыдущий: 433
- Следующий: 435
- Индекс: 00. Индекс разборов
Навигация
- Предыдущий: 433
- Следующий: 435
- Индекс: 00. Индекс разборов