English translation is not available yet. Showing Russian content.

Как работает RMSNorm (Root Mean Square Normalization) и чем лучше LayerNorm?

Краткий тезис

RMSNorm (Root Mean Square Normalization) — это упрощённая версия LayerNorm, которая нормализует входные данные только по среднеквадратичному значению (RMS), без вычитания среднего (центрирования). Это делает её на 10–15% быстрее в инференсе и обучении, сохраняя сопоставимое качество на большинстве задач. RMSNorm стала стандартом в современных LLM (Llama, Mistral, Gemma) благодаря более эффективной реализации, особенно при работе с длинными последовательностями.

1. Термин: Layer Normalization (LayerNorm)

LayerNorm — это метод нормализации, который стандартизирует активации нейронов внутри слоя, вычисляя среднее и дисперсию по всем нейронам слоя для каждого примера в батче. Формула:

[ [text](/wiki/text){LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + [epsilon](/wiki/Epsilon)}} + [beta](/wiki/beta) ]

где:

  • (x) — входной вектор размерности (d);
  • (\mu = \frac{1}{d} \sum_{i=1}^{d} x_i) — среднее по слою;
  • (\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2) — дисперсия;
  • (\gamma, [beta](/wiki/beta)) — обучаемые параметры сдвига и масштаба;
  • ([epsilon](/wiki/Epsilon)) — малая константа для численной стабильности.

Применяется в Transformer-архитектурах после каждого под-слоя (self-attention, feed-forward). Выполняет два действия: центрирование (вычитание среднего) и нормализацию (деление на стандартное отклонение).

2. Термин: RMSNorm (Root Mean Square Normalization)

RMSNorm — это вариант нормализации, предложенный в статье "Root Mean Square Layer Normalization" (Zhang & Sennrich, 2019). Вместо полной стандартизации она делит вход на среднеквадратичное значение (RMS) без центрирования:

[ [text](/wiki/text){RMSNorm}(x) = \gamma \cdot \frac{x}{[text](/wiki/text){RMS}(x)} + [beta](/wiki/beta), \quad [text](/wiki/text){RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + [epsilon](/wiki/Epsilon)} ]

Ключевые отличия от LayerNorm:

  • Нет вычитания среднего — не требуется вычислять (\mu).
  • Единственная статистика — только RMS (квадратный корень из среднего квадратов).
  • Параметр сдвига ([beta](/wiki/beta)) сохраняется, но без центрирования он играет роль обучения сдвигу, а не восстановления среднего после нормализации.

3. Сравнение RMSNorm и LayerNorm

ХарактеристикаLayerNormRMSNorm
Центрирование (вычитание среднего)ДаНет
Вычисляемые статистики(\mu) и (\sigma^2)([text](/wiki/text){RMS}(x))
Количество операций~3N + 2 корня~2N + 1 корень
Относительная скорость (инференс)Базовый~10–15% быстрее
Качество на NLP-задачахЭталонноеСопоставимо
Использование в современных LLMGPT-2, BERTLlama, Mistral, Gemma, Falcon

4. Почему RMSNorm быстрее?

Вычислительная сложность LayerNorm для вектора длины (d):

  • Вычисление (\mu): (d) операций сложения + 1 деление.
  • Вычисление (\sigma^2): (d) вычитаний, (d) возведений в квадрат, (d) сложений, 1 деление.
  • Нормализация: (d) вычитаний, (d) делений.
  • Итого: ~3(d) сложений/вычитаний, ~(d) возведений в квадрат, ~(d) делений, 1 квадратный корень.

Для RMSNorm:

  • Вычисление RMS: (d) возведений в квадрат, (d) сложений, 1 деление, 1 квадратный корень.
  • Нормализация: (d) делений.
  • Итого: ~(d) сложений, ~(d) возведений в квадрат, ~(d) делений, 1 корень.

Экономия составляет около 2(d) операций сложения/вычитания, что даёт прирост скорости на 10–15% в реальных условиях (особенно на GPU, где операции с памятью — узкое место). В обратном проходе выгода ещё больше, так как не нужно вычислять градиенты по (\mu) и (\sigma^2).

5. Почему качество остаётся сопоставимым?

