English translation is not available yet. Showing Russian content.

Что такое residual connections и зачем они нужны в трансформере?

Краткий тезис

Residual connections (остаточные связи) — это архитектурный приём, при котором выход слоя суммируется с его входом: output = F(x) + x. В трансформере они критически важны для обучения глубоких моделей (100+ слоёв), так как позволяют градиентам «обходить» нелинейные преобразования, решая проблему gradients|vanishing gradients. Без residual connections градиенты экспоненциально затухали бы, и сеть не могла бы обучаться.


1. Определение и формула

Residual connection (также skip connection или shortcut connection) — это соединение, которое передаёт входной сигнал слоя напрямую на его выход, минуя нелинейное преобразование. Математически:

output = F(x) + x

где F(x) — произвольная функция (например, multi-head attention или feed-forward network), а x — вход этого слоя. Сложение происходит поэлементно, поэтому размерности F(x) и x должны совпадать (при необходимости x проектируется через линейный слой).

Термин «остаточная» (residual) отражает, что слой учится предсказывать остаток (разницу) между желаемым выходом и входом: F(x) = output - x. Если оптимальное отображение — тождественное, слою достаточно обнулить F(x).


2. Проблема vanishing gradients в глубоких сетях

В глубоких нейронных сетях градиенты при обратном распространении ошибки проходят через цепочку умножений на производные активаций и весов. Если производные меньше 1 (например, сигмоида или tanh), градиент экспоненциально затухает с глубиной. В результате первые слои почти не получают сигнала обновления, и сеть не может обучаться.

Пример затухания (без residual):

  • Градиент на слое L: ∂L/∂w₁ = ∂L/∂h_L · ∏_{i=2}^{L} (∂h_i/∂h_{i-1})
  • Если каждый множитель < 0.25, то для 50 слоёв градиент уменьшается в 0.25⁵⁰ ≈ 10⁻³⁰ раз.

Термин gradients|vanishing gradients (затухающие градиенты) — явление, при котором градиенты становятся настолько малыми, что веса перестают обновляться. Противоположная проблема — gradients|exploding gradients (взрывающиеся градиенты), когда градиенты растут экспоненциально.


3. Как residual connections решают проблему

Residual connection создаёт обходной путь для градиента. При обратном распространении градиент через output = F(x) + x распадается на две ветви:

∂output / ∂x = ∂F(x)/∂x + 1

Благодаря слагаемому +1 градиент никогда не затухает полностью — даже если ∂F(x)/∂x мала, единица гарантирует прохождение сигнала. Таким образом, градиенты могут «перепрыгивать» через слой F и достигать ранних слоёв.

Интуиция: сеть может обучаться как «мелкая» (используя только shortcut) и постепенно добавлять полезные нелинейные преобразования через F(x). Это позволяет строить модели с сотнями слоёв (например, GPT-3 имеет 96 слоёв, а ResNet — до 152 слоёв).


4. Роль residual connections в архитектуре трансформера

В трансформере residual connections используются в каждом блоке encoder и decoder:

Термин Layer Normalization — нормализация по признакам (а не по батчу), стабилизирует обучение. В современных трансформерах (например, GPT, BERT) часто применяют pre-norm (нормализация перед F(x), а не после), что улучшает сходимость.

Схема блока трансформера:

x → LayerNorm → Multi-Head Attention → + (residual) → LayerNorm → Feed-Forward → + (residual) → выход

Без residual connections каждый блок добавлял бы новые нелинейности, и градиенты быстро затухали бы. С ними градиент может проходить через всю стопку блоков почти без потерь.


5. Связь с layer normalization: pre-norm vs post-norm

В оригинальной статье «Attention is All You Need» использовалась post-norm (нормализация после сложения):

output = LayerNorm(x + Attention(x))

Однако на практике pre-norm (нормализация перед Attention/FFN) оказалась стабильнее для глубоких моделей:

x' = LayerNorm(x)
output = x + Attention(x')
ТипФормулаОсобенности
Post-normLayerNorm(x + F(x))Градиенты проходят через LayerNorm, что может замедлять обучение
Pre-normx + F(LayerNorm(x))LayerNorm не мешает residual shortcut, градиенты лучше текут

Большинство современных LLM (GPT, LLaMA) используют pre-norm с residual connections.


6. Преимущества residual connections

ПреимуществоОписание
Решение vanishing gradientsГрадиенты могут обходить нелинейные слои, что позволяет обучать сети с 100+ слоями
Ускорение сходимостиМодель быстрее достигает минимума, так как градиенты не затухают
Улучшение обобщенияResidual connections действуют как регуляризатор, предотвращая переобучение
Возможность строить очень глубокие моделиБез residual connections трансформеры с 12+ слоями практически не обучались бы
Сохранение информации о входеДаже если F(x) зашумляет сигнал, x остаётся нетронутым

7. Альтернативы и родственные подходы

  • Highway networks — обучаемые «вентили», которые регулируют вклад F(x) и x через сигмоиду.
  • Dense connections (DenseNet) — каждый слой получает на вход все предыдущие выходы (concat, не сумма).
  • Gated residual connections — используются в некоторых архитектурах (например, в Transformer-XL) для управления потоком информации.

Термин skip connection — синоним residual connection, чаще используется в контексте свёрточных сетей (ResNet).


8. Пример реализации на PyTorch

import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, num_heads=8, dropout=dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Pre-norm residual connection для attention
        attn_out, _ = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + self.dropout(attn_out)   # residual

        # Pre-norm residual connection для FFN
        ffn_out = self.ffn(self.norm2(x))
        x = x + self.dropout(ffn_out)    # residual
        return x

Ключевые моменты:

  • x + ... — это residual connection.
  • Dropout применяется только к F(x), а не к shortcut, чтобы не зашумлять прямой путь.
  • LayerNorm стоит перед F(x) (pre-norm).

9. Пет-проект для закрепления

Задача: Сравнить обучение маленького трансформера с residual connections и без них на задаче классификации текста (например, IMDb).

Инструменты: PyTorch, Hugging Face Datasets, Weights & Biases (для логирования).

Шаги:

  1. Реализовать два варианта encoder-трансформера: один с residual connections (как выше), другой — без них (просто последовательность AttentionFFN без сложения).
  2. Обучить обе модели на одинаковых данных (например, 2 эпохи, одинаковый learning rate).
  3. Логировать loss и accuracy на валидации.
  4. Построить графики: для модели без residual loss не будет уменьшаться (или будет очень медленно), а с residual — сойдётся.

Ожидаемый результат: Модель без residual connections покажет затухание градиентов (можно проверить через torch.autograd.grad — градиенты первых слоев будут близки к нулю). Модель с residual connections обучится успешно.


10. Связь с другими вопросами

ВопросТема
670Архитектура трансформера (общая схема)
671Self-attention механизм
673Layer normalization в трансформере
674Positional encoding
675Feed-forward network в трансформере
680Проблема vanishing/exploding gradients в глубоких сетях

11. Навигация


Навигация