English translation is not available yet. Showing Russian content.

Как работает grouped-query attention (GQA) и как trade-off speed/quality?

Краткий тезис

Attention|Grouped-query attention (GQA) — это компромисс между multi-head attention (MHA) (максимум качества, много KV-памяти) и Attention|multi-query attention (MQA) (максимум скорости, но возможна потеря качества). В GQA головы делятся на группы, и каждая группа использует одну общую пару ключей-значений (KV). Параметр n_groups регулирует баланс: при 4 группах качество практически не уступает MHA, а скорость инференса растёт в 3‑5 раз; при 16 группах экономия памяти максимальна, но на сложных задачах возможно снижение качества.


1. Термин: Multi-Head Attention (MHA) — традиционный механизм

Multi-head attention — стандартное внимание в трансформерах. Входные запросы (Q), ключи (K) и значения (V) проецируются несколько раз (каждая голова) с разными весовыми матрицами. Каждая голова независимо вычисляет внимание, результаты конкатенируются и снова проецируются.

Формула одной головы:

head_i = Attention(Q·W_Q_i, K·W_K_i, V·W_V_i)
  • W_Q_i, W_K_i, W_V_i — отдельные матрицы проекций для каждой головы.
  • Количество голов — гиперпараметр (обычно 8, 16, 32, 64).

Особенность: каждая голова хранит свои KV-кэш (состояния ключей и значений для генерации). При большом числе голов (например, 64) кэш занимает много памяти, что критично при длинных контекстах и batch-инференсе.


2. Проблема: KV-кэш и узкое место памяти

При авторегрессионной генерации (особенно в чат-ботах) Transformer хранит все предыдущие ключи и значения в KV-кэше. Для MHA размер кэша пропорционален num_heads × d_k × seq_len. Например:

  • num_heads = 64, d_k = 128, seq_len = 4096, batch_size = 4 → один слой даёт ≈ 134 МБ.
  • При 32 слоях → около 4,3 ГБ на один запрос.

С ростом числа голов память растёт линейно, а для больших моделей (Llama-3 70B, 64 головы) кэш становится главным лимитирующим фактором.


3. Multi-Query Attention (MQA) — радикальное решение

Multi-query attention (MQA) — предложен в работе “Fast Transformer Decoding” (Shazeer, 2019). Все головы запросов (Q) используют один и тот же набор ключей (K) и значений (V). То есть проекции K и V — общие для всех голов.

  • Плюс: KV-кэш уменьшается в num_heads раз (например, в 64 раза).
  • Минус: из-за одного K/V качество на задачах, требующих тонкого различного внимания к разным аспектам, может снижаться. MQA иногда “сглаживает” контекстную информацию.

4. Grouped-Query Attention (GQA) — компромисс

Grouped-query attention (GQA) — развитие MQA, предложенное в статье “GQA: Training Generalized Multi-Query Transformer Models” (Ainslie et al., 2023). Идея: разделить головы на группы, внутри группы — один общий K/V, между группами — разные K/V.

  • Пусть всего голов запросов num_heads, а число групп — g. Тогда число уникальных KV-пар равно g, каждая обслуживает num_heads / g голов.
  • Для Llama-3: 64 головы, 8 групп → 8 KV-пар (вместо 64).

5. Как работает GQA: механизм и пример

Схема работы

  1. Проекция запросов: выполняется для каждой головы отдельно (как в MHA).
  2. Проекция ключей и значений: выполняется только для g общих пар (как в MQA).
  3. Разделение на группы: головы распределяются по группам равномерно. Внутри группы все головы используют один и тот же K, V.
  4. Attention: каждая голова вычисляет attention, используя свой Q и общий для группы K, V.
  5. Объединение: результаты голов конкатенируются и проецируются.

Пример расчёта:

  • num_heads = 8, g = 4 → 2 головы на группу, 4 KV-пары (экономия в 2 раза по сравнению с MHA).
  • num_heads = 64, g = 16 → 4 головы на группу, 16 KV-пар (экономия в 4 раза).

6. Trade-off speed/quality — таблица и анализ

Основные метрики:

ПараметрMHAMQAGQA (g=4)GQA (g=16)
KV-кэш (отн.)1x1/num_heads1/g (g=4 → 1/4)1/16
Скорость инференса (batch)медленный~3-5x быстрее MHA~2-4x быстрее MHA~2-3x быстрее MHA
Качество (BLEU/ROUGE)эталоннебольшое падениепрактически идентично MHAвозможна потеря на сложных задачах

Эмпирические результаты (из статей и практики):

  • g=4 (например, Falcon-7B, 64 головы, 4 группы) — почти неотличимо от MHA на большинстве бенчмарков (MMLU, TriviaQA, ROUGE-L).
  • g=16 (как в Llama-3) — скорость растёт меньше, но память экономится существенно. При этом на задачах, требующих тонкого детектирования аспектов (например, распознавание сущностей, многозначность), может наблюдаться падение качества на 1-3%.
  • g=1 (MQA) — сильная экономия, но потеря качества заметна на задачах с длинными контекстами и сложной логикой.

Торговля между скоростью и качеством регулируется именно параметром g. В больших моделях часто выбирают g = 4 или g = 8 как сбалансированный вариант.


