Как работает нормализация перед attention (pre-norm) vs после (post-norm)?
Краткий тезис
Pre-norm и post-norm — два подхода к размещению Layer Normalization внутри блока трансформера. Post-norm (оригинальный Transformer, GPT-2) применяет нормализацию после суммирования остаточных связей, что может вызывать затухание градиентов в глубоких сетях. Pre-norm (GPT-3, Llama) применяет нормализацию перед подуровнями (attention и FFN), что улучшает поток градиентов через identity-связь и позволяет обучать более глубокие модели. Сегодня pre-norm является стандартом для большинства современных LLM.
1. Термин: Layer Normalization (LayerNorm)
LayerNorm — это метод нормализации, который вычисляет среднее и стандартное отклонение по всем нейронам одного слоя (в отличие от BatchNorm, который нормализует по батчу). Формула:
LayerNorm(x) = γ * (x - μ) / √(σ² + ε) + β
где μ и σ — среднее и стандартное отклонение по последней оси (признаки), γ и β — обучаемые параметры сдвига и масштаба, ε — малая константа для численной стабильности.
Зачем нужна LayerNorm в трансформере
- Стабилизирует обучение, предотвращая взрыв или затухание активаций.
- Уменьшает зависимость от инициализации весов.
- Позволяет использовать более высокие learning rates.
В контексте pre-norm vs post-norm LayerNorm применяется к разным частям блока.
2. Термин: Residual connection (остаточная связь)
Residual connection — это архитектурный приём, при котором выход подуровня (например, attention) добавляется к его входу: output = x + F(x). Это позволяет градиентам течь напрямую через identity-путь (обходной путь), что решает проблему затухающих градиентов в глубоких сетях.
В трансформере каждый блок состоит из двух подуровней: Multi-Head Attention и Feed-Forward Network (FFN). Каждый из них окружён residual connection.
3. Post-norm (нормализация после)
Post-norm — это оригинальная архитектура из статьи «Attention is All You Need». Формула для одного подуровня:
output = LayerNorm(x + F(x))
Порядок операций:
- Вход x подаётся на подуровень F (attention или FFN).
- Результат F(x) складывается с x (residual connection).
- К сумме применяется LayerNorm.
Пример: GPT-2 использует post-norm.
Проблемы post-norm
- Градиенты проходят через LayerNorm, которая может их масштабировать и сдвигать, что ухудшает поток через identity-путь.
- В глубоких сетях (более 12–24 слоёв) градиенты могут затухать, требуя специальных техник (warmup, малые learning rates).
- LayerNorm после сложения может «смешивать» информацию из residual и подуровня, что иногда приводит к нестабильности.
Математически градиент по x для post-norm:
∂Loss/∂x = ∂Loss/∂LayerNorm * ∂LayerNorm/∂(x+F(x)) * (1 + ∂F(x)/∂x)
Наличие ∂LayerNorm/∂(x+F(x)) может ослабить градиент, особенно если дисперсия активаций велика.
4. Pre-norm (нормализация перед)
Pre-norm — это модификация, где LayerNorm применяется перед подуровнем. Формула:
output = x + F(LayerNorm(x))
Порядок операций:
- Вход x нормализуется: x_norm = LayerNorm(x).
- Нормализованный вход подаётся на подуровень F.
- Результат F(x_norm) складывается с исходным x (residual connection).
Пример: GPT-3, Llama, Mistral используют pre-norm.
Преимущества pre-norm
- Градиенты текут через identity-путь без изменений (нет LayerNorm на пути residual). Это позволяет обучать модели с сотнями слоёв.
- Более стабильное обучение: не требуется сложная инициализация или warmup.
- LayerNorm перед подуровнем гарантирует, что вход F имеет нулевое среднее и единичную дисперсию, что улучшает сходимость attention и FFN.
Математически градиент по x для pre-norm:
∂Loss/∂x = ∂Loss/∂output * (1 + ∂F(LayerNorm(x))/∂x)
Identity-путь (единица) остаётся нетронутым, что обеспечивает прямой поток градиентов.
5. Сравнение pre-norm vs post-norm
| Аспект | Post-norm | Pre-norm |
|---|---|---|
| Поток градиентов | Ослаблен из-за LayerNorm на residual-пути | Прямой через identity (без нормализации) |
| Стабильность обучения | Требует warmup и малый LR | Более стабилен, можно использовать больший LR |
| Глубина сети | Плохо масштабируется на >24 слоя | Хорошо масштабируется до сотен слоёв |
| Производительность (perplexity) | Может быть лучше при малой глубине | Обычно лучше при большой глубине |
| Популярные модели | Transformer (original), GPT-2 | GPT-3, Llama, Mistral, BLOOM, T5 (вариант) |
| Вычислительные затраты | Одинаково (один LayerNorm на подуровень) | Одинаково |
6. Почему pre-norm стал стандартом
Эмпирические исследования (например, «On Layer Normalization in the Transformer Architecture» от Xiong et al., 2020) показали, что pre-norm:
- Позволяет обучать модели без warmup (или с очень коротким warmup).
- Даёт более низкий loss при одинаковом количестве шагов обучения.
- Улучшает сходимость для глубоких моделей (12+ слоёв).
Кроме того, pre-norm упрощает реализацию: не нужно заботиться о специальной инициализации весов (например, T5 использует pre-norm с инициализацией по умолчанию).
Исключения Некоторые модели (например,早期的 Transformer) используют post-norm, но в современных LLM (GPT-3, Llama, Mistral, Falcon) pre-norm является де-факто стандартом.
7. Варианты pre-norm: Pre-LN и Sandwich-Norm
- Pre-LN (обычный pre-norm): LayerNorm только перед каждым подуровнем. Используется в GPT-3, Llama.
- Sandwich-Norm: дополнительная LayerNorm после последнего residual connection в блоке (т.е. output = LayerNorm(x + F(LayerNorm(x)))). Используется в некоторых моделях (например, T5). Это компромисс: сохраняет прямой градиент через residual, но добавляет нормализацию на выходе блока.
Sandwich-Norm может улучшить стабильность на очень глубоких моделях (100+ слоёв), но добавляет один лишний LayerNorm на блок.
8. Влияние на attention и FFN
Attention: Pre-norm нормализует входные запросы (Q), ключи (K) и значения (V) через LayerNorm. Это уменьшает внутренний ковариатный сдвиг и позволяет attention лучше фокусироваться на релевантных токенах. Post-norm нормализует уже после attention, что может «смазывать» информацию.
FFN Pre-norm стабилизирует активации FFN (обычно ReLU или SwiGLU), предотвращая взрыв значений. Post-norm может приводить к большим вариациям в выходе FFN, что требует более осторожной настройки.
9. Практические рекомендации
- Для новых моделей используйте pre-norm (Pre-LN). Это безопасный выбор, который работает для любой глубины.
- Если глубина < 12 слоёв post-norm может дать slightly better perplexity (как в оригинальном Transformer), но разница обычно незначительна.
- Для очень глубоких моделей (> 48 слоёв): рассмотрите Sandwich-Norm или pre-norm с дополнительной нормализацией на выходе.
- При fine-tuning pre-norm модели легче адаптируются к новым задачам без сбоев.
- При обучении с нуля pre-norm позволяет использовать более высокий learning rate (например, 3e-4 вместо 1e-4) и сократить warmup.
10. Связь с другими техниками
- Warmup post-norm требует обязательного warmup (линейное увеличение LR), pre-norm может обходиться без него.
- Weight initialization post-norm чувствителен к инициализации (например, T5 использует инициализацию с малыми значениями), pre-norm более робастен.
- Gradient clipping pre-norm реже требует clipping, так как градиенты не взрываются.
- Dropout: pre-norm может сочетаться с dropout внутри подуровней без потери стабильности.
Пет-проект для закрепления
Задача Реализовать простой трансформер с возможностью переключения pre/post-norm, обучить на датасете Tiny Shakespeare и сравнить сходимость.
Инструменты PyTorch, Hugging Face Transformers (для baseline), Weights & Biases (для логирования).
Шаги:
- Напишите класс
TransformerBlockс параметромnorm_type('pre'или'post'). - Соберите модель из 6 блоков, embedding-слоя и выходного линейного слоя.
- Загрузите датасет Tiny Shakespeare (текст ~1 MB).
- Обучите две модели (pre и post) с одинаковыми гиперпараметрами (learning rate 3e-4, batch size 32, 10 эпох).
- Логируйте loss на каждом шаге.
- Постройте график loss по шагам для обеих моделей.
Ожидаемый результат Pre-norm модель покажет более быстрое снижение loss и меньшую финальную perplexity. Post-norm может расходиться без warmup (попробуйте обучить без warmup — post-norm, скорее всего, не сойдётся).
Код (фрагмент):
import torch
import torch.nn as nn
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, norm_type='pre'):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, n_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm_type = norm_type
def forward(self, x):
if self.norm_type == 'pre':
# Pre-norm
x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ffn(self.norm2(x))
else:
# Post-norm
x = self.norm1(x + self.attention(x, x, x)[0])
x = self.norm2(x + self.ffn(x))
return x
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 672 | Архитектура трансформера (общая структура блока) |
| 674 | Механизмы внимания (Multi-Head Attention) |
| 675 | Позиционное кодирование (Positional Encoding) |
| 676 | Feed-Forward Network (FFN) в трансформере |
| 677 | Residual connections (остаточные связи) |
| 678 | Layer Normalization (детальный разбор) |
Навигация
- Предыдущий: 672
- Следующий: 674
- Индекс: 00. Индекс разборов