Как работает LayerNorm и RMSNorm? В чем разница и почему RMSNorm быстрее?

Краткий тезис

Layer Normalization (LayerNorm) и Root Mean Square Normalization (RMSNorm) — это методы нормализации скрытых состояний в нейронных сетях. LayerNorm вычитает среднее и делит на стандартное отклонение, что требует вычисления двух статистик. RMSNorm упрощает процесс, используя только среднеквадратичное значение (RMS) без центрирования, что даёт прирост скорости на 10–15% при сопоставимом качестве. RMSNorm стал стандартом в современных LLM (Llama, Mistral) благодаря эффективности и простоте.


1. Зачем нужна нормализация в нейронных сетях

Нормализация стабилизирует обучение, уменьшая внутренний ковариатный сдвиг (internal covariate shift) — изменение распределения активаций слоёв в процессе обучения. Без нормализации градиенты могут становиться слишком большими или маленькими, что замедляет сходимость или приводит к расходимости. Основные подходы:

  • Batch Normalization — нормализация по батчу (неудобна для RNN/трансформеров из-за зависимости от длины последовательности).
  • Layer Normalization — нормализация по признакам одного токена (независима от батча, подходит для трансформеров).
  • RMSNorm — упрощённая версия LayerNorm без центрирования.

В трансформерах нормализация применяется после каждого подуровня (self-attention, FFN) — это Post-LN (оригинальный Transformer) или перед подуровнем — Pre-LN (современные модели, стабильнее).


2. Layer Normalization: математика и интуиция

2.1 Формула

Для входного вектора ( x \in \mathbb{R}^d ) (state|скрытое состояние одного токена) LayerNorm вычисляет:

[ \mu = \frac{1}{d} \sum_{i=1}^d x_i ] [ \sigma^2 = \frac{1}{d} \sum_{i=1}^d (x_i - \mu)^2 ] [ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + [epsilon](/wiki/Epsilon)}} ] [ y_i = \gamma_i \hat{x}_i + \beta_i ]

Где:

  • (\mu) — среднее, (\sigma^2) — дисперсия.
  • ([epsilon](/wiki/Epsilon)) — малая константа для численной стабильности (обычно (10^{-5})).
  • (\gamma, [beta](/wiki/beta) \in \mathbb{R}^d) — обучаемые параметры сдвига и масштаба (affine transformation).

2.2 Интуиция

LayerNorm делает распределение активаций для каждого токена стандартным нормальным (среднее 0, дисперсия 1), после чего позволяет модели выучить оптимальный сдвиг и масштаб через (\gamma, [beta](/wiki/beta)). Это устраняет зависимость от масштаба входов и стабилизирует градиенты.

2.3 Реализация на Python (упрощённо)

import torch
import torch.nn as nn

def layer_norm(x, gamma, beta, eps=1e-5):
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True, unbiased=False)
    x_norm = (x - mean) / torch.sqrt(var + eps)
    return gamma * x_norm + beta

# Пример
x = torch.randn(2, 4, 8)  # (batch, seq_len, d_model)
ln = nn.LayerNorm(8)
y = ln(x)

3. RMSNorm: математика и интуиция

3.1 Формула

RMSNorm отбрасывает центрирование (вычитание среднего) и использует только среднеквадратичное значение (RMS):

[ [text](/wiki/text){RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^d x_i^2} ] [ \hat{x}_i = \frac{x_i}{[text](/wiki/text){RMS}(x) + [epsilon](/wiki/Epsilon)} ] [ y_i = \gamma_i \hat{x}_i \quad ([text](/wiki/text){без } [beta](/wiki/beta)) ]

Обратите внимание: ([beta](/wiki/beta)) отсутствует, хотя в некоторых реализациях может быть добавлен (но обычно не используется).

3.2 Интуиция

RMSNorm масштабирует вектор так, чтобы его RMS стал равен 1, но не принуждает среднее к нулю. Это сохраняет информацию о сдвиге, которая может быть полезна для модели. На практике качество остаётся сопоставимым с LayerNorm, а вычислений становится меньше.

3.3 Реализация на Python

def rms_norm(x, gamma, eps=1e-5):
    rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
    x_norm = x / rms
    return gamma * x_norm

# Пример
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.gamma * x / rms

4. Сравнение LayerNorm и RMSNorm

ХарактеристикаLayerNormRMSNorm
ЦентрированиеВычитает среднее (\mu)Не вычитает
МасштабированиеДелит на (\sigma)Делит на RMS
Обучаемые параметры(\gamma) и ([beta](/wiki/beta))Только (\gamma) (обычно)
Вычислительная сложность(O(2d)): среднее + дисперсия(O(d)): только сумма квадратов
Численная стабильностьЧуть выше из-за ([epsilon](/wiki/Epsilon)) в знаменателеАналогично
Качество (perplexity)ЭталонСопоставимо (иногда чуть хуже)
СкоростьБазоваяНа 10–15% быстрее
Применение в моделяхBERT, GPT-2, T5Llama, Mistral, Gemma, Qwen

5. Почему RMSNorm быстрее? Детальный анализ

