中文翻译暂不可用,显示俄语原文。

Как работает attention математически (Q, K, V) и как вычислительная сложность масштабируется?

Краткий тезис

Attention — это механизм, который позволяет модели «смотреть» на все токены последовательности и взвешивать их важность для каждого текущего токена. Математически он выражается формулой Attention(Q,K,V) = softmax(QK^T / √d_k)V, где Q (Query), K (Key) и V (Value) — это линейные проекции входных данных. Основное узкое место — квадратичная сложность O(n²·d) по длине последовательности n, что делает обработку длинных контекстов дорогой и мотивирует появление оптимизаций (FlashAttention, sparse attention).

1. Термин: Attention (внимание)

Attention — это механизм, который вычисляет взвешенную сумму значений (V), где веса определяются сходством между запросом (Q) и ключами (K). В контексте трансформеров это позволяет каждому токену «общаться» с любым другим токеном в последовательности, преодолевая ограничения рекуррентных сетей.

Термин «Self-attention» (самовнимание) — частный случай, когда Q, K, V получаются из одной и той же последовательности (например, из входных эмбеддингов предложения). Это основа архитектуры Transformer.

2. Математическая формула: Q, K, V

Формула scaled dot-product attention:

Attention(Q, K, V) = softmax( (Q @ K^T) / √d_k ) @ V

Где:

  • Q (Query) — матрица запросов размерности (n, d_k). Для каждого токена мы хотим понять, на какие другие токены обратить внимание.
  • K (Key) — матрица ключей размерности (n, d_k). Каждый токен «предлагает» себя через ключ.
  • V (Value) — матрица значений размерности (n, d_v). Содержит фактическую информацию токена, которая будет передана.
  • n — длина последовательности (количество токенов).
  • d_k — размерность ключей/запросов.
  • d_v — размерность значений (часто равна d_k).
  • @ — матричное умножение.
  • softmax — функция, превращающая оценки в вероятности (сумма по строке = 1).
  • d_k — масштабирующий коэффициент, предотвращающий «взрыв» значений softmax при больших d_k.

Пошаговый разбор

  1. Скалярное произведение Q и K^T: Q @ K^T даёт матрицу (n, n), где элемент (i, j) — это «оценка внимания» токена i к токену j. Чем выше оценка, тем больше токен i будет «смотреть» на токен j.
  2. Масштабирование: Деление на √d_k стабилизирует градиенты. Без этого при больших d_k значения скалярного произведения становятся большими, softmax уходит в режим «почти one-hot» (очень острые вероятности), и градиенты затухают.
  3. Softmax: Применяется к каждой строке матрицы (n, n). Превращает оценки в вероятности (веса внимания), которые в сумме дают 1.
  4. Умножение на V: softmax(...) @ V даёт новую матрицу (n, d_v). Каждая строка — это взвешенная сумма всех V, где веса — это вероятности из softmax. Токен i получает информацию от всех токенов, пропорционально их важности.

Пример на Python (упрощённо):

import numpy as np

def attention(Q, K, V):
    d_k = Q.shape[-1]
    scores = Q @ K.T  # (n, n)
    scaled_scores = scores / np.sqrt(d_k)
    weights = np.exp(scaled_scores) / np.sum(np.exp(scaled_scores), axis=-1, keepdims=True)  # softmax
    output = weights @ V  # (n, d_v)
    return output, weights

# Пример: n=3, d_k=4
Q = np.random.randn(3, 4)
K = np.random.randn(3, 4)
V = np.random.randn(3, 4)

output, weights = attention(Q, K, V)
print("Output shape:", output.shape)  # (3, 4)
print("Attention weights:\n", weights)  # (3, 3)

3. Multi-Head Attention (MHA)

Вместо одного attention используется несколько «голов» (heads). Каждая голова учится фокусироваться на разных аспектах данных (например, синтаксис, семантика, позиция).

Формула MultiHead(Q, K, V) = Concat(head_1, ..., head_h) @ W_O

Где:

  • head_i = Attention(Q @ W_Q_i, K @ W_K_i, V @ W_V_i)
  • W_Q_i, W_K_i, W_V_i — обучаемые матрицы проекций для каждой головы (размерность d_model x d_k).
  • W_O — выходная проекция (размерность h * d_v x d_model).
  • h — количество голов (обычно 8, 16, 32).
  • d_model — размерность модели (например, 768 для BERT-base).

