Как работает нормализация перед attention (pre-norm) vs после (post-norm)?

Краткий тезис

Pre-norm и post-norm — два подхода к размещению Layer Normalization внутри блока трансформера. Post-norm (оригинальный Transformer, GPT-2) применяет нормализацию после суммирования остаточных связей, что может вызывать затухание градиентов в глубоких сетях. Pre-norm (GPT-3, Llama) применяет нормализацию перед подуровнями (attention и FFN), что улучшает поток градиентов через identity-связь и позволяет обучать более глубокие модели. Сегодня pre-norm является стандартом для большинства современных LLM.


1. Термин: Layer Normalization (LayerNorm)

LayerNorm — это метод нормализации, который вычисляет среднее и стандартное отклонение по всем нейронам одного слоя (в отличие от BatchNorm, который нормализует по батчу). Формула:

LayerNorm(x) = γ * (x - μ) / √(σ² + ε) + β

где μ и σ — среднее и стандартное отклонение по последней оси (признаки), γ и β — обучаемые параметры сдвига и масштаба, ε — малая константа для численной стабильности.

Зачем нужна LayerNorm в трансформере

  • Стабилизирует обучение, предотвращая взрыв или затухание активаций.
  • Уменьшает зависимость от инициализации весов.
  • Позволяет использовать более высокие learning rates.

В контексте pre-norm vs post-norm LayerNorm применяется к разным частям блока.


2. Термин: Residual connection (остаточная связь)

Residual connection — это архитектурный приём, при котором выход подуровня (например, attention) добавляется к его входу: output = x + F(x). Это позволяет градиентам течь напрямую через identity-путь (обходной путь), что решает проблему затухающих градиентов в глубоких сетях.

В трансформере каждый блок состоит из двух подуровней: Multi-Head Attention и Feed-Forward Network (FFN). Каждый из них окружён residual connection.


3. Post-norm (нормализация после)

Post-norm — это оригинальная архитектура из статьи «Attention is All You Need». Формула для одного подуровня:

output = LayerNorm(x + F(x))

Порядок операций:

  1. Вход x подаётся на подуровень F (attention или FFN).
  2. Результат F(x) складывается с x (residual connection).
  3. К сумме применяется LayerNorm.

Пример: GPT-2 использует post-norm.

Проблемы post-norm

  • Градиенты проходят через LayerNorm, которая может их масштабировать и сдвигать, что ухудшает поток через identity-путь.
  • В глубоких сетях (более 12–24 слоёв) градиенты могут затухать, требуя специальных техник (warmup, малые learning rates).
  • LayerNorm после сложения может «смешивать» информацию из residual и подуровня, что иногда приводит к нестабильности.

Математически градиент по x для post-norm:

∂Loss/∂x = ∂Loss/∂LayerNorm * ∂LayerNorm/∂(x+F(x)) * (1 + ∂F(x)/∂x)

Наличие ∂LayerNorm/∂(x+F(x)) может ослабить градиент, особенно если дисперсия активаций велика.


4. Pre-norm (нормализация перед)

Pre-norm — это модификация, где LayerNorm применяется перед подуровнем. Формула:

output = x + F(LayerNorm(x))

Порядок операций:

  1. Вход x нормализуется: x_norm = LayerNorm(x).
  2. Нормализованный вход подаётся на подуровень F.
  3. Результат F(x_norm) складывается с исходным x (residual connection).

Пример: GPT-3, Llama, Mistral используют pre-norm.

Преимущества pre-norm

  • Градиенты текут через identity-путь без изменений (нет LayerNorm на пути residual). Это позволяет обучать модели с сотнями слоёв.
  • Более стабильное обучение: не требуется сложная инициализация или warmup.
  • LayerNorm перед подуровнем гарантирует, что вход F имеет нулевое среднее и единичную дисперсию, что улучшает сходимость attention и FFN.

Математически градиент по x для pre-norm:

∂Loss/∂x = ∂Loss/∂output * (1 + ∂F(LayerNorm(x))/∂x)

Identity-путь (единица) остаётся нетронутым, что обеспечивает прямой поток градиентов.


5. Сравнение pre-norm vs post-norm

АспектPost-normPre-norm
Поток градиентовОслаблен из-за LayerNorm на residual-путиПрямой через identity (без нормализации)
Стабильность обученияТребует warmup и малый LRБолее стабилен, можно использовать больший LR
Глубина сетиПлохо масштабируется на >24 слояХорошо масштабируется до сотен слоёв
Производительность (perplexity)Может быть лучше при малой глубинеОбычно лучше при большой глубине
Популярные моделиTransformer (original), GPT-2GPT-3, Llama, Mistral, BLOOM, T5 (вариант)
Вычислительные затратыОдинаково (один LayerNorm на подуровень)Одинаково