5.1 Количество операций

Для вектора размерности (d):

  • LayerNorm:

    • Вычисление среднего: (d) сложений, 1 деление.
    • Вычисление дисперсии: (d) вычитаний, (d) умножений, (d) сложений, 1 деление.
    • Нормализация: (d) вычитаний, (d) делений.
    • Итого: ~(4d) арифметических операций + 2 деления.
  • RMSNorm:

    • Вычисление RMS: (d) умножений, (d) сложений, 1 деление, 1 квадратный корень.
    • Нормализация: (d) делений.
    • Итого: ~(2d) арифметических операций + 1 деление + 1 sqrt.

На практике квадратный корень — дорогая операция, но она есть и в LayerNorm (при вычислении (\sigma)). RMSNorm экономит ~(2d) операций на каждом токене. Для батча из (B) токенов и (L) слоёв экономия становится значительной.

5.2 Отсутствие центрирования

Вычисление среднего требует дополнительного прохода по данным (или одновременного накопления суммы). RMSNorm обходится одним проходом для суммы квадратов. Это уменьшает загрузку памяти и кэша, что особенно важно на GPU.

5.3 Эмпирические измерения

В статье RMSNorm (Zhang & Sennrich, 2019) показано ускорение на 10–15% на задачах машинного перевода. В современных LLM с десятками миллиардов параметров экономия времени обучения может составлять часы.


6. Где применяются LayerNorm и RMSNorm

Причина перехода: RMSNorm даёт выигрыш в скорости без потери качества, что критично при обучении больших моделей.


7. Влияние на обучение и качество

7.1 Сходимость

RMSNorm обычно сходится так же быстро, как LayerNorm, или даже немного быстрее из-за меньшего количества операций. Однако в некоторых задачах (например, с очень разреженными активациями) отсутствие центрирования может приводить к нестабильности — тогда помогает добавление ([beta](/wiki/beta)) или использование Pre-LN вместо Post-LN.

7.2 Качество (perplexity)

Эксперименты показывают, что разница в perplexity между LayerNorm и RMSNorm составляет менее 0.1–0.2 пункта на стандартных бенчмарках (WikiText-103, C4). В некоторых случаях RMSNorm даже немного лучше, так как сохраняет информацию о сдвиге.

7.3 Влияние на градиенты

LayerNorm делает градиенты независимыми от масштаба входа, что улучшает conditioning. RMSNorm сохраняет это свойство для масштаба, но не для сдвига. На практике это редко становится проблемой.


8. Недостатки RMSNorm

  1. Отсутствие центрирования — если распределение активаций сильно смещено, RMSNorm не нормализует его, что может замедлить обучение.
  2. Меньше обучаемых параметров — без ([beta](/wiki/beta)) модель теряет один степень свободы, хотя это компенсируется другими слоями.
  3. Не подходит для некоторых архитектур — например, в моделях с Group Normalization (как в Stable Diffusion) RMSNorm не используется.

На практике эти недостатки незначительны, и RMSNorm остаётся предпочтительным выбором для LLM.


9. Связь с другими нормализациями

ТипОсь нормализацииПрименение
BatchNormПо батчу и пространствуCNN, не для последовательностей
LayerNormПо признакам токенаТрансформеры, RNN
RMSNormПо признакам токена (без среднего)Современные LLM
GroupNormПо группам каналовVision, малые батчи
InstanceNormПо одному примеруСтилизация, GAN

10. Практические рекомендации

  • Для нового проекта LLM используйте RMSNorm — быстрее, проще, качество не хуже.
  • Если модель уже использует LayerNorm, замена на RMSNorm может потребовать тонкой настройки (learning rate, инициализация).
  • В задачах с очень длинными последовательностями (например, 128k токенов) выигрыш в скорости от RMSNorm становится более заметным.
  • Всегда используйте Pre-LN (нормализация перед подуровнем) — она стабильнее, чем Post-LN, независимо от типа нормализации.

Пет-проект для закрепления

Задача: Сравнить скорость и качество LayerNorm и RMSNorm на задаче языкового моделирования (обучение маленького трансформера с нуля).

Инструменты: PyTorch, Hugging Face Transformers, датасет WikiText-2.

Шаги:

  1. Реализовать два класса LayerNorm и RMSNorm (как показано выше).
  2. Создать маленький трансформер (2 слоя, 4 головы, d_model=128) с возможностью переключения нормализации.
  3. Обучить обе модели на WikiText-2 (1000 шагов, одинаковый seed).
  4. Замерить время обучения и финальную perplexity на валидации.
  5. Построить график loss по шагам.

Ожидаемый результат: RMSNorm будет на 10–15% быстрее по времени на шаг, perplexity будет отличаться не более чем на 0.5 пункта.


Связь с другими вопросами

ВопросТема
653Нормализация в трансформерах (общий обзор)
655Функции активации (Swish, GeLU)
656Позиционные эмбеддинги (RoPE, AliBi)
657Архитектура Attention (Multi-Head, Grouped-Query)
658Pre-LN vs Post-LN
659Инициализация весов в трансформерах

Навигация