Что такое grouped-query attention (GQA) как компромисс для long context?

Краткий тезис

Attention|Grouped-query attention (GQA) — это модификация механизма внимания, при которой несколько query-голов (heads) разделяют одну и ту же пару key-value (KV). Это промежуточное решение между multi-head attention (MHA) (каждая голова имеет свои KV) и Attention|multi-query attention (MQA) (все головы используют одни KV). GQA позволяет значительно сократить размер KV-кэша (в 4–8 раз по сравнению с MHA), что критично для работы с длинными контекстами, при этом сохраняя качество генерации на уровне, близком к MHA (потери 2–5%). Модели вроде Llama-3 (64 heads, 8 групп → 8 KV-пар) и Falcon используют GQA как стандарт для эффективного масштабирования длины контекста.


1. Термины: MHA, MQA, GQA

1.1 Multi-Head Attention (MHA)

В стандартном трансформере (Vaswani et al., 2017) каждый слой внимания имеет H голов. Для каждой головы вычисляются свои Q, K, V через отдельные линейные проекции. Размерность каждой головы: d_k = d_model / H.
Плюс высокая выразительность, каждая голова может фокусироваться на разных паттернах.
Минус KV-кэш растёт линейно с H: 2 * H * L * d_k (L — длина последовательности). Для длинных контекстов (32k+ токенов) это становится узким местом по памяти.

1.2 Multi-Query Attention (MQA)

Предложен в работе Shazeer (2019). Все Hquery-голов используютодин общий набор K и V (одна голова для K, одна для V).
Плюс KV-кэш уменьшается в H раз — огромная экономия памяти.
Минус качество падает сильнее, так как все головы вынуждены использовать одинаковые ключи и значения, что ограничивает разнообразие внимания.

1.3 Grouped-Query Attention (GQA)

Компромисс: H query-голов делятся на G групп (обычно G < H). Каждая группа имеет свои общие K и V. Таким образом, количество KV-пар равно G.
Пример: Llama-3-70B: H=64, G=8 → 8 KV-пар. Коэффициент сжатия KV-кэша = H/G = 8.
Плюс память в 8 раз меньше, чем MHA, качество почти как у MHA (потери 2–5%).
Минус всё ещё немного хуже MHA, но значительно лучше MQA.


2. Почему GQA важен для long context?

Long context (длинный контекст) — способность модели обрабатывать последовательности из десятков или сотен тысяч токенов. Основное ограничение — KV-кэш: на каждом шаге генерации нужно хранить K и V для всех предыдущих токенов. Размер KV-кэша пропорционален длине контекста L и количеству KV-голов.

  • MHA: 2 * H * L * d_k байт (при FP16: 2 байта на элемент).
  • GQA: 2 * G * L * d_k байт.

При L=128k, H=64, d_k=128, FP16:

  • MHA: 2 * 64 * 128k * 128 * 2 байта ≈ 4.2 ГБ на слой.
  • GQA (G=8): 2 * 8 * 128k * 128 * 2 ≈ 0.52 ГБ на слой.
    Экономия в 8 раз позволяет либо увеличить длину контекста, либо уменьшить требования к GPU.

3. Как устроена GQA внутри слоя?

Рассмотрим слой с d_model=4096, H=32, G=8.

  • Размерность головы: d_k = 4096 / 32 = 128.
  • Query-проекция: W_Q размера (d_model, H*d_k) = (4096, 4096).
  • Key-проекция: W_K размера (d_model, G*d_k) = (4096, 1024).
  • Value-проекция: W_V размера (d_model, G*d_k) = (4096, 1024).

На выходе K и V имеют форму (batch, G, seq_len, d_k).
Query — (batch, H, seq_len, d_k).
Внимание вычисляется так: для каждой группы g (0..G-1) берутся соответствующие query-головы (их H/G штук) и общие K_g, V_g.
Результаты конкатенируются по голове.

Псевдокод

# shapes: Q: (B, H, L, d_k), K: (B, G, L, d_k), V: (B, G, L, d_k)
# groups: heads_per_group = H // G
output = []
for g in range(G):
    q_group = Q[:, g*heads_per_group : (g+1)*heads_per_group, :, :]  # (B, hpg, L, d_k)
    k = K[:, g, :, :].unsqueeze(1)   # (B, 1, L, d_k)
    v = V[:, g, :, :].unsqueeze(1)   # (B, 1, L, d_k)
    attn = softmax(q_group @ k.transpose(-2,-1) / sqrt(d_k))  # (B, hpg, L, L)
    out = attn @ v                     # (B, hpg, L, d_k)
    output.append(out)
output = torch.cat(output, dim=1)     # (B, H, L, d_k)

4. Сравнение MHA, MQA, GQA

ХарактеристикаMHAMQAGQA
Количество KV-парH1G (1 < G < H)
Размер KV-кэша (отн. MHA)1x1/HG/H
Качество (perplexity)эталонхуже на 5–10%хуже на 2–5%
Скорость инференсамедленнее (больше памяти)быстреекомпромисс
Примеры моделейGPT-2, BERTPaLM, Gemini (ранние)Llama-2/3, Falcon, Mistral

