Как работает инициализация весов в LLM (Xavier, Kaiming, почему важна)?

Краткий тезис

Инициализация весов — это критический этап обучения нейросетей, определяющий, сойдётся ли модель и как быстро. Плохая инициализация приводит к затухающим или взрывающимся градиентам, делая обучение невозможным. Xavier (Glorot) и Kaiming (He) — два стандартных метода, которые математически обосновывают выбор дисперсии начальных весов так, чтобы сигнал (активации и градиенты) не затухал и не взрывался при проходе через слои. Для LLM с современными активациями (SwiGLU, GELU) чаще применяют Kaiming с поправкой на gain.


1. Проблема: затухающие и взрывающиеся градиенты

При обучении глубокой сети градиенты проходят через множество слоёв (цепное правило). Если веса инициализированы слишком малыми значениями, дисперсия активаций уменьшается с каждым слоем → градиенты затухают (gradients|vanishing gradients). Если слишком большими — дисперсия растёт → градиенты взрываются (gradients|exploding gradients). Оба случая делают обучение нестабильным или невозможным.

Пример затухания:

  • Вход: x ~ N(0,1)
  • Слой с весами W ~ N(0, 0.01): выход y = Wx ~ N(0, 0.0001)
  • После 10 слоёв: дисперсия ~ 10⁻⁴⁰ → градиенты практически нулевые.

Пример взрыва:

  • Веса W ~ N(0, 10): выход после 10 слоёв → NaN.

Цель инициализации — подобрать дисперсию весов так, чтобы дисперсия активаций (и градиентов) оставалась примерно постоянной от слоя к слою.


2. Термин: дисперсия и сигнал

Дисперсия — мера разброса значений. Для случайной величины X: Var(X) = E[(X - μ)²]. В контексте нейросетей нас интересует дисперсия выходов слоя (активаций) и дисперсия градиентов]].

Сигнал — значения, которые передаются между слоями (активации на forward pass, градиенты на backward pass). Если сигнал затухает или взрывается, обучение нарушается.

Инициализация весов — это выбор распределения (обычно равномерного или нормального) с определённой дисперсией, которая зависит от количества входов (fan_in) и выходов (fan_out) слоя.


3. Xavier (Glorot) инициализация

Разработана Xavier Glorot и Yoshua Bengio в 2010 году для сетей с симметричными активациями (tanh, sigmoid). Основная идея: дисперсия выходов каждого слоя должна быть равна дисперсии входов.

Формула:

Var(W) = 2 / (fan_in + fan_out)

где fan_in — количество входов слоя, fan_out — количество выходов.

Распределение:

  • Равномерное: W ~ U[-√(6/(fan_in+fan_out)), +√(6/(fan_in+fan_out))]
  • Нормальное: W ~ N(0, √(2/(fan_in+fan_out)))

Почему работает:

  • Для линейного слоя y = Wx, если x имеет дисперсию σ²_x, а веса независимы с дисперсией σ²_W, то Var(y) = fan_in * σ²_W * σ²_x.
  • Чтобы Var(y) = σ²_x, нужно σ²_W = 1/fan_in.
  • Но на обратном проходе градиенты проходят через транспонированную матрицу, и условие для градиентов даёт σ²_W = 1/fan_out.
  • Компромисс: σ²_W = 2/(fan_in + fan_out).

Ограничение: Xavier плохо работает с ReLU, так как ReLU обнуляет половину значений, что нарушает предположение о симметрии.


4. Kaiming (He) инициализация

Разработана Kaiming He в 2015 году специально для ReLU и его вариантов (Leaky ReLU, PReLU). Учитывает, что ReLU обнуляет отрицательные значения, уменьшая дисперсию вдвое.

Формула:

Var(W) = 2 / fan_in

где fan_in — количество входов слоя.

Распределение:

  • Нормальное: W ~ N(0, √(2/fan_in))
  • Равномерное: W ~ U[-√(6/fan_in), +√(6/fan_in)]

Почему 2/fan_in, а не 1/fan_in:

  • Для ReLU: y = max(0, Wx). Если x ~ N(0, σ²), то после ReLU дисперсия уменьшается примерно вдвое (половина значений становится нулём).
  • Чтобы компенсировать это, дисперсия весов берётся в 2 раза больше: σ²_W = 2/fan_in.

Для Leaky ReLU:

Var(W) = 2 / (1 + α²) * fan_in

где α — коэффициент наклона для отрицательных значений (обычно 0.01).

Сравнение Xavier и Kaiming:

ПараметрXavier (Glorot)Kaiming (He)
Целевые активацииtanh, sigmoidReLU, Leaky ReLU, PReLU
Формула дисперсии2/(fan_in+fan_out)2/fan_in
Учёт нелинейностиНет (предполагает симметрию)Да (компенсирует обнуление)
Поведение с ReLUДисперсия падает → затуханиеДисперсия стабильна

5. Инициализация в современных LLM