6. Почему pre-norm стал стандартом

Эмпирические исследования (например, «On Layer Normalization in the Transformer Architecture» от Xiong et al., 2020) показали, что pre-norm:

  • Позволяет обучать модели без warmup (или с очень коротким warmup).
  • Даёт более низкий loss при одинаковом количестве шагов обучения.
  • Улучшает сходимость для глубоких моделей (12+ слоёв).

Кроме того, pre-norm упрощает реализацию: не нужно заботиться о специальной инициализации весов (например, T5 использует pre-norm с инициализацией по умолчанию).

Исключения Некоторые модели (например,早期的 Transformer) используют post-norm, но в современных LLM (GPT-3, Llama, Mistral, Falcon) pre-norm является де-факто стандартом.


7. Варианты pre-norm: Pre-LN и Sandwich-Norm

  • Pre-LN (обычный pre-norm): LayerNorm только перед каждым подуровнем. Используется в GPT-3, Llama.
  • Sandwich-Norm: дополнительная LayerNorm после последнего residual connection в блоке (т.е. output = LayerNorm(x + F(LayerNorm(x)))). Используется в некоторых моделях (например, T5). Это компромисс: сохраняет прямой градиент через residual, но добавляет нормализацию на выходе блока.

Sandwich-Norm может улучшить стабильность на очень глубоких моделях (100+ слоёв), но добавляет один лишний LayerNorm на блок.


8. Влияние на attention и FFN

Attention: Pre-norm нормализует входные запросы (Q), ключи (K) и значения (V) через LayerNorm. Это уменьшает внутренний ковариатный сдвиг и позволяет attention лучше фокусироваться на релевантных токенах. Post-norm нормализует уже после attention, что может «смазывать» информацию.

FFN Pre-norm стабилизирует активации FFN (обычно ReLU или SwiGLU), предотвращая взрыв значений. Post-norm может приводить к большим вариациям в выходе FFN, что требует более осторожной настройки.


9. Практические рекомендации

  • Для новых моделей используйте pre-norm (Pre-LN). Это безопасный выбор, который работает для любой глубины.
  • Если глубина < 12 слоёв post-norm может дать slightly better perplexity (как в оригинальном Transformer), но разница обычно незначительна.
  • Для очень глубоких моделей (> 48 слоёв): рассмотрите Sandwich-Norm или pre-norm с дополнительной нормализацией на выходе.
  • При fine-tuning pre-norm модели легче адаптируются к новым задачам без сбоев.
  • При обучении с нуля pre-norm позволяет использовать более высокий learning rate (например, 3e-4 вместо 1e-4) и сократить warmup.

10. Связь с другими техниками

  • Warmup post-norm требует обязательного warmup (линейное увеличение LR), pre-norm может обходиться без него.
  • Weight initialization post-norm чувствителен к инициализации (например, T5 использует инициализацию с малыми значениями), pre-norm более робастен.
  • Gradient clipping pre-norm реже требует clipping, так как градиенты не взрываются.
  • Dropout: pre-norm может сочетаться с dropout внутри подуровней без потери стабильности.

Пет-проект для закрепления

Задача Реализовать простой трансформер с возможностью переключения pre/post-norm, обучить на датасете Tiny Shakespeare и сравнить сходимость.

Инструменты PyTorch, Hugging Face Transformers (для baseline), Weights & Biases (для логирования).

Шаги:

  1. Напишите класс TransformerBlock с параметром norm_type ('pre' или 'post').
  2. Соберите модель из 6 блоков, embedding-слоя и выходного линейного слоя.
  3. Загрузите датасет Tiny Shakespeare (текст ~1 MB).
  4. Обучите две модели (pre и post) с одинаковыми гиперпараметрами (learning rate 3e-4, batch size 32, 10 эпох).
  5. Логируйте loss на каждом шаге.
  6. Постройте график loss по шагам для обеих моделей.

Ожидаемый результат Pre-norm модель покажет более быстрое снижение loss и меньшую финальную perplexity. Post-norm может расходиться без warmup (попробуйте обучить без warmup — post-norm, скорее всего, не сойдётся).

Код (фрагмент):

import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, norm_type='pre'):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm_type = norm_type

    def forward(self, x):
        if self.norm_type == 'pre':
            # Pre-norm
            x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
            x = x + self.ffn(self.norm2(x))
        else:
            # Post-norm
            x = self.norm1(x + self.attention(x, x, x)[0])
            x = self.norm2(x + self.ffn(x))
        return x

Связь с другими вопросами

ВопросТема
672Архитектура трансформера (общая структура блока)
674Механизмы внимания (Multi-Head Attention)
675Позиционное кодирование (Positional Encoding)
676Feed-Forward Network (FFN) в трансформере
677Residual connections (остаточные связи)
678Layer Normalization (детальный разбор)

Навигация