Как работает обратное распространение (backpropagation) в трансформере?

Краткий тезис

Обратное распространение (backpropagation) в трансформере — это последовательное применение цепного правила для вычисления градиентов функции потерь по всем обучаемым параметрам. Особенности архитектуры (механизм внимания, остаточные связи, нормализация слоёв) накладывают специфику на прохождение градиентов: через softmax, матричные умножения, SwiGLU и identity mapping. Понимание backprop критично для эффективного обучения, fine-tuning и диагностики проблем вроде затухающих градиентов.


1. Основы обратного распространения

Backpropagation — алгоритм вычисления градиента функции потерь (L) по параметрам модели с помощью цепного правила. Граф вычислений строится в прямом проходе, затем градиенты распространяются от выхода к входу.

Цепное правило: [ \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x} ] где (y = f(x)). Для композиции функций градиент перемножается по всем промежуточным звеньям.

В трансформере граф вычислений содержит:


2. Прямой проход (напоминание архитектуры)

Для одного слоя трансформера (decoder-only, как GPT):

  1. Вход (X) (batch, seq_len, d_model)
  2. LayerNorm → (X_{[text](/wiki/text){ln}})
  3. Multi-Head Self-Attention → (A = [text](/wiki/text){Attention}(Q, K, V))
  4. Остаточная связь → (X' = X + A)
  5. LayerNorm → (X'_{[text](/wiki/text){ln}})
  6. FFN (SwiGLU) → (F = [text](/wiki/text){FFN}(X'_{[text](/wiki/text){ln}}))
  7. Остаточная связь → (X'' = X' + F)

После последнего слоя — LM head (линейный слой без bias) и softmax для получения вероятностей токенов.


3. Градиенты через функцию потерь и softmax

Функция потерь — cross-entropy для одного токена: [ L = -\log p_{[text](/wiki/text){true}}, \quad p = [text](/wiki/text){softmax}(z) ] где (z) — логиты (выход LM head).

Градиент по логитам: [ \frac{\partial L}{\partial z_i} = p_i - y_{[text](/wiki/text){true},i} ] где (y_{[text](/wiki/text){true}}) — one-hot вектор правильного токена. Это классический результат: градиент softmax + cross-entropy упрощается до разности предсказания и истины.

Почему это важно: градиент «расталкивает» логиты: правильный токен получает положительный градиент (увеличиваем вероятность), остальные — отрицательный.


4. Обратное распространение через Multi-Head Self-Attention

Рассмотрим одну голову внимания (без учёта heads, затем суммирование). Для каждого запроса:

[ [text](/wiki/text){Attention}(Q, K, V) = [text](/wiki/text){softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V ]

Обозначим:

  • (S = QK^T / \sqrt{d_k}) — матрица «сырых» весов (scores)
  • (P = [text](/wiki/text){softmax}(S)) — веса внимания (по последней оси)
  • (O = P V) — выход attention

Градиенты распространяются в обратном порядке:

4.1 Градиент по (V):

[ \frac{\partial L}{\partial V} = P^T \cdot \frac{\partial L}{\partial O} ] (матричное умножение: транспонированная матрица весов внимания умножается на градиент выхода).

4.2 Градиент по (P):

[ \frac{\partial L}{\partial P} = \frac{\partial L}{\partial O} \cdot V^T ]

4.3 Градиент через softmax:

Для каждой строки (P_i) (веса для i-го запроса) градиент по (S_i): [ \frac{\partial L}{\partial S_i} = P_i \cdot \left( \frac{\partial L}{\partial P_i} - \left( \frac{\partial L}{\partial P_i} \cdot P_i \right) \mathbf{1}^T \right) ] где (\mathbf{1}) — вектор единиц. Это дифференцирование softmax: ( \frac{\partial P_i}{\partial S_i} = [text](/wiki/text){diag}(P_i) - P_i P_i^T ).

4.4 Градиент по (Q) и (K):

[ \frac{\partial L}{\partial Q} = \frac{1}{\sqrt{d_k}} \cdot \frac{\partial L}{\partial S} \cdot K, \quad \frac{\partial L}{\partial K} = \frac{1}{\sqrt{d_k}} \cdot \left( \frac{\partial L}{\partial S} \right)^T \cdot Q ]

4.5 Объединение голов:

В multi-head attention градиенты по (Q, K, V) усредняются/суммируются по головам, затем проходят через линейные проекции (если они есть). Градиент через линейный слой: (\frac{\partial L}{\partial X} = W^T \cdot \frac{\partial L}{\partial Y}).


5. Обратное распространение через Feed-Forward Network (SwiGLU)

Современные трансформеры (LLaMA, GPT-4) используют SwiGLU: [ [text](/wiki/text){FFN}(x) = [text](/wiki/text){Swish}(x W_1) \odot (x W_2) \cdot W_3 ] где ([text](/wiki/text){Swish}(z) = z \cdot \sigma(z)) (сигмоида), (\odot) — поэлементное умножение.

Градиенты:

  • Для (W_3): (\frac{\partial L}{\partial W_3} = ([text](/wiki/text){Swish}(x W_1) \odot (x W_2))^T \cdot \frac{\partial L}{\partial [text](/wiki/text){out}})
  • Для (W_2): (\frac{\partial L}{\partial (x W_2)} = [text](/wiki/text){Swish}(x W_1) \odot \frac{\partial L}{\partial [text](/wiki/text){out}} \cdot W_3^T), затем (\frac{\partial L}{\partial W_2} = x^T \cdot \frac{\partial L}{\partial (x W_2)})
  • Для (W_1): через производную Swish: (\frac{\partial [text](/wiki/text){Swish}}{\partial z} = \sigma(z) + z \cdot \sigma(z) \cdot (1 - \sigma(z))). Затем цепное правило.

Особенность: SwiGLU нелинейна, но градиенты вычисляются поэлементно, что эффективно.


6. Обратное распространение через остаточные связи (residual connections)

Остаточная связь: (y = x + f(x)). Градиент: [ \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \cdot \left( I + \frac{\partial f}{\partial x} \right) ] На практике это означает, что градиент «течёт» двумя путями:

  • Identity mapping: градиент (\frac{\partial L}{\partial y}) напрямую прибавляется к градиенту входа (x).
  • Через функцию (f): градиент умножается на (\frac{\partial f}{\partial x}).

Почему это помогает: identity mapping позволяет градиентам свободно проходить через глубокие сети, предотвращая затухание. В трансформерах это критично для обучения 100+ слоёв.


7. Обратное распространение через Layer Normalization

LayerNorm для вектора (x): [ \mu = \frac{1}{d}\sum_{i=1}^d x_i, \quad \sigma^2 = \frac{1}{d}\sum_{i=1}^d (x_i - \mu)^2, \quad \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + [epsilon](/wiki/Epsilon)}}, \quad y_i = \gamma_i \hat{x}_i + \beta_i ]

Градиенты:

  • (\frac{\partial L}{\partial \gamma_i} = \frac{\partial L}{\partial y_i} \cdot \hat{x}_i)
  • (\frac{\partial L}{\partial \beta_i} = \frac{\partial L}{\partial y_i})
  • (\frac{\partial L}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \cdot \gamma_i)
  • Далее через (\frac{\partial \hat{x}}{\partial x}) (с учётом (\mu) и (\sigma)) — стандартная формула, требующая аккуратного вычисления.

Важно: LayerNorm применяется перед attention и FFN (pre-norm архитектура). Градиенты через норму «смешивают» все элементы вектора, что стабилизирует обучение.


8. Градиенты через эмбеддинги и позиционные кодировки

Token embeddings: градиент по матрице эмбеддингов (E) — это просто градиент по входу первого слоя, суммированный по позициям, где встречается каждый токен.

Positional embeddings: если они обучаемые (как в BERT), градиент вычисляется аналогично. Если фиксированные (синусоидальные) — градиент не нужен, но они участвуют в прямом проходе, и градиент через сложение просто передаётся дальше.


9. Проблемы и решения при backprop в трансформерах

ПроблемаПричинаРешение
Vanishing gradientsГлубокие сети, насыщение активацийОстаточные связи, LayerNorm, активации без насыщения (SwiGLU, GELU)
Exploding gradientsБольшие значения в attention (softmax)Gradient clipping, AdamW с weight decay
Нестабильность обученияДисперсия градиентов через attentionScale factor (1/\sqrt{d_k}), инициализация (Xavier, small init)
ПамятьХранение промежуточных активаций для backpropGradient checkpointing (recomputation), mixed precision

Gradient clipping: обрезание нормы градиента до максимального значения (например, 1.0) — стандартная практика.


10. Особенности backprop для разных архитектур

АрхитектураОсобенности backprop
Encoder-only (BERT)Двунаправленное внимание — градиенты от всех позиций ко всем. Masking (pad tokens) обнуляет градиенты.
Decoder-only (GPT)Каузальное внимание (только прошлые токены) — градиенты не текут от будущих к прошлым. Маска внимания обнуляет градиенты для запрещённых связей.
Encoder-Decoder (T5)Cross-attention: градиенты от decoder к encoder через ключи и значения. Backprop через обе части.

Cross-attention: градиенты по (K) и (V) из encoder вычисляются так же, как в self-attention, но (K) и (V) — выход encoder, поэтому градиенты передаются в encoder.


11. Практические аспекты: mixed precision и checkpointing

Mixed precision training (FP16/BF16): градиенты вычисляются в половинной точности, но накапливаются в FP32 для стабильности. Это ускоряет backprop и уменьшает память.

Gradient checkpointing: не хранятся все активации, а пересчитываются при обратном проходе. Это trade-off: меньше памяти, больше вычислений.

Recomputation в attention: для экономии памяти можно не хранить матрицу (P) (веса внимания), а пересчитать её из (Q) и (K) при backprop.


12. Связь с fine-tuning: заморозка и LoRA

При fine-tuning часто замораживают часть слоёв (например, все, кроме последних). Градиенты для замороженных параметров не вычисляются (или обнуляются).

LoRA (Low-Rank Adaptation): добавляет обучаемые матрицы (A) и (B) к весам (W). Backprop вычисляет градиенты только по (A) и (B), а исходные веса остаются неизменными. Это резко уменьшает число параметров и память для градиентов.


Пет-проект для закрепления

Задача: Реализовать минимальный однослойный трансформер (decoder-only) с нуля на NumPy и вручную вычислить градиенты через backprop для одного примера. Сравнить с autograd из PyTorch.

Инструменты: Python, NumPy, PyTorch (для проверки).

Шаги:

  1. Реализовать прямой проход: эмбеддинги, LayerNorm, one-head attention, FFN (SwiGLU), остаточные связи, LM head, softmax, cross-entropy.
  2. Вручную реализовать обратный проход для каждого модуля, используя формулы из разбора.
  3. Сравнить численные градиенты (finite differences) с аналитическими.
  4. Запустить PyTorch-версию той же модели и сравнить градиенты.

Ожидаемый результат: Понимание, как каждый компонент влияет на градиенты, и умение отлаживать backprop в реальных моделях.


Связь с другими вопросами

ВопросТема
661Архитектура трансформера (прямой проход)
664Self-attention и его варианты
665Позиционные кодировки
666Layer Normalization
667Оптимизаторы (AdamW)
668Fine-tuning трансформеров
669LoRA и другие методы адаптации
670Mixed precision training

Навигация