Современные LLM (GPT-2, LLaMA, Mistral) используют:

  • Функции активации: SwiGLU, GELU, SiLU (они не обнуляют отрицательные значения полностью, но имеют нелинейность).
  • Нормализацию: LayerNorm (или RMSNorm) после каждого подуровня, что стабилизирует дисперсию независимо от инициализации.

Практика:

  • Для линейных слоёв (W_q, W_k, W_v, W_o, FFN) часто используют Kaiming normal с gain=√2 (для компенсации SwiGLU или GELU).
  • Для эмбеддингов (token embeddings, positional embeddings) — нормальное распределение с малой дисперсией (например, N(0, 0.02)).
  • Для выходного слоя (lm_head) — часто нулевая инициализация или Kaiming с малым gain.

Пример из LLaMA:

def init_weights(module):
    if isinstance(module, nn.Linear):
        torch.nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

Почему не Xavier:

  • SwiGLU и GELU не симметричны, но и не обнуляют половину значений. Kaiming с gain=√2 эмпирически даёт лучшую сходимость.
  • LayerNorm делает инициализацию менее критичной, но хорошая инициализация всё равно ускоряет обучение.

6. Другие методы инициализации

МетодОписаниеКогда используется
OrthogonalВеса инициализируются ортогональной матрицей (W^T W = I). Сохраняет норму сигнала.RNN, LSTM (для борьбы с затуханием)
Small initВеса ~ N(0, 0.01). Простая, но часто приводит к затуханию.Мелкие сети, fine-tuning
Zero initВсе веса = 0.Только для bias или residual connections (чтобы начать с identity)
Pre-trainingВеса берутся из предобученной модели.Fine-tuning, transfer learning

Для residual connections:

  • В трансформерах часто инициализируют последний линейный слой в каждом residual block нулями (zero init), чтобы на старте модель была identity mapping. Это стабилизирует обучение глубоких сетей (метод из GPT-2).

7. Практические рекомендации для LLM

  1. Для линейных слоёв (attention, FFN): Kaiming normal с gain=√2 (или 1.0 для ReLU).
  2. Для эмбеддингов: нормальное распределение с std=0.02 (или 1/√d_model).
  3. Для выходного слоя (lm_head): часто используют tied embeddings (веса эмбеддингов и lm_head общие) или инициализируют нулями.
  4. Для bias: нули.
  5. Для LayerNorm: weight = 1, bias = 0.
  6. Для residual blocks: последний линейный слой в blockzero init (если используется pre-norm архитектура).

Проверка инициализации:

  • После инициализации прогоните один forward pass с batch данных и проверьте дисперсию активаций на каждом слое. Она не должна резко расти или падать.
  • Если loss на старте не падает (или сразу NaN) — проблема в инициализации.

8. Связь с другими аспектами обучения

  • Оптимизаторы: Adam, AdamW адаптивно подбирают learning rate для каждого параметра, что частично компенсирует плохую инициализацию. Но хорошая инициализация всё равно важна для скорости сходимости.
  • Learning rate schedule: Warmup (постепенное увеличение LR) помогает стабилизировать обучение при плохой инициализации.
  • Normalization: LayerNorm, BatchNorm, RMSNorm уменьшают зависимость от инициализации, но не отменяют её полностью.
  • Weight decay: Регуляризация, которая штрафует большие веса, может взаимодействовать с инициализацией.

Пет-проект для закрепления

Задача: Сравнить влияние Xavier, Kaiming и плохой инициализации на обучение маленького трансформера.

Инструменты: PyTorch, Hugging Face Transformers (или написать свой трансформер), Weights & Biases (для логирования).

Шаги:

  1. Создайте простой трансформер (2-3 слоя, d_model=128, 4 головы) для задачи классификации текста (например, IMDb).
  2. Реализуйте три варианта инициализации:
    • Xavier uniform (для всех линейных слоёв)
    • Kaiming normal (gain=√2)
    • Плохая: нормальное распределение с std=1.0 (взрыв) или std=1e-6 (затухание)
  3. Обучите каждую модель с одинаковым learning rate (1e-4), batch size, количеством эпох.
  4. Логируйте loss, accuracy, градиенты (norm) на каждом шаге.
  5. Постройте графики: loss vs step, gradient norm vs step.

Ожидаемый результат:

  • Kaiming: стабильное обучение, loss монотонно падает.
  • Xavier: loss падает медленнее, возможны колебания.
  • Плохая инициализация: loss не падает (затухание) или становится NaN (взрыв).

Дополнительно: Попробуйте добавить LayerNorm и повторите эксперимент — увидите, что нормализация сглаживает разницу, но Kaiming всё равно даёт лучший старт.


Связь с другими вопросами

ВопросТема
664Архитектура трансформера (где применяются линейные слои)
666Оптимизаторы (Adam, AdamW) и их взаимодействие с инициализацией
667LayerNorm и другие методы нормализации
668Функции активации (ReLU, GELU, SwiGLU)
669Проблема затухающих/взрывающихся градиентов
670Fine-tuning LLM (как инициализация влияет на дообучение)

Навигация