5. Влияние на качество: почему GQA почти не уступает MHA?

Исследования (Ainslie et al., 2023) показывают, что при достаточном количестве групп (G >= 8) GQA практически не теряет в качестве на задачах понимания языка и генерации. Причина: избыточность в MHA — многие головы учатся похожим паттернам внимания. GQA заставляет группы голов делить KV, что действует как регуляризация и даже может улучшить обобщение.
На практике Llama-3-70B с GQA (G=8) показывает perplexity лишь на 0.1–0.2 хуже, чем гипотетическая MHA-версия, но при этом использует в 8 раз меньше памяти.


6. GQA и KV-кэш при инференсе

При авторегрессивной генерации KV-кэш обновляется на каждом шаге. Для GQA:

  • На шаге t добавляем один токен: K_new = (B, G, 1, d_k), V_new = (B, G, 1, d_k).
  • Кэш конкатенируется: K_cache = (B, G, t, d_k).
  • Размер кэша растёт как 2 * G * t * d_k — линейно, но с малым коэффициентом G.

Это позволяет держать в памяти контексты длиной 128k+ токенов даже на consumer GPU (24 ГБ VRAM).


7. GQA + Quantization: ещё большая экономия

Quantization (квантование) — снижение точности весов и кэша (например, FP16INT8 или FP4).
GQA уже уменьшает KV-кэш в H/G раз. Если дополнительно применить KV-кэш quantization (например, до INT8), можно получить ещё 2x экономию.
Итоговое сжатие: (H/G) * 2 (при INT8). Для Llama-3-70B: 8 * 2 = 16x относительно MHA в FP16.
Это делает возможным запуск моделей с контекстом 128k на одной A100 (80 ГБ).


8. Альтернативы GQA для long context

  • Sliding window attention (Mistral, Longformer) — ограничивает окно внимания, но теряет глобальный контекст.
  • Sparse attention (BigBird, LongNet) — разреженные паттерны, сложнее реализовать.
  • FlashAttention — оптимизирует вычисления, но не уменьшает KV-кэш.
  • Multi-Query Attention — более агрессивное сжатие, но хуже качество.

GQA остаётся лучшим балансом «качество-память» для long context в современных LLM.


9. Примеры моделей, использующих GQA

МодельHGСжатие KV-кэша
Llama-2-70B6488x
Llama-3-70B6488x
Falcon-40B6488x
Mistral-7B3284x
Gemma-7B1682x

10. Реализация GQA в PyTorch (упрощённый слой)

import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_groups):
        super().__init__()
        assert n_heads % n_groups == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_groups = n_groups
        self.head_dim = d_model // n_heads
        self.heads_per_group = n_heads // n_groups

        self.wq = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(d_model, n_groups * self.head_dim, bias=False)
        self.wv = nn.Linear(d_model, n_groups * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, d_model, bias=False)

    def forward(self, x):
        B, L, _ = x.shape
        Q = self.wq(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.wk(x).view(B, L, self.n_groups, self.head_dim).transpose(1, 2)
        V = self.wv(x).view(B, L, self.n_groups, self.head_dim).transpose(1, 2)

        # Expand K,V to match number of heads
        K = K[:, :, None, :, :].expand(-1, -1, self.heads_per_group, -1, -1).reshape(B, self.n_heads, L, self.head_dim)
        V = V[:, :, None, :, :].expand(-1, -1, self.heads_per_group, -1, -1).reshape(B, self.n_heads, L, self.head_dim)

        attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, V)
        out = out.transpose(1, 2).contiguous().view(B, L, -1)
        return self.wo(out)

Пет-проект для закрепления

Задача Реализовать сравнение MHA, MQA и GQA на синтетических данных с длинными последовательностями (4096 токенов). Измерить размер KV-кэша, скорость инференса и качество (perplexity на валидационном наборе).

Инструменты PyTorch, Hugging Face Transformers (для baseline), Weights & Biases (логирование).

Шаги:

  1. Написать три класса внимания: MHA, MQA, GQA (с параметром n_groups).
  2. Собрать небольшой трансформер (2–4 слоя, d_model=512, H=16).
  3. Обучить на текстовом датасете (например, WikiText-2) с одинаковыми гиперпараметрами.
  4. Для каждой архитектуры замерить:
    • Размер KV-кэша при генерации (в байтах).
    • Время генерации 100 токенов.
    • Perplexity на тестовом наборе.
  5. Построить графики: зависимость perplexity от размера кэша.

Ожидаемый результат

  • MHA — лучшая perplexity, но самый большой кэш.
  • MQA — худшая perplexity, минимальный кэш.
  • GQA (G=4 или G=8) — perplexity близка к MHA, кэш в 2–4 раза меньше.
  • Вывод: GQA — оптимальный выбор для long context.

Связь с другими вопросами

ВопросТема
640Long context: проблемы и решения
642KV-кэш: устройство и оптимизация
643Quantization для LLM
644Sliding window attention
645Sparse attention
646FlashAttention

Навигация