Что такое residual connections и зачем они нужны в трансформере?
Краткий тезис
Residual connections (остаточные связи) — это архитектурный приём, при котором выход слоя суммируется с его входом: output = F(x) + x. В трансформере они критически важны для обучения глубоких моделей (100+ слоёв), так как позволяют градиентам «обходить» нелинейные преобразования, решая проблему gradients|vanishing gradients. Без residual connections градиенты экспоненциально затухали бы, и сеть не могла бы обучаться.
1. Определение и формула
Residual connection (также skip connection или shortcut connection) — это соединение, которое передаёт входной сигнал слоя напрямую на его выход, минуя нелинейное преобразование. Математически:
output = F(x) + x
где F(x) — произвольная функция (например, multi-head attention или feed-forward network), а x — вход этого слоя. Сложение происходит поэлементно, поэтому размерности F(x) и x должны совпадать (при необходимости x проектируется через линейный слой).
Термин «остаточная» (residual) отражает, что слой учится предсказывать остаток (разницу) между желаемым выходом и входом: F(x) = output - x. Если оптимальное отображение — тождественное, слою достаточно обнулить F(x).
2. Проблема vanishing gradients в глубоких сетях
В глубоких нейронных сетях градиенты при обратном распространении ошибки проходят через цепочку умножений на производные активаций и весов. Если производные меньше 1 (например, сигмоида или tanh), градиент экспоненциально затухает с глубиной. В результате первые слои почти не получают сигнала обновления, и сеть не может обучаться.
Пример затухания (без residual):
- Градиент на слое
L:∂L/∂w₁ = ∂L/∂h_L · ∏_{i=2}^{L} (∂h_i/∂h_{i-1}) - Если каждый множитель < 0.25, то для 50 слоёв градиент уменьшается в
0.25⁵⁰ ≈ 10⁻³⁰раз.
Термин gradients|vanishing gradients (затухающие градиенты) — явление, при котором градиенты становятся настолько малыми, что веса перестают обновляться. Противоположная проблема — gradients|exploding gradients (взрывающиеся градиенты), когда градиенты растут экспоненциально.
3. Как residual connections решают проблему
Residual connection создаёт обходной путь для градиента. При обратном распространении градиент через output = F(x) + x распадается на две ветви:
∂output / ∂x = ∂F(x)/∂x + 1
Благодаря слагаемому +1 градиент никогда не затухает полностью — даже если ∂F(x)/∂x мала, единица гарантирует прохождение сигнала. Таким образом, градиенты могут «перепрыгивать» через слой F и достигать ранних слоёв.
Интуиция: сеть может обучаться как «мелкая» (используя только shortcut) и постепенно добавлять полезные нелинейные преобразования через F(x). Это позволяет строить модели с сотнями слоёв (например, GPT-3 имеет 96 слоёв, а ResNet — до 152 слоёв).
4. Роль residual connections в архитектуре трансформера
В трансформере residual connections используются в каждом блоке encoder и decoder:
- После multi-head attention: output = LayerNorm(x + Attention(x))
- После feed-forward network: output = LayerNorm(x + FFN(x))
Термин Layer Normalization — нормализация по признакам (а не по батчу), стабилизирует обучение. В современных трансформерах (например, GPT, BERT) часто применяют pre-norm (нормализация перед F(x), а не после), что улучшает сходимость.
Схема блока трансформера:
x → LayerNorm → Multi-Head Attention → + (residual) → LayerNorm → Feed-Forward → + (residual) → выход
Без residual connections каждый блок добавлял бы новые нелинейности, и градиенты быстро затухали бы. С ними градиент может проходить через всю стопку блоков почти без потерь.
5. Связь с layer normalization: pre-norm vs post-norm
В оригинальной статье «Attention is All You Need» использовалась post-norm (нормализация после сложения):
output = LayerNorm(x + Attention(x))
Однако на практике pre-norm (нормализация перед Attention/FFN) оказалась стабильнее для глубоких моделей:
x' = LayerNorm(x)
output = x + Attention(x')
| Тип | Формула | Особенности |
|---|---|---|
| Post-norm | LayerNorm(x + F(x)) | Градиенты проходят через LayerNorm, что может замедлять обучение |
| Pre-norm | x + F(LayerNorm(x)) | LayerNorm не мешает residual shortcut, градиенты лучше текут |
Большинство современных LLM (GPT, LLaMA) используют pre-norm с residual connections.
6. Преимущества residual connections
| Преимущество | Описание |
|---|---|
| Решение vanishing gradients | Градиенты могут обходить нелинейные слои, что позволяет обучать сети с 100+ слоями |
| Ускорение сходимости | Модель быстрее достигает минимума, так как градиенты не затухают |
| Улучшение обобщения | Residual connections действуют как регуляризатор, предотвращая переобучение |
| Возможность строить очень глубокие модели | Без residual connections трансформеры с 12+ слоями практически не обучались бы |
| Сохранение информации о входе | Даже если F(x) зашумляет сигнал, x остаётся нетронутым |
7. Альтернативы и родственные подходы
- Highway networks — обучаемые «вентили», которые регулируют вклад
F(x)иxчерез сигмоиду. - Dense connections (DenseNet) — каждый слой получает на вход все предыдущие выходы (concat, не сумма).
- Gated residual connections — используются в некоторых архитектурах (например, в Transformer-XL) для управления потоком информации.
Термин skip connection — синоним residual connection, чаще используется в контексте свёрточных сетей (ResNet).
8. Пример реализации на PyTorch
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, d_model, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads=8, dropout=dropout)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model),
nn.Dropout(dropout)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Pre-norm residual connection для attention
attn_out, _ = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
x = x + self.dropout(attn_out) # residual
# Pre-norm residual connection для FFN
ffn_out = self.ffn(self.norm2(x))
x = x + self.dropout(ffn_out) # residual
return x
Ключевые моменты:
x + ...— это residual connection.- Dropout применяется только к
F(x), а не к shortcut, чтобы не зашумлять прямой путь. - LayerNorm стоит перед
F(x)(pre-norm).
9. Пет-проект для закрепления
Задача: Сравнить обучение маленького трансформера с residual connections и без них на задаче классификации текста (например, IMDb).
Инструменты: PyTorch, Hugging Face Datasets, Weights & Biases (для логирования).
Шаги:
- Реализовать два варианта encoder-трансформера: один с residual connections (как выше), другой — без них (просто последовательность Attention → FFN без сложения).
- Обучить обе модели на одинаковых данных (например, 2 эпохи, одинаковый learning rate).
- Логировать loss и accuracy на валидации.
- Построить графики: для модели без residual loss не будет уменьшаться (или будет очень медленно), а с residual — сойдётся.
Ожидаемый результат: Модель без residual connections покажет затухание градиентов (можно проверить через torch.autograd.grad — градиенты первых слоев будут близки к нулю). Модель с residual connections обучится успешно.
10. Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 670 | Архитектура трансформера (общая схема) |
| 671 | Self-attention механизм |
| 673 | Layer normalization в трансформере |
| 674 | Positional encoding |
| 675 | Feed-forward network в трансформере |
| 680 | Проблема vanishing/exploding gradients в глубоких сетях |
11. Навигация
- Предыдущий: 671
- Следующий: 673
- Индекс: 00. Индекс разборов
Навигация
- Предыдущий: 671
- Следующий: 673
- Индекс: 00. Индекс разборов