Авторы RMSNorm показали, что центрирование (вычитание среднего) не является критически важным для нормализации в RNN и Transformer. Нормализация по RMS уже приводит активации к стабильному диапазону, а обучаемый параметр ([beta](/wiki/beta)) может компенсировать отсутствие явного вычитания среднего. Эксперименты на машинном переводе, языковом моделировании и классификации текстов показали:

  • Потери (perplexity): разница менее 0,1–0,2 пункта.
  • Точность на GLUE/SuperGLUE: в пределах погрешности.
  • Скорость сходимости: иногда даже быстрее, особенно при больших глубинах.

Практический вывод: в современных LLM (Llama 2/3, Mistral) замена LayerNorm на RMSNorm стала стандартом, так как даёт выигрыш в скорости без заметного ухудшения.

6. Влияние на архитектуру Transformer

В оригинальном Transformer используется Post-LN (LayerNorm после под-слоя). Современные модели, использующие RMSNorm, часто применяют Pre-LN (нормализация перед под-слоем). Это дополнительно улучшает стабильность обучения. Пример:

# Pre-LN с RMSNorm
x = RMSNorm(x)
x = x + MultiHeadAttention(x)
x = RMSNorm(x)
x = x + FeedForward(x)

Такая конфигурация используется в Llama и Mistral. RMSNorm в Pre-LN даёт дополнительный прирост скорости за счёт меньшего количества операций.

7. Реализация RMSNorm на Python

Пример кода для сравнения:

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))  # опционально
        self.eps = eps

    def forward(self, x):
        # x shape: (..., d_model)
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return self.gamma * (x / rms) + self.beta

# Сравнение времени на GPU
layer_norm = nn.LayerNorm(4096).cuda()
rms_norm = RMSNorm(4096).cuda()
x = torch.randn(32, 128, 4096).cuda()

# Замер (упрощённый)
import time
t0 = time.time(); _ = layer_norm(x); t1 = time.time()
t2 = time.time(); _ = rms_norm(x); t3 = time.time()
print(f"LayerNorm: {t1-t0:.4f}s, RMSNorm: {t3-t2:.4f}s")

Ожидаемый результат: RMSNorm быстрее на 10–15% (разброс зависит от размера тензора и hardware).

8. Когда выбрать RMSNorm вместо LayerNorm?

  • Когда скорость критична — крупные модели, inference на CPU/edge, сервисы с высоким RPS.
  • При обучении очень глубоких сетей — Pre-LN с RMSNorm стабильнее (меньше ваниширующих градиентов).
  • В моделях, где уже используется RMSNorm — для совместимости (Llama, Mistral).
  • Если нет необходимости в центрировании — например, когда среднее активаций неважно (ReLU-подобные функции активации нечувствительны к сдвигу).

Ограничения RMSNorm:

  • Теоретически может быть менее стабильна при очень малых значениях RMS (но эпсилон решает).
  • В некоторых задачах (например, регрессия с чувствительностью к среднему) LayerNorm может дать небольшой прирост.
  • Отсутствие центрирования означает, что нормированные значения имеют ненулевое среднее — это может влиять на распределение активаций последующих слоёв.

Пет-проект для закрепления

Задача Сравнить скорость и качество RMSNorm и LayerNorm при fine-tuning небольшой модели (например, DistilBERT) на задаче классификации текстов.

Инструменты Python, PyTorch, Hugging Face Transformers, datasets (IMDb).

Шаги:

  1. Загрузите DistilBERT (использует LayerNorm).
  2. Замените все LayerNorm на RMSNorm (заменив nn.LayerNorm на кастомный RMSNorm).
  3. Протестируйте качество (accuracy, F1) на IMDb и замерьте время одного эпоха обучения (train + eval).
  4. Сравните с оригиналом (LayerNorm).
  5. Дополнительно: замерьте время инференса на CPU (batch=1).

Ожидаемый результат RMSNorm покажет аналогичное качество (разница < 0.5%) и ускорение на 10–15% на GPU, на CPU — заметнее.

Связь с другими вопросами

ВопросТема
279Batch Normalization и Layer Normalization
281GroupNorm и другие методы нормализации
274Архитектура Transformer (расположение нормализации)
282Влияние нормализации на обучение глубоких сетей
287Оптимизации inference в LLM (слияние операций)

Навигация