7. Практическая реализация GQA (PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_groups, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = d_model // num_heads
        assert num_heads % num_groups == 0, "heads must be divisible by groups"
        self.heads_per_group = num_heads // num_groups

        # Проекции Q — для каждой головы (num_heads матриц)
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        # Проекции K, V — только для num_groups групп
        self.W_K = nn.Linear(d_model, d_model // (num_heads // num_groups), bias=False)
        self.W_V = nn.Linear(d_model, d_model // (num_heads // num_groups), bias=False)
        # Выходная проекция
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch, seq_len, _ = x.shape
        # Размерность общих K,V: d_kv = d_model // (num_heads // num_groups)
        d_kv = self.head_dim * self.heads_per_group

        # Q: [batch, seq_len, num_heads, head_dim]
        q = self.W_Q(x).view(batch, seq_len, self.num_heads, self.head_dim)
        # K, V: [batch, seq_len, num_groups, d_kv]
        k = self.W_K(x).view(batch, seq_len, self.num_groups, d_kv)
        v = self.W_V(x).view(batch, seq_len, self.num_groups, d_kv)

        # Расширяем K,V для каждой группы до размеров голов
        # [batch, seq_len, num_groups, 1, d_kv] -> [batch, seq_len, num_groups, heads_per_group, d_kv]
        k = k.unsqueeze(3).expand(-1, -1, -1, self.heads_per_group, -1)
        v = v.unsqueeze(3).expand(-1, -1, -1, self.heads_per_group, -1)
        # Затем объединяем группы и головы: [batch, seq_len, num_heads, head_dim]
        k = k.reshape(batch, seq_len, self.num_heads, self.head_dim)
        v = v.reshape(batch, seq_len, self.num_heads, self.head_dim)

        # Применяем attention для каждой головы
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        attn_output = torch.matmul(attn_weights, v)

        # Конкатенация и выходной проекция
        attn_output = attn_output.contiguous().view(batch, seq_len, -1)
        return self.W_O(attn_output)

В этом коде:

  • Размерность K,V для каждой группы равна head_dim * heads_per_group. После расширения и reshape каждая голова получает свою часть K,V.
  • При batch-инференсе KV-кэш хранит только num_groups пар, а не num_heads.

8. Влияние на обучение и инференс

Обучение: GQA не даёт значительного ускорения, так как при обучении важно качество, а кэш не ограничен. Но есть техники up-training — дообучение MHA модели с преобразованием в GQA (например, через проекции K,V).

Инференс: GQA выигрывает за счёт меньшего объёма KV-кэша. Это особенно важно при длинных контекстах (например, 32768 токенов в Llama-3). Уменьшение занимаемой памяти позволяет увеличить batch size на GPU или снизить latency.

Торговля если выбрать слишком мало групп (например, 1 — MQA), качество может упасть; если много групп (близко к MHA) — экономия памяти минимальна. Оптимальный range — от 4 до 8 групп при 32‑64 головах.


9. Эмпирические результаты в известных моделях

  • Falcon-7B: 64 головы, 4 группы (GQA-4). Качество на MMLU соответствует MHA.
  • Falcon-40B: 128 голов, 8 групп (GQA-8). Скорость инференса в 3-4 раза выше MHA.
  • Llama-2: 32 головы, 0 групп (MHA).
  • Llama-3 (8B/70B): 64/128 голов, 8 групп (GQA-8). При этом Llama-3 использует GQA для экономии памяти при длинном контексте (до 8K/32K).
  • Gemma-2B: 8 голов, 2 группы (GQA-2).

10. Связь с другими механизмами внимания

GQA — не единственная оптимизация. Есть также:

  • FlashAttention (алгоритмическая оптимизация IO для attention).
  • PagedAttention (vLLM) — управление KV-кэшем с помощью виртуальных страниц.
  • Sliding window attention (Mistral) — ограничение окна внимания для снижения памяти.

GQA ортогональна этим методам и может комбинироваться с ними.


11. Пет-проект для закрепления

Задача: Реализовать сравнение MHA, MQA и GQA на простой задаче — обучение трансформера для генерации текста (например, на датасете Tiny Shakespeare) с последующей оценкой качества (perplexity) и замером скорости инференса.

Инструменты:

  • PyTorch, Transformers (для baseline).
  • Библиотека timeit, psutil для измерения памяти.
  • Датасет: Tiny Shakespeare (1 МБ текста).

Шаги:

  1. Реализовать три версии attention (MHA, MQA, GQA).
  2. Обучить декодер-трансформер с одинаковой размерностью (d_model=256, num_heads=8, для GQA — 4 группы).
  3. Измерить:
    • Perplexity на валидации (качество).
    • Время инференса для batch_size=1 и batch_size=16.
    • Размер KV-кэша при seq_len=1024.
  4. Построить таблицу и графики.

Ожидаемый результат:

  • MHA: низкая перплексия (~70), медленный инференс, большой кэш.
  • MQA: перплексия немного выше (~72), максимальная скорость, минимальный кэш.
  • GQA (g=2): перплексия ~71, скорость на 2-3x больше MHA, кэш в 4 раза меньше. Подтверждение trade-off.

12. Связь с другими вопросами

ВопросТема
432Как работает multi-head attention (MHA) и зачем он нужен?
433Что такое Multi-Query Attention (MQA) и чем он отличается от MHA?
435Оптимизация KV-кеша (PagedAttention, FlashAttention)
436Как разбивается внимание на группы в Llama-3 и Falcon?
437Up-training: как преобразовать MHA в GQA без потери качества?
430Архитектура Decoder-only трансформера (Llama-3)

13. Навигация


Навигация