Как работает обратное распространение (backpropagation) в трансформере?
Краткий тезис
Обратное распространение (backpropagation) в трансформере — это последовательное применение цепного правила для вычисления градиентов функции потерь по всем обучаемым параметрам. Особенности архитектуры (механизм внимания, остаточные связи, нормализация слоёв) накладывают специфику на прохождение градиентов: через softmax, матричные умножения, SwiGLU и identity mapping. Понимание backprop критично для эффективного обучения, fine-tuning и диагностики проблем вроде затухающих градиентов.
1. Основы обратного распространения
Backpropagation — алгоритм вычисления градиента функции потерь (L) по параметрам модели с помощью цепного правила. Граф вычислений строится в прямом проходе, затем градиенты распространяются от выхода к входу.
Цепное правило: [ \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x} ] где (y = f(x)). Для композиции функций градиент перемножается по всем промежуточным звеньям.
В трансформере граф вычислений содержит:
- Эмбеддинги (token + positional)
- Multi-Head Self-Attention (MHSA)
- Feed-Forward Network (FFN) с активацией SwiGLU
- Остаточные связи (residual connections)
- Layer Normalization (LayerNorm)
- Функция потерь (обычно cross-entropy)
2. Прямой проход (напоминание архитектуры)
Для одного слоя трансформера (decoder-only, как GPT):
- Вход (X) (batch, seq_len, d_model)
- LayerNorm → (X_{[text](/wiki/text){ln}})
- Multi-Head Self-Attention → (A = [text](/wiki/text){Attention}(Q, K, V))
- Остаточная связь → (X' = X + A)
- LayerNorm → (X'_{[text](/wiki/text){ln}})
- FFN (SwiGLU) → (F = [text](/wiki/text){FFN}(X'_{[text](/wiki/text){ln}}))
- Остаточная связь → (X'' = X' + F)
После последнего слоя — LM head (линейный слой без bias) и softmax для получения вероятностей токенов.
3. Градиенты через функцию потерь и softmax
Функция потерь — cross-entropy для одного токена: [ L = -\log p_{[text](/wiki/text){true}}, \quad p = [text](/wiki/text){softmax}(z) ] где (z) — логиты (выход LM head).
Градиент по логитам: [ \frac{\partial L}{\partial z_i} = p_i - y_{[text](/wiki/text){true},i} ] где (y_{[text](/wiki/text){true}}) — one-hot вектор правильного токена. Это классический результат: градиент softmax + cross-entropy упрощается до разности предсказания и истины.
Почему это важно: градиент «расталкивает» логиты: правильный токен получает положительный градиент (увеличиваем вероятность), остальные — отрицательный.
4. Обратное распространение через Multi-Head Self-Attention
Рассмотрим одну голову внимания (без учёта heads, затем суммирование). Для каждого запроса:
[ [text](/wiki/text){Attention}(Q, K, V) = [text](/wiki/text){softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V ]
Обозначим:
- (S = QK^T / \sqrt{d_k}) — матрица «сырых» весов (scores)
- (P = [text](/wiki/text){softmax}(S)) — веса внимания (по последней оси)
- (O = P V) — выход attention
Градиенты распространяются в обратном порядке:
4.1 Градиент по (V):
[ \frac{\partial L}{\partial V} = P^T \cdot \frac{\partial L}{\partial O} ] (матричное умножение: транспонированная матрица весов внимания умножается на градиент выхода).
4.2 Градиент по (P):
[ \frac{\partial L}{\partial P} = \frac{\partial L}{\partial O} \cdot V^T ]
4.3 Градиент через softmax:
Для каждой строки (P_i) (веса для i-го запроса) градиент по (S_i): [ \frac{\partial L}{\partial S_i} = P_i \cdot \left( \frac{\partial L}{\partial P_i} - \left( \frac{\partial L}{\partial P_i} \cdot P_i \right) \mathbf{1}^T \right) ] где (\mathbf{1}) — вектор единиц. Это дифференцирование softmax: ( \frac{\partial P_i}{\partial S_i} = [text](/wiki/text){diag}(P_i) - P_i P_i^T ).
4.4 Градиент по (Q) и (K):
[ \frac{\partial L}{\partial Q} = \frac{1}{\sqrt{d_k}} \cdot \frac{\partial L}{\partial S} \cdot K, \quad \frac{\partial L}{\partial K} = \frac{1}{\sqrt{d_k}} \cdot \left( \frac{\partial L}{\partial S} \right)^T \cdot Q ]
4.5 Объединение голов:
В multi-head attention градиенты по (Q, K, V) усредняются/суммируются по головам, затем проходят через линейные проекции (если они есть). Градиент через линейный слой: (\frac{\partial L}{\partial X} = W^T \cdot \frac{\partial L}{\partial Y}).
5. Обратное распространение через Feed-Forward Network (SwiGLU)
Современные трансформеры (LLaMA, GPT-4) используют SwiGLU: [ [text](/wiki/text){FFN}(x) = [text](/wiki/text){Swish}(x W_1) \odot (x W_2) \cdot W_3 ] где ([text](/wiki/text){Swish}(z) = z \cdot \sigma(z)) (сигмоида), (\odot) — поэлементное умножение.
Градиенты:
- Для (W_3): (\frac{\partial L}{\partial W_3} = ([text](/wiki/text){Swish}(x W_1) \odot (x W_2))^T \cdot \frac{\partial L}{\partial [text](/wiki/text){out}})
- Для (W_2): (\frac{\partial L}{\partial (x W_2)} = [text](/wiki/text){Swish}(x W_1) \odot \frac{\partial L}{\partial [text](/wiki/text){out}} \cdot W_3^T), затем (\frac{\partial L}{\partial W_2} = x^T \cdot \frac{\partial L}{\partial (x W_2)})
- Для (W_1): через производную Swish: (\frac{\partial [text](/wiki/text){Swish}}{\partial z} = \sigma(z) + z \cdot \sigma(z) \cdot (1 - \sigma(z))). Затем цепное правило.
Особенность: SwiGLU нелинейна, но градиенты вычисляются поэлементно, что эффективно.
6. Обратное распространение через остаточные связи (residual connections)
Остаточная связь: (y = x + f(x)). Градиент: [ \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \cdot \left( I + \frac{\partial f}{\partial x} \right) ] На практике это означает, что градиент «течёт» двумя путями:
- Identity mapping: градиент (\frac{\partial L}{\partial y}) напрямую прибавляется к градиенту входа (x).
- Через функцию (f): градиент умножается на (\frac{\partial f}{\partial x}).
Почему это помогает: identity mapping позволяет градиентам свободно проходить через глубокие сети, предотвращая затухание. В трансформерах это критично для обучения 100+ слоёв.
7. Обратное распространение через Layer Normalization
LayerNorm для вектора (x): [ \mu = \frac{1}{d}\sum_{i=1}^d x_i, \quad \sigma^2 = \frac{1}{d}\sum_{i=1}^d (x_i - \mu)^2, \quad \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + [epsilon](/wiki/Epsilon)}}, \quad y_i = \gamma_i \hat{x}_i + \beta_i ]
Градиенты:
- (\frac{\partial L}{\partial \gamma_i} = \frac{\partial L}{\partial y_i} \cdot \hat{x}_i)
- (\frac{\partial L}{\partial \beta_i} = \frac{\partial L}{\partial y_i})
- (\frac{\partial L}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \cdot \gamma_i)
- Далее через (\frac{\partial \hat{x}}{\partial x}) (с учётом (\mu) и (\sigma)) — стандартная формула, требующая аккуратного вычисления.
Важно: LayerNorm применяется перед attention и FFN (pre-norm архитектура). Градиенты через норму «смешивают» все элементы вектора, что стабилизирует обучение.
8. Градиенты через эмбеддинги и позиционные кодировки
Token embeddings: градиент по матрице эмбеддингов (E) — это просто градиент по входу первого слоя, суммированный по позициям, где встречается каждый токен.
Positional embeddings: если они обучаемые (как в BERT), градиент вычисляется аналогично. Если фиксированные (синусоидальные) — градиент не нужен, но они участвуют в прямом проходе, и градиент через сложение просто передаётся дальше.
9. Проблемы и решения при backprop в трансформерах
| Проблема | Причина | Решение |
|---|---|---|
| Vanishing gradients | Глубокие сети, насыщение активаций | Остаточные связи, LayerNorm, активации без насыщения (SwiGLU, GELU) |
| Exploding gradients | Большие значения в attention (softmax) | Gradient clipping, AdamW с weight decay |
| Нестабильность обучения | Дисперсия градиентов через attention | Scale factor (1/\sqrt{d_k}), инициализация (Xavier, small init) |
| Память | Хранение промежуточных активаций для backprop | Gradient checkpointing (recomputation), mixed precision |
Gradient clipping: обрезание нормы градиента до максимального значения (например, 1.0) — стандартная практика.
10. Особенности backprop для разных архитектур
| Архитектура | Особенности backprop |
|---|---|
| Encoder-only (BERT) | Двунаправленное внимание — градиенты от всех позиций ко всем. Masking (pad tokens) обнуляет градиенты. |
| Decoder-only (GPT) | Каузальное внимание (только прошлые токены) — градиенты не текут от будущих к прошлым. Маска внимания обнуляет градиенты для запрещённых связей. |
| Encoder-Decoder (T5) | Cross-attention: градиенты от decoder к encoder через ключи и значения. Backprop через обе части. |
Cross-attention: градиенты по (K) и (V) из encoder вычисляются так же, как в self-attention, но (K) и (V) — выход encoder, поэтому градиенты передаются в encoder.
11. Практические аспекты: mixed precision и checkpointing
Mixed precision training (FP16/BF16): градиенты вычисляются в половинной точности, но накапливаются в FP32 для стабильности. Это ускоряет backprop и уменьшает память.
Gradient checkpointing: не хранятся все активации, а пересчитываются при обратном проходе. Это trade-off: меньше памяти, больше вычислений.
Recomputation в attention: для экономии памяти можно не хранить матрицу (P) (веса внимания), а пересчитать её из (Q) и (K) при backprop.
12. Связь с fine-tuning: заморозка и LoRA
При fine-tuning часто замораживают часть слоёв (например, все, кроме последних). Градиенты для замороженных параметров не вычисляются (или обнуляются).
LoRA (Low-Rank Adaptation): добавляет обучаемые матрицы (A) и (B) к весам (W). Backprop вычисляет градиенты только по (A) и (B), а исходные веса остаются неизменными. Это резко уменьшает число параметров и память для градиентов.
Пет-проект для закрепления
Задача: Реализовать минимальный однослойный трансформер (decoder-only) с нуля на NumPy и вручную вычислить градиенты через backprop для одного примера. Сравнить с autograd из PyTorch.
Инструменты: Python, NumPy, PyTorch (для проверки).
Шаги:
- Реализовать прямой проход: эмбеддинги, LayerNorm, one-head attention, FFN (SwiGLU), остаточные связи, LM head, softmax, cross-entropy.
- Вручную реализовать обратный проход для каждого модуля, используя формулы из разбора.
- Сравнить численные градиенты (finite differences) с аналитическими.
- Запустить PyTorch-версию той же модели и сравнить градиенты.
Ожидаемый результат: Понимание, как каждый компонент влияет на градиенты, и умение отлаживать backprop в реальных моделях.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 661 | Архитектура трансформера (прямой проход) |
| 664 | Self-attention и его варианты |
| 665 | Позиционные кодировки |
| 666 | Layer Normalization |
| 667 | Оптимизаторы (AdamW) |
| 668 | Fine-tuning трансформеров |
| 669 | LoRA и другие методы адаптации |
| 670 | Mixed precision training |
Навигация
- Предыдущий: 662
- Следующий: 664
- Индекс: 00. Индекс разборов