Что такое vanishing / exploding gradients в трансформерах и как их предотвратить?
Краткий тезис
Vanishing (затухание) и exploding (взрыв) градиентов — это проблемы обучения глубоких нейронных сетей, когда градиенты, проходя через множество слоёв, становятся экспоненциально малыми или большими. В трансформерах эти проблемы обостряются из-за глубоких residual-связей и механизма внимания. Основные методы предотвращения: pre-normalization (LayerNorm до, а не после под-слоёв), gradient clipping, правильная инициализация весов (Xavier, Kaiming) и использование residual connections (skip connections).
1. Термин: Vanishing и Exploding Gradients
Vanishing gradients — ситуация, когда градиенты, распространяемые обратно через слои сети, становятся настолько малыми (близкими к нулю), что веса ранних слоёв практически не обновляются. Это приводит к тому, что сеть не может обучаться.
Exploding gradients — противоположная ситуация: градиенты становятся настолько большими, что веса «взрываются» (становятся NaN или бесконечными), что разрушает обучение.
Математическая причина При обратном распространении ошибки (backpropagation) градиент на каждом слое умножается на матрицу весов (или её транспонированную версию). Если собственные значения матрицы весов меньше 1, градиент экспоненциально затухает; если больше 1 — экспоненциально растёт.
2. Причины в трансформерах
Трансформеры особенно подвержены этим проблемам по нескольким причинам:
- Глубина Современные трансформеры (GPT-3, LLaMA) содержат десятки или сотни слоёв. Чем глубже сеть, тем сильнее эффект умножения градиентов.
- Residual connections (skip connections): Хотя они и помогают бороться с затуханием, они также могут усиливать взрыв градиентов, если веса в основном пути слишком велики.
- Softmax attention Механизм внимания включает softmax, который может создавать «острые» распределения (почти one-hot), что приводит к большим градиентам для соответствующих позиций.
- LayerNorm Позиция LayerNorm (до или после под-слоя) критически влияет на стабильность градиентов.
3. Решение 1: Pre-normalization (LayerNorm до под-слоёв)
Pre-normalization — размещение Layer Normalization перед под-слоем (self-attention или FFN), а не после него.
| Тип | Схема | Влияние на градиенты |
|---|---|---|
| Post-LN (оригинальный трансформер) | x -> Attention -> Add -> LayerNorm -> FFN -> Add -> LayerNorm | Градиент проходит через LayerNorm и Add, что может усиливать затухание |
| Pre-LN (современные модели) | x -> LayerNorm -> Attention -> Add -> LayerNorm -> FFN -> Add | Градиент идёт напрямую через residual-связь, минуя LayerNorm, что стабилизирует обучение |
Почему Pre-LN работает В Pre-LN градиент имеет «короткий путь» через residual-связь (identity mapping), который не умножается на веса. Это предотвращает как затухание, так и взрыв. Post-LN вынуждает градиент проходить через LayerNorm, который может масштабировать градиент (особенно на ранних этапах обучения).
Пример кода (PyTorch):
import torch.nn as nn
class PreLNTransformerBlock(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, nhead)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
def forward(self, x):
# Pre-LN: нормируем ДО под-слоя
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ffn(self.norm2(x))
return x
4. Решение 2: Gradient Clipping
Gradient clipping — ограничение нормы градиента (или отдельных значений) перед применением оптимизатора.
Основные методы
- Norm clipping Если ||g|| > threshold, то g = g * (threshold / ||g||). Сохраняет направление градиента, но ограничивает его длину.
- Value clipping Каждый элемент градиента обрезается до [-threshold, [[Вики/confidence score|threshold.
Пример кода
import torch.nn.utils as utils
# Norm clipping (рекомендуемый метод)
utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Value clipping (реже, может искажать направление)
utils.clip_grad_value_(model.parameters(), clip_value=0.5)
Типичные пороги max_norm от 0.5 до 5.0 в зависимости от модели. Для трансформеров часто используют 1.0.
5. Решение 3: Инициализация весов
Правильная инициализация весов предотвращает взрыв/затухание на ранних этапах обучения.
| Метод | Формула | Для каких слоёв |
|---|---|---|
| Xavier (Glorot) | W ~ Uniform(-sqrt(6/(n_in + n_out)), sqrt(6/(n_in + n_out))) | Линейные слои, где активация симметрична (tanh, sigmoid) |
| Kaiming (He) | W ~ Normal(0, sqrt(2/n_in)) | Линейные слои с ReLU |
| Small init | W ~ Normal(0, 0.02) | Эмбеддинги, последние слои (часто в GPT-2) |
Почему это важно Если веса слишком большие, градиенты взрываются; если слишком маленькие — затухают. Инициализация подбирает дисперсию так, чтобы сигнал (и градиент) сохранял свою дисперсию при проходе через слой.
6. Решение 4: Residual Connections (Skip Connections)
Residual connection — добавление входа слоя к его выходу: output = x + F(x). Это создаёт «шоссе» для градиента, который может течь напрямую, минуя умножение на веса.
Формально Если у нас N слоёв, градиент на первом слое:
dL/dx_1 = dL/dx_N * (1 + sum(dF_i/dx_i))
Благодаря единице, градиент никогда не затухает полностью (если не происходит коллапс в -1).
Важное замечание Residual connections не решают проблему взрыва градиентов — они лишь гарантируют, что градиент не затухнет. Для взрыва нужны дополнительные меры (clipping, pre-norm).
7. Решение 5: Мониторинг градиентов
Мониторинг — отслеживание нормы градиента во время обучения для раннего обнаружения проблем.
Что отслеживать
- Total norm
sqrt(sum(g_i^2))— если > 100, вероятен взрыв. - Per-layer norm Если норма градиента на ранних слоях в 1000+ раз меньше, чем на поздних — затухание.
- Ratio update/weight
(lr * ||g||) / ||W||— если > 0.1, возможна нестабильность.
Инструменты TensorBoard, Weights & Biases, простой logging.
Пример логирования
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
print(f"Gradient norm: {total_norm:.4f}")
8. Сравнение методов
| Метод | Эффективность против vanishing | Эффективность против exploding | Сложность внедрения |
|---|---|---|---|
| Pre-normalization | Высокая | Средняя | Низкая (изменить порядок слоёв) |
| Gradient clipping | Низкая | Высокая | Очень низкая (1 строка кода) |
| Инициализация весов | Средняя | Средняя | Низкая (изменить параметры) |
| Residual connections | Высокая | Низкая | Низкая (стандарт в трансформерах) |
| Мониторинг | Обнаружение | Обнаружение | Средняя (добавить логи) |
Рекомендуемый набор Pre-LN + Gradient clipping (norm) + Xavier/Kaiming init + Residual connections обязательны. Мониторинг — для отладки.
Пет-проект для закрепления
Задача Обучить маленький трансформер (2-4 слоя, d_model=128) на синтетических данных (например, задача сложения чисел) и сравнить поведение градиентов при Post-LN vs Pre-LN, с/без gradient clipping.
Инструменты PyTorch, Hugging Face Transformers (или написать свой блок), Weights & Biases для логирования.
Шаги:
- Реализовать два варианта блока трансформера: Post-LN и Pre-LN.
- Обучить на задаче сложения двух 3-значных чисел (seq2seq).
- Логировать норму градиента на каждом шаге для первого и последнего слоя.
- Сравнить: при Post-LN норма градиента на первом слое будет в 10-100 раз меньше, чем на последнем (vanishing). При Pre-LN — примерно одинаковая.
- Добавить gradient clipping (max_norm=1.0) и показать, что exploding градиенты (если они возникают) обрезаются.
Ожидаемый результат Графики нормы градиента по шагам обучения, демонстрирующие стабилизацию при Pre-LN и clipping. Вывод: Pre-LN + clipping — минимальный набор для стабильного обучения.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 665 | Residual connections в трансформерах |
| 666 | Layer Normalization и его роль |
| 667 | Инициализация весов (Xavier, Kaiming) |
| 668 | Механизм attention и его градиенты |
| 669 | Оптимизаторы (Adam, SGD) и их влияние на градиенты |
| 670 | Gradient accumulation и его влияние на стабильность |
Навигация
- Предыдущий: 663
- Следующий: 665
- Индекс: 00. Индекс разборов