中文翻译暂不可用,显示俄语原文。

Как работает Multi-query attention (MQA) для long context?

Краткий тезис

Attention|Multi-Query Attention (MQA) — это модификация механизма внимания, в которой все attention heads используют общие Key (K) и Value (V) пары, а Query (Q) остаётся уникальным для каждой головы. Это радикально сокращает размер KV cache (в 8 раз для 8 голов) и ускоряет инференс на длинных контекстах, но ценой небольшого падения качества (5–10% на сложных задачах). MQA применяется в моделях, где критична скорость и память, например в PaLM и Falcon.


1. Термин: Multi-Query Attention (MQA)

Multi-Query Attention — это вариант scaled dot-product attention, предложенный в статье “Fast Transformer Decoding: One Write-Head is All You Need” (2019). В отличие от стандартного Multi-Head Attention (MHA), где каждый head имеет собственные проекции Q, K, V, в MQA все heads разделяют одну и ту же проекцию K и V аQ остаётся индивидуальным.

Ключевые понятия

  • Attention head — одна «голова» внимания, которая вычисляет взвешенную сумму значений.
  • KV cache — кэш ключей и значений, хранящийся для каждого слоя декодера при авторегрессивной генерации. Размер KV cache растёт линейно с длиной контекста и числом голов.
  • Long context — контекст длиной более 4K–8K токенов, где KV cache становится узким местом по памяти.

2. Как работает MQA: общие KV пары

В стандартном MHA для каждого head h из H голов:

  • Q_h = X * W_q_h
  • K_h = X * W_k_h
  • V_h = X * W_v_h

Где X — входная последовательность, W_*_h — матрицы проекций.

В MQA:

  • Q_h = X * W_q_h (уникальный для каждого head)
  • K = X * W_k (один на все heads)
  • V = X * W_v (один на все heads)

Таким образом, для каждого слоя хранится только одна пара K и V вместо H пар. Размер KV cache уменьшается в H раз.

Формула внимания для одного head

Attention(Q_h, K, V) = softmax(Q_h * K^T / sqrt(d_k)) * V

Где d_k — размерность ключа (обычно d_model / H).

Визуализация (текстовая):

MHA: [head1: K1,V1] [head2: K2,V2] ... [headH: KH,VH] → H пар KV
MQA: [head1: K,V] [head2: K,V] ... [headH: K,V] → 1 пара KV (общая)

3. Сравнение MHA vs MQA vs GQA

Grouped-Query Attention (GQA) — компромисс: heads делятся на группы, каждая группа имеет общие KV. MQA — частный случай GQA с одной группой.

ХарактеристикаMHAMQAGQA
Число KV парH (по числу heads)1G (число групп)
Размер KV cacheH * L * d_k1 * L * d_kG * L * d_k
Экономия памяти (отн. MHA)1xH разH/G раз
Качество (perplexity)Эталон−5..−10%−1..−3%
Скорость инференсаБазоваяВысокаяСредняя
Примеры моделейGPT-3, LLaMAPaLM, FalconLLaMA 2 70B, Mistral

Когда какая нужна

  • MHA — максимальное качество, допустимо много памяти.
  • MQA — экстремальная экономия памяти, long context, мобильные устройства.
  • GQA — баланс, часто используется в больших моделях (70B+).

4. Преимущества MQA для long context

4.1. Экономия памяти KV cache При длине контекста L=32K, d_k=128, H=8:

  • MHA: 8 * 32K * 128 * 2 байта (float16) ≈ 64 MB на слой.
  • MQA: 1 * 32K * 128 * 2 ≈ 8 MB на слой. Для 32 слоёв разница — 2 GB vs 256 MB. Это позволяет обрабатывать контексты в 4–8 раз длиннее на том же GPU.

4.2. Ускорение инференса

  • Меньше операций записи/чтения KV cache (bandwidth-bound).
  • Возможность использовать batch size больше.
  • В autoregressive decoding bottleneck часто — memory bandwidth, а не compute. MQA снижает нагрузку на память.

4.3. Масштабирование на длинные контексты MQA — ключевой компонент в моделях с контекстом 100K+ токенов (например, Falcon-180B использует MQA для эффективности).


5. Недостатки и ухудшение качества

5.1. Потеря выразительности Разделение KV означает, что все heads вынуждены «смотреть» на одни и те же ключи и значения. Это ограничивает способность модели захватывать разные типы отношений (синтаксические, семантические, позиционные) независимо.

5.2. Падение качества на сложных задачах

  • Perplexity растёт на 5–10% по сравнению с MHA.
  • Особенно заметно на задачах, требующих точного извлечения фактов (например, Multi-document QA, long-range reasoning).
  • В некоторых работах (например, “GQA: Training Generalized Multi-Query Transformer Models”) показано, что MQA хуже справляется с long-range dependencies и compositional reasoning.