Зачем нужно multi-head Одна голова может усреднять внимание, а несколько голов позволяют модели захватывать сложные, многогранные зависимости.

4. Вычислительная сложность (Quadratic Bottleneck)

Основная операция — матричное умножение Q @ K^T, которое даёт матрицу (n, n). Сложность этой операции: O(n² · d_k).

КомпонентСложностьПримечание
Q @ K^TO(n² · d_k)Квадратичная по n. Главный bottleneck.
softmaxO(n²)Применяется к каждой строке матрицы (n, n).
weights @ VO(n² · d_v)Ещё одно квадратичное умножение.
ИтогоO(n² · d)Где d = d_k ≈ d_v (обычно).

Почему это проблема При n = 1000, матрица внимания имеет 1 млн элементов. При n = 100 000 (например, обработка книги) — 10 млрд элементов. Это требует огромной памяти (O(n²)) и времени.

Термин «Quadratic bottleneck» — фундаментальное ограничение Transformer, которое делает обработку длинных последовательностей (более 8k-16k токенов) непрактичной без специальных оптимизаций.

5. Оптимизации: как бороться с O(n²)

5.1 FlashAttention

  • Идея: Не материализовать матрицу (n, n) в памяти целиком. Вычислять attention блоками (tiling) на GPU, используя быструю, но маленькую память (SRAM).
  • Сложность: O(n² · d) по времени, но O(n) по памяти (не хранит всю матрицу).
  • Результат: Ускорение в 2-4 раза и значительное снижение потребления памяти.

5.2 Sparse Attention

  • Идея: Вычислять attention не для всех пар токенов, а только для выбранных (например, локальное окно + глобальные токены).
  • Примеры: Longformer, BigBird, Sparse Transformers.
  • Сложность: O(n · k · d), где k — константа (размер окна или количество глобальных токенов).

5.3 Linear Attention

  • Идея: Заменить softmax на другую функцию, которая позволяет изменить порядок умножения: (Q @ K^T) @ V = Q @ (K^T @ V). Это даёт сложность O(n · d²) вместо O(n² · d).
  • Примеры: Linear Transformers, Performer (с FAVOR+).
  • Недостаток: Может быть менее точным, чем полный softmax attention.

5.4 KV-Cache (для инференса)

  • Идея: При генерации каждого нового токена не пересчитывать K и V для всех предыдущих токенов, а сохранять их в кэше.
  • Сложность: O(n) на шаг генерации (только для нового токена), но O(n²) для префилла (первого шага).

6. Практические следствия

  • Длина контекста: Модели с полным attention (GPT-4, Claude) обычно имеют контекст 8k-128k токенов. Дальнейшее увеличение требует оптимизаций.
  • Выбор модели: Для задач с длинными документами (анализ книг, код) выбирают модели с sparse attention (Mistral, Mixtral) или FlashAttention.
  • Fine-tuning: При обучении на длинных последовательностях нужно учитывать ограничения памяти GPU. Используют gradient checkpointing и mixed precision.

Пет-проект для закрепления

Задача: Реализовать упрощённый Transformer с attention и сравнить время выполнения для разных длин последовательностей.

Инструменты: Python, PyTorch, numpy, matplotlib.

Шаги:

  1. Реализуйте scaled_dot_product_attention(Q, K, V) с нуля на PyTorch.
  2. Реализуйте MultiHeadAttention с 4 головами.
  3. Создайте синтетические данные: n = [64, 128, 256, 512, 1024], d_model = 256.
  4. Замерьте время выполнения forward pass для каждого n (усредните по 10 запускам).
  5. Постройте график зависимости времени от n. Аппроксимируйте кривую квадратичной функцией t = a * n² + b.
  6. Добавьте FlashAttention (используйте torch.nn.functional.scaled_dot_product_attention с enable_flash=True) и сравните результаты.

Ожидаемый результат: Вы увидите, что время растёт как O(n²) для наивной реализации, и как O(n) или O(n log n) для оптимизированной. График наглядно демонстрирует quadratic bottleneck.

Связь с другими вопросами

ВопросТема
275Архитектура Transformer: encoder-decoder vs decoder-only
277Positional encoding (RoPE, ALiBi)
278KV-cache и его оптимизация
279Sparse attention и Long Context
280FlashAttention и memory-bound операции
281Mixture of Experts (MoE)

Навигация