Как работает LayerNorm и RMSNorm? В чем разница и почему RMSNorm быстрее?
Краткий тезис
Layer Normalization (LayerNorm) и Root Mean Square Normalization (RMSNorm) — это методы нормализации скрытых состояний в нейронных сетях. LayerNorm вычитает среднее и делит на стандартное отклонение, что требует вычисления двух статистик. RMSNorm упрощает процесс, используя только среднеквадратичное значение (RMS) без центрирования, что даёт прирост скорости на 10–15% при сопоставимом качестве. RMSNorm стал стандартом в современных LLM (Llama, Mistral) благодаря эффективности и простоте.
1. Зачем нужна нормализация в нейронных сетях
Нормализация стабилизирует обучение, уменьшая внутренний ковариатный сдвиг (internal covariate shift) — изменение распределения активаций слоёв в процессе обучения. Без нормализации градиенты могут становиться слишком большими или маленькими, что замедляет сходимость или приводит к расходимости. Основные подходы:
- Batch Normalization — нормализация по батчу (неудобна для RNN/трансформеров из-за зависимости от длины последовательности).
- Layer Normalization — нормализация по признакам одного токена (независима от батча, подходит для трансформеров).
- RMSNorm — упрощённая версия LayerNorm без центрирования.
В трансформерах нормализация применяется после каждого подуровня (self-attention, FFN) — это Post-LN (оригинальный Transformer) или перед подуровнем — Pre-LN (современные модели, стабильнее).
2. Layer Normalization: математика и интуиция
2.1 Формула
Для входного вектора ( x \in \mathbb{R}^d ) (state|скрытое состояние одного токена) LayerNorm вычисляет:
[ \mu = \frac{1}{d} \sum_{i=1}^d x_i ] [ \sigma^2 = \frac{1}{d} \sum_{i=1}^d (x_i - \mu)^2 ] [ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + [epsilon](/wiki/Epsilon)}} ] [ y_i = \gamma_i \hat{x}_i + \beta_i ]
Где:
- (\mu) — среднее, (\sigma^2) — дисперсия.
- ([epsilon](/wiki/Epsilon)) — малая константа для численной стабильности (обычно (10^{-5})).
- (\gamma, [beta](/wiki/beta) \in \mathbb{R}^d) — обучаемые параметры сдвига и масштаба (affine transformation).
2.2 Интуиция
LayerNorm делает распределение активаций для каждого токена стандартным нормальным (среднее 0, дисперсия 1), после чего позволяет модели выучить оптимальный сдвиг и масштаб через (\gamma, [beta](/wiki/beta)). Это устраняет зависимость от масштаба входов и стабилизирует градиенты.
2.3 Реализация на Python (упрощённо)
import torch
import torch.nn as nn
def layer_norm(x, gamma, beta, eps=1e-5):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + eps)
return gamma * x_norm + beta
# Пример
x = torch.randn(2, 4, 8) # (batch, seq_len, d_model)
ln = nn.LayerNorm(8)
y = ln(x)
3. RMSNorm: математика и интуиция
3.1 Формула
RMSNorm отбрасывает центрирование (вычитание среднего) и использует только среднеквадратичное значение (RMS):
[ [text](/wiki/text){RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^d x_i^2} ] [ \hat{x}_i = \frac{x_i}{[text](/wiki/text){RMS}(x) + [epsilon](/wiki/Epsilon)} ] [ y_i = \gamma_i \hat{x}_i \quad ([text](/wiki/text){без } [beta](/wiki/beta)) ]
Обратите внимание: ([beta](/wiki/beta)) отсутствует, хотя в некоторых реализациях может быть добавлен (но обычно не используется).
3.2 Интуиция
RMSNorm масштабирует вектор так, чтобы его RMS стал равен 1, но не принуждает среднее к нулю. Это сохраняет информацию о сдвиге, которая может быть полезна для модели. На практике качество остаётся сопоставимым с LayerNorm, а вычислений становится меньше.
3.3 Реализация на Python
def rms_norm(x, gamma, eps=1e-5):
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
x_norm = x / rms
return gamma * x_norm
# Пример
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return self.gamma * x / rms
4. Сравнение LayerNorm и RMSNorm
| Характеристика | LayerNorm | RMSNorm |
|---|---|---|
| Центрирование | Вычитает среднее (\mu) | Не вычитает |
| Масштабирование | Делит на (\sigma) | Делит на RMS |
| Обучаемые параметры | (\gamma) и ([beta](/wiki/beta)) | Только (\gamma) (обычно) |
| Вычислительная сложность | (O(2d)): среднее + дисперсия | (O(d)): только сумма квадратов |
| Численная стабильность | Чуть выше из-за ([epsilon](/wiki/Epsilon)) в знаменателе | Аналогично |
| Качество (perplexity) | Эталон | Сопоставимо (иногда чуть хуже) |
| Скорость | Базовая | На 10–15% быстрее |
| Применение в моделях | BERT, GPT-2, T5 | Llama, Mistral, Gemma, Qwen |
5. Почему RMSNorm быстрее? Детальный анализ
5.1 Количество операций
Для вектора размерности (d):
-
- Вычисление среднего: (d) сложений, 1 деление.
- Вычисление дисперсии: (d) вычитаний, (d) умножений, (d) сложений, 1 деление.
- Нормализация: (d) вычитаний, (d) делений.
- Итого: ~(4d) арифметических операций + 2 деления.
-
- Вычисление RMS: (d) умножений, (d) сложений, 1 деление, 1 квадратный корень.
- Нормализация: (d) делений.
- Итого: ~(2d) арифметических операций + 1 деление + 1 sqrt.
На практике квадратный корень — дорогая операция, но она есть и в LayerNorm (при вычислении (\sigma)). RMSNorm экономит ~(2d) операций на каждом токене. Для батча из (B) токенов и (L) слоёв экономия становится значительной.
5.2 Отсутствие центрирования
Вычисление среднего требует дополнительного прохода по данным (или одновременного накопления суммы). RMSNorm обходится одним проходом для суммы квадратов. Это уменьшает загрузку памяти и кэша, что особенно важно на GPU.
5.3 Эмпирические измерения
В статье RMSNorm (Zhang & Sennrich, 2019) показано ускорение на 10–15% на задачах машинного перевода. В современных LLM с десятками миллиардов параметров экономия времени обучения может составлять часы.
6. Где применяются LayerNorm и RMSNorm
- LayerNorm — классический выбор в трансформерах до 2022 года: BERT, GPT-2, T5, ViT.
- RMSNorm — стандарт в современных открытых LLM:
Причина перехода: RMSNorm даёт выигрыш в скорости без потери качества, что критично при обучении больших моделей.
7. Влияние на обучение и качество
7.1 Сходимость
RMSNorm обычно сходится так же быстро, как LayerNorm, или даже немного быстрее из-за меньшего количества операций. Однако в некоторых задачах (например, с очень разреженными активациями) отсутствие центрирования может приводить к нестабильности — тогда помогает добавление ([beta](/wiki/beta)) или использование Pre-LN вместо Post-LN.
7.2 Качество (perplexity)
Эксперименты показывают, что разница в perplexity между LayerNorm и RMSNorm составляет менее 0.1–0.2 пункта на стандартных бенчмарках (WikiText-103, C4). В некоторых случаях RMSNorm даже немного лучше, так как сохраняет информацию о сдвиге.
7.3 Влияние на градиенты
LayerNorm делает градиенты независимыми от масштаба входа, что улучшает conditioning. RMSNorm сохраняет это свойство для масштаба, но не для сдвига. На практике это редко становится проблемой.
8. Недостатки RMSNorm
- Отсутствие центрирования — если распределение активаций сильно смещено, RMSNorm не нормализует его, что может замедлить обучение.
- Меньше обучаемых параметров — без ([beta](/wiki/beta)) модель теряет один степень свободы, хотя это компенсируется другими слоями.
- Не подходит для некоторых архитектур — например, в моделях с Group Normalization (как в Stable Diffusion) RMSNorm не используется.
На практике эти недостатки незначительны, и RMSNorm остаётся предпочтительным выбором для LLM.
9. Связь с другими нормализациями
| Тип | Ось нормализации | Применение |
|---|---|---|
| BatchNorm | По батчу и пространству | CNN, не для последовательностей |
| LayerNorm | По признакам токена | Трансформеры, RNN |
| RMSNorm | По признакам токена (без среднего) | Современные LLM |
| GroupNorm | По группам каналов | Vision, малые батчи |
| InstanceNorm | По одному примеру | Стилизация, GAN |
10. Практические рекомендации
- Для нового проекта LLM используйте RMSNorm — быстрее, проще, качество не хуже.
- Если модель уже использует LayerNorm, замена на RMSNorm может потребовать тонкой настройки (learning rate, инициализация).
- В задачах с очень длинными последовательностями (например, 128k токенов) выигрыш в скорости от RMSNorm становится более заметным.
- Всегда используйте Pre-LN (нормализация перед подуровнем) — она стабильнее, чем Post-LN, независимо от типа нормализации.
Пет-проект для закрепления
Задача: Сравнить скорость и качество LayerNorm и RMSNorm на задаче языкового моделирования (обучение маленького трансформера с нуля).
Инструменты: PyTorch, Hugging Face Transformers, датасет WikiText-2.
Шаги:
- Реализовать два класса
LayerNormиRMSNorm(как показано выше). - Создать маленький трансформер (2 слоя, 4 головы, d_model=128) с возможностью переключения нормализации.
- Обучить обе модели на WikiText-2 (1000 шагов, одинаковый seed).
- Замерить время обучения и финальную perplexity на валидации.
- Построить график loss по шагам.
Ожидаемый результат: RMSNorm будет на 10–15% быстрее по времени на шаг, perplexity будет отличаться не более чем на 0.5 пункта.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 653 | Нормализация в трансформерах (общий обзор) |
| 655 | Функции активации (Swish, GeLU) |
| 656 | Позиционные эмбеддинги (RoPE, AliBi) |
| 657 | Архитектура Attention (Multi-Head, Grouped-Query) |
| 658 | Pre-LN vs Post-LN |
| 659 | Инициализация весов в трансформерах |
Навигация
- Предыдущий: 653
- Следующий: 655
- Индекс: 00. Индекс разборов