5.3. Компенсация

  • Увеличение числа heads (при фиксированном d_model) может частично восстановить качество, но снижает экономию.
  • Использование Grouped-Query Attention (GQA) даёт лучший trade-off.

6. Примеры моделей, использующих MQA

МодельРазмерКонтекстПримечание
PaLM (Google, 2022)540B2048Использует MQA для экономии памяти при обучении и инференсе.
Falcon (TII, 2023)40B, 180B2048 (расширяется до 8K)Falcon-180B — одна из первых открытых моделей с MQA.
Gemma (Google, 2024)2B, 7B8192Использует MQA для эффективности на мобильных устройствах.
Mamba (альтернатива Transformer)Не attention, но тоже решает проблему long context через state space models.

Почему PaLM и Falcon выбрали MQA?

  • PaLM обучалась на 780B токенов; экономия памяти позволила ускорить обучение на 15–20%.
  • Falcon-180B — самая большая открытая модель на момент выхода; MQA сделала её инференс возможным на 8×A100 (80GB).

7. Реализация MQA в коде (PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Q — отдельная проекция для каждого head (но можно сделать одну и разделить)
        self.W_q = nn.Linear(d_model, d_model)  # d_model = num_heads * head_dim
        # K и V — общие проекции
        self.W_k = nn.Linear(d_model, self.head_dim)  # только head_dim, не d_model
        self.W_v = nn.Linear(d_model, self.head_dim)

        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x)  # (B, L, d_model)
        K = self.W_k(x)  # (B, L, head_dim)
        V = self.W_v(x)  # (B, L, head_dim)

        # Разделяем Q на heads
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # K и V — без heads, добавляем размерность heads = 1 для broadcasting
        K = K.unsqueeze(1)  # (B, 1, L, head_dim)
        V = V.unsqueeze(1)  # (B, 1, L, head_dim)

        # Scaled dot-product attention
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        out = torch.matmul(attn_weights, V)  # (B, H, L, head_dim)

        # Объединяем heads
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.out_proj(out)

Ключевые моменты

  • W_k и W_v проецируют в head_dim, а не в d_model. Это и даёт экономию.
  • При инференсе KV cache хранит только один тензор K и V на слой, а не H тензоров.

8. Когда выбирать MQA

Сценарии, где MQA оправдан:

  • Long context generation (суммаризация документов, диалоги с историей).
  • Мобильные и edge-устройства (ограниченная память).
  • Очень большие модели (70B+), где KV cache становится bottleneck.
  • Batch inference с большим batch size (MQA уменьшает memory bandwidth).

Сценарии, где лучше MHA или GQA:

  • Высокое качество критично (медицинские, юридические задачи).
  • Короткие контексты (до 1K токенов) — экономия незначительна.
  • Задачи с точным извлечением (например, ответ на вопрос по одному документу).

Практический совет Если вы обучаете свою модель, начните с MHA, затем попробуйте GQA (4–8 групп). MQA используйте только если memory budget жёстко ограничен.


Пет-проект для закрепления

Задача Реализовать сравнение MHA и MQA на задаче генерации текста с длинным контекстом (например, суммаризация статьи на 10K токенов).

Инструменты

  • Python, PyTorch, Hugging Face Transformers (можно взять маленькую модель, например GPT-2, и модифицировать attention).
  • Датасет: CNN/DailyMail (суммаризация новостей) или LongBench (бенчмарк для long context).

Шаги:

  1. Загрузите предобученную модель GPT-2 (она использует MHA).
  2. Замените её attention слои на MQA (используя код выше).
  3. Догрузите веса проекций Q (оставьте как есть), а для K и V усредните веса всех heads (или обучите с нуля на маленьком датасете).
  4. Сравните:
    • Perplexity на тестовом наборе.
    • Время инференса и пиковое потребление памяти для контекстов разной длины (1K, 4K, 8K).
    • Качество суммаризации (ROUGE-1/2/L).
  5. Постройте графики зависимости времени/памяти от длины контекста.

Ожидаемый результат

  • MQA даст в 2–4 раза меньший KV cache и ускорение на 20–40% на длинных контекстах.
  • Perplexity вырастет на 3–8%, ROUGE упадёт на 1–3 пункта.
  • Вы сможете аргументированно ответить на собеседовании, когда стоит жертвовать качеством ради скорости.

Связь с другими вопросами

ВопросТема
639Как работает KV cache и почему он важен для long context?
641Что такое Grouped-Query Attention (GQA) и чем отличается от MQA?
642Как работает Sliding Window Attention для long context?
643Как работает Flash Attention и как он ускоряет long context?
644Какие архитектурные решения (RoPE, ALiBi) улучшают работу с длинными контекстами?
645Как вы оцениваете качество модели на длинных контекстах?

Навигация