Как работает attention математически (Q, K, V) и как вычислительная сложность масштабируется?
Краткий тезис
Attention — это механизм, который позволяет модели «смотреть» на все токены последовательности и взвешивать их важность для каждого текущего токена. Математически он выражается формулой Attention(Q,K,V) = softmax(QK^T / √d_k)V, где Q (Query), K (Key) и V (Value) — это линейные проекции входных данных. Основное узкое место — квадратичная сложность O(n²·d) по длине последовательности n, что делает обработку длинных контекстов дорогой и мотивирует появление оптимизаций (FlashAttention, sparse attention).
1. Термин: Attention (внимание)
Attention — это механизм, который вычисляет взвешенную сумму значений (V), где веса определяются сходством между запросом (Q) и ключами (K). В контексте трансформеров это позволяет каждому токену «общаться» с любым другим токеном в последовательности, преодолевая ограничения рекуррентных сетей.
Термин «Self-attention» (самовнимание) — частный случай, когда Q, K, V получаются из одной и той же последовательности (например, из входных эмбеддингов предложения). Это основа архитектуры Transformer.
2. Математическая формула: Q, K, V
Формула scaled dot-product attention:
Attention(Q, K, V) = softmax( (Q @ K^T) / √d_k ) @ V
Где:
- Q (Query) — матрица запросов размерности (n, d_k). Для каждого токена мы хотим понять, на какие другие токены обратить внимание.
- K (Key) — матрица ключей размерности (n, d_k). Каждый токен «предлагает» себя через ключ.
- V (Value) — матрица значений размерности
(n, d_v). Содержит фактическую информацию токена, которая будет передана. - n — длина последовательности (количество токенов).
- d_k — размерность ключей/запросов.
- d_v — размерность значений (часто равна d_k).
- @ — матричное умножение.
- softmax — функция, превращающая оценки в вероятности (сумма по строке = 1).
- √d_k — масштабирующий коэффициент, предотвращающий «взрыв» значений softmax при больших d_k.
Пошаговый разбор
- Скалярное произведение Q и K^T:
Q @ K^Tдаёт матрицу(n, n), где элемент(i, j)— это «оценка внимания» токена i к токену j. Чем выше оценка, тем больше токен i будет «смотреть» на токен j. - Масштабирование: Деление на √d_k стабилизирует градиенты. Без этого при больших d_k значения скалярного произведения становятся большими, softmax уходит в режим «почти one-hot» (очень острые вероятности), и градиенты затухают.
- Softmax: Применяется к каждой строке матрицы
(n, n). Превращает оценки в вероятности (веса внимания), которые в сумме дают 1. - Умножение на V: softmax(...) @ V даёт новую матрицу
(n, d_v). Каждая строка — это взвешенная сумма всех V, где веса — это вероятности из softmax. Токен i получает информацию от всех токенов, пропорционально их важности.
Пример на Python (упрощённо):
import numpy as np
def attention(Q, K, V):
d_k = Q.shape[-1]
scores = Q @ K.T # (n, n)
scaled_scores = scores / np.sqrt(d_k)
weights = np.exp(scaled_scores) / np.sum(np.exp(scaled_scores), axis=-1, keepdims=True) # softmax
output = weights @ V # (n, d_v)
return output, weights
# Пример: n=3, d_k=4
Q = np.random.randn(3, 4)
K = np.random.randn(3, 4)
V = np.random.randn(3, 4)
output, weights = attention(Q, K, V)
print("Output shape:", output.shape) # (3, 4)
print("Attention weights:\n", weights) # (3, 3)
3. Multi-Head Attention (MHA)
Вместо одного attention используется несколько «голов» (heads). Каждая голова учится фокусироваться на разных аспектах данных (например, синтаксис, семантика, позиция).
Формула
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) @ W_O
Где:
- head_i = Attention(Q @ W_Q_i, K @ W_K_i, V @ W_V_i)
W_Q_i, W_K_i, W_V_i— обучаемые матрицы проекций для каждой головы (размерность d_model x d_k).W_O— выходная проекция (размерность h * d_v x d_model).h— количество голов (обычно 8, 16, 32).- d_model — размерность модели (например, 768 для BERT-base).
Зачем нужно multi-head Одна голова может усреднять внимание, а несколько голов позволяют модели захватывать сложные, многогранные зависимости.
4. Вычислительная сложность (Quadratic Bottleneck)
Основная операция — матричное умножение Q @ K^T, которое даёт матрицу (n, n). Сложность этой операции: O(n² · d_k).
| Компонент | Сложность | Примечание |
|---|---|---|
| Q @ K^T | O(n² · d_k) | Квадратичная по n. Главный bottleneck. |
| softmax | O(n²) | Применяется к каждой строке матрицы (n, n). |
| weights @ V | O(n² · d_v) | Ещё одно квадратичное умножение. |
| Итого | O(n² · d) | Где d = d_k ≈ d_v (обычно). |
Почему это проблема При n = 1000, матрица внимания имеет 1 млн элементов. При n = 100 000 (например, обработка книги) — 10 млрд элементов. Это требует огромной памяти (O(n²)) и времени.
Термин «Quadratic bottleneck» — фундаментальное ограничение Transformer, которое делает обработку длинных последовательностей (более 8k-16k токенов) непрактичной без специальных оптимизаций.
5. Оптимизации: как бороться с O(n²)
5.1 FlashAttention
- Идея: Не материализовать матрицу
(n, n)в памяти целиком. Вычислять attention блоками (tiling) на GPU, используя быструю, но маленькую память (SRAM). - Сложность: O(n² · d) по времени, но O(n) по памяти (не хранит всю матрицу).
- Результат: Ускорение в 2-4 раза и значительное снижение потребления памяти.
5.2 Sparse Attention
- Идея: Вычислять attention не для всех пар токенов, а только для выбранных (например, локальное окно + глобальные токены).
- Примеры: Longformer, BigBird, Sparse Transformers.
- Сложность: O(n · k · d), где k — константа (размер окна или количество глобальных токенов).
5.3 Linear Attention
- Идея: Заменить softmax на другую функцию, которая позволяет изменить порядок умножения:
(Q @ K^T) @ V = Q @ (K^T @ V). Это даёт сложность O(n · d²) вместо O(n² · d). - Примеры: Linear Transformers, Performer (с FAVOR+).
- Недостаток: Может быть менее точным, чем полный softmax attention.
5.4 KV-Cache (для инференса)
- Идея: При генерации каждого нового токена не пересчитывать K и V для всех предыдущих токенов, а сохранять их в кэше.
- Сложность: O(n) на шаг генерации (только для нового токена), но O(n²) для префилла (первого шага).
6. Практические следствия
- Длина контекста: Модели с полным attention (GPT-4, Claude) обычно имеют контекст 8k-128k токенов. Дальнейшее увеличение требует оптимизаций.
- Выбор модели: Для задач с длинными документами (анализ книг, код) выбирают модели с sparse attention (Mistral, Mixtral) или FlashAttention.
- Fine-tuning: При обучении на длинных последовательностях нужно учитывать ограничения памяти GPU. Используют gradient checkpointing и mixed precision.
Пет-проект для закрепления
Задача: Реализовать упрощённый Transformer с attention и сравнить время выполнения для разных длин последовательностей.
Инструменты: Python, PyTorch, numpy, matplotlib.
Шаги:
- Реализуйте
scaled_dot_product_attention(Q, K, V)с нуля на PyTorch. - Реализуйте
MultiHeadAttentionс 4 головами. - Создайте синтетические данные:
n = [64, 128, 256, 512, 1024],d_model = 256. - Замерьте время выполнения forward pass для каждого n (усредните по 10 запускам).
- Постройте график зависимости времени от n. Аппроксимируйте кривую квадратичной функцией
t = a * n² + b. - Добавьте FlashAttention (используйте
torch.nn.functional.scaled_dot_product_attentionсenable_flash=True) и сравните результаты.
Ожидаемый результат: Вы увидите, что время растёт как O(n²) для наивной реализации, и как O(n) или O(n log n) для оптимизированной. График наглядно демонстрирует quadratic bottleneck.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 275 | Архитектура Transformer: encoder-decoder vs decoder-only |
| 277 | Positional encoding (RoPE, ALiBi) |
| 278 | KV-cache и его оптимизация |
| 279 | Sparse attention и Long Context |
| 280 | FlashAttention и memory-bound операции |
| 281 | Mixture of Experts (MoE) |
Навигация
- Предыдущий: 275
- Следующий: 277
- Индекс: 00. Индекс разборов