Почему в формуле attention нужно делить на √d_k? Что будет без масштабирования?

Краткий тезис

Масштабирование на √d_k в формуле scaled dot-product attention необходимо для стабилизации градиентов и предотвращения «острых» распределений softmax. Без него дисперсия скалярного произведения q·k растёт пропорционально размерности ключей d_k, что приводит к экстремально малым градиентам (gradients|vanishing gradients) и замедлению обучения. Деление нормализует дисперсию к единице, сохраняя эффективное обучение даже при больших d_k.


1. Формула attention и роль масштабирования

Attention — механизм, позволяющий модели фокусироваться на релевантных частях входной последовательности. В Transformer используется scaled dot-product attention:

Attention(Q, K, V) = softmax( (Q K^T) / √d_k ) V

Где:

  • Q (query), K (key), V (value) — матрицы, полученные из входных эмбеддингов через линейные проекции.
  • d_k — размерность ключей (и запросов).
  • softmax — функция, превращающая логиты в вероятности.

Ключевой элемент — деление на √d_k. Без него формула выглядела бы как softmax(Q K^T) V. Почему это важно?


2. Математика: дисперсия скалярного произведения

Пусть q и k— независимые случайные векторы размерностиd_k с нулевым средним и единичной дисперсией каждой компоненты (типичная инициализация). Тогда скалярное произведение:

s = q·k = Σ_{i=1}^{d_k} q_i * k_i

Каждое слагаемое имеет среднее 0 и дисперсию 1 (произведение двух независимых величин с единичной дисперсией). Сумма d_k слагаемых даёт дисперсию d_k:

Var(s) = d_k

Стандартное отклонение σ = √d_k. Для d_k = 128 типичные значения s лежат в диапазоне ~ ±√128 ≈ ±11.3.


3. Проблема без масштабирования: острый softmax

Softmax от вектора s (логитов) вычисляется как:

softmax(s_i) = exp(s_i) / Σ_j exp(s_j)

Если значения s_i велики по модулю (например, 11.3), экспоненты становятся огромными. Разница между максимальным и остальными логитами резко возрастает → распределение softmax стремится к one-hot (почти 1 на максимальном элементе, почти 0 на остальных).

Последствия

  • Потеря информации: модель не может «смешивать» информацию из разных позиций — внимание становится бинарным.
  • Малые градиенты: производная softmax в областях, где вероятность близка к 0 или 1, стремится к нулю. Градиенты, проходящие через attention, затухают (gradients|vanishing gradients).
  • Замедление обучения: обновления весов становятся крайне малыми, сходимость ухудшается.

4. Vanishing gradients и замедление обучения

Vanishing gradients — проблема, когда градиенты становятся настолько маленькими, что веса практически не обновляются. В контексте attention это проявляется так:

  • При остром softmax большинство весов внимания ≈ 0.
  • Градиент по Q и K для этих позиций ≈ 0.
  • Модель не учится выбирать правильные ключи, так как обратный сигнал не проходит.

Экспериментально показано, что без масштабирования Transformer с d_k > 64 обучается значительно медленнее или вовсе не сходится.


5. Интуиция: почему √d_k?

Деление на √d_k нормализует дисперсию скалярного произведения к 1:

Var( (q·k) / √d_k ) = (1 / d_k) * Var(q·k) = (1 / d_k) * d_k = 1

Теперь типичные значения логитов находятся в диапазоне ~ ±1. Softmax получает «умеренные» входы, распределение остаётся сглаженным (не one-hot), градиенты — ненулевыми.

Выбор √d_k, а не d_k — чтобы стандартное отклонение стало 1, а не дисперсия. Если бы делили на d_k, дисперсия стала бы 1/d_k → логиты слишком маленькие → softmax становится почти равномерным (все вероятности одинаковы), что тоже плохо для обучения.


6. Экспериментальная демонстрация (код Python)

import numpy as np
import matplotlib.pyplot as plt

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

d_k = 128
num_samples = 1000

# Генерируем случайные q и k
q = np.random.randn(num_samples, d_k)
k = np.random.randn(num_samples, d_k)

# Скалярные произведения без масштабирования
s_unscaled = np.sum(q * k, axis=1)

# С масштабированием
s_scaled = s_unscaled / np.sqrt(d_k)

# Применяем softmax к одному примеру (для наглядности)
example_unscaled = s_unscaled[:10]
example_scaled = s_scaled[:10]

print("Softmax без масштабирования:", softmax(example_unscaled))
print("Softmax с масштабированием:", softmax(example_scaled))

Результат: без масштабирования одно значение близко к 1, остальные к 0. С масштабированием — более равномерное распределение.


7. Альтернативные подходы

В некоторых вариантах attention (например, Reformer, Linformer) масштабирование может быть заменено на learnable temperature:

Attention(Q, K, V) = softmax( (Q K^T) / τ ) V

где τ — обучаемый параметр. Однако на практике фиксированное d_k работает стабильно и не требует дополнительных параметров.

Также существуют normalized attention (например, QK-normalization), где перед умножением нормализуют Q и K по норме, но это менее распространено.


8. Связь с multi-head attention

В multi-head attention каждый head работает с подпространством размерности d_k = d_model / h. Для типичных значений (d_model=512, h=8 → d_k=64) масштабирование всё равно необходимо, так как d_k=64 даёт σ=8, что уже приводит к заметному «заострению» softmax. Деление на √64=8 возвращает дисперсию к 1.


9. Пет-проект для закрепления

Задача Исследовать влияние масштабирования на обучение простого Transformer для задачи машинного перевода.

Инструменты PyTorch, небольшой датасет (например, Multi30k), библиотека torch.nn.Transformer.

Шаги:

  1. Реализовать две версии attention: scaled_dot_product_attention с делением на √d_k и без.
  2. Обучить две модели с одинаковыми гиперпараметрами (размерность, число слоёв, learning rate).
  3. Сравнить кривые loss и accuracy на валидации.
  4. Визуализировать распределение весов attention для нескольких примеров (с масштабированием и без).

Ожидаемый результат Модель без масштабирования либо не сойдётся, либо будет иметь значительно более высокий loss и острые attention-карты. Модель с масштабированием покажет гладкие распределения и лучшую сходимость.


10. Связь с другими вопросами

ВопросТема
650Что такое attention и зачем он нужен в Transformer?
651Как работает multi-head attention?
653Что такое positional encoding и зачем оно нужно?
654Почему в Transformer используется LayerNorm, а не BatchNorm?
660Как устроен механизм cross-attention в encoder-decoder?

Навигация