Как вы дебажите training instability (loss spikes, divergence)?

Краткий тезис

Training instability — это ситуация, когда loss ведёт себя непредсказуемо: резкие скачки (spikes) или уход в бесконечность (divergence). Причины лежат в гиперпараметрах (слишком высокий learning rate, отсутствие gradient clipping), данных (NaN, выбросы), precision (переполнение в fp16) или архитектуре (softmax overflow в LLM). Диагностика требует логирования градиентных норм, активаций и loss-кривых, а стабилизация достигается комбинацией warmup, gradient clipping, корректировки learning rate и использования precision precision precision training|mixed precision с loss scaling.


1. Термин: Training Instability (нестабильность обучения)

Training instability — это состояние, при котором loss-функция не сходится гладко, а демонстрирует аномальное поведение. Основные симптомы:

  • Loss spikes — резкие кратковременные скачки loss (например, с 2.5 до 15.0 за один шаг).
  • Divergenceloss неуклонно растёт или уходит в NaN/Inf, обучение «взрывается».
  • Oscillations — loss колеблется с большой амплитудой, не снижаясь.

Нестабильность критична для LLM (особенно 70B+), так как один сбойный шаг может испортить всю контрольную точку (checkpoint). В контексте Agentic RAG instability может возникнуть при fine-tuning агентов (RL, PPO) или при обучении ретриверов.


2. Причины loss spikes

2.1 Слишком высокий learning rate

Если learning rate (скорость обучения) превышает оптимальное значение, градиенты делают слишком большие шаги, перескакивая минимум. Это приводит к резкому росту loss.

Диагностика построить график learning rate и loss — spike совпадает с моментом, когда LR ещё не снизился (например, после warmup).

2.2 Gradient explosion (взрыв градиентов)

Gradient explosion — когда норма градиента становится огромной (например, >1000). Это происходит из-за глубоких сетей (RNN, Transformer) или плохой инициализации. В LLM часто связано с attention softmax (см. раздел 7).

Диагностика логировать gradient norm (норму градиента) — если она резко возрастает перед spike, причина в explosion.

2.3 Проблемы с данными

  • NaN в данных — пропуски, деление на ноль.
  • Outliers — экстремальные значения в loss (например, токен с очень низкой вероятностью).
  • Некорректная маска — attention mask, пропускающая pad-токены, может давать неправильные логиты.

3. Причины divergence (расходимости)

3.1 Отсутствие warmup

Warmup — постепенное увеличение learning rate с нуля до целевого значения в начале обучения. Без warmup модель сразу получает большой градиентный шаг, что может привести к divergence, особенно при использовании AdamW (адаптивный оптимизатор с коррекцией момента).

3.2 Плохая инициализация

Если веса инициализированы слишком большими значениями, активации насыщаются (например, tanh или sigmoid), градиенты затухают или взрываются. Для Transformer стандарт — инициализация из Xavier или Kaiming с учётом глубины.

3.3 Высокий batch size

При увеличении batch size (размер батча) оценка градиента становится точнее, но если learning rate не скорректирован (линейное масштабирование), шаги могут быть слишком большими. Это вызывает divergence.

3.4 Проблемы с precision (точностью)

При использовании mixed precision (fp16) числа с плавающей точкой могут переполняться (overflow) или терять точность (underflow). Если loss scaling не настроен, градиенты обнуляются или становятся Inf.


4. Диагностика: что логировать

Для эффективного дебага нужно логировать не только loss, но и внутренние состояния модели. Основные метрики:

МетрикаЧто показываетКак логировать
LossОсновной сигнал нестабильностиКаждый шаг или каждые N шагов
Gradient normВзрыв градиентовtorch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) до клиппинга
Gradient varianceНестабильность направленияСреднее и std градиентов по слоям
Activation statisticsНасыщение нейроновСреднее, std, min/max активаций (например, после attention)
Learning rateСвязь spikes с LRТекущее значение LR
Loss scale (для mixed precision)Переполнение fp16scaler.get_scale() из torch.cuda.amp
NaN/Inf countПроблемы с данными или вычислениямиПроверка каждого тензора на NaN/Inf

Инструменты Weights & Biases (W&B), TensorBoard, MLflow. В W&B можно построить панель с loss, gradient norm и LR на одном графике.


5. Техники стабилизации

5.1 Gradient clipping (клиппинг градиентов)

Ограничивает норму градиента сверху. Стандартное значение max_norm=1.0 для LLM.

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Когда помогает при gradient explosion. Если spikes исчезают после клиппинга — причина в explosion.

5.2 Warmup steps

Постепенное увеличение LR от 0 до целевого за warmup_steps (обычно 500–2000 шагов). Реализация через linear schedule:

from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=total_steps)

Когда помогает при divergence в начале обучения.

5.3 Learning rate schedule (расписание LR)

После warmup обычно используют cosine decay или linear decay. Cosine decay более плавный, снижает риск spikes на поздних этапах.

5.4 Уменьшение learning rate

Если spikes появляются после warmup, попробуйте уменьшить базовый LR в 2–10 раз. Для LLM типичные значения: 1e-5 – 5e-5 для fine-tuning, 1e-4 – 3e-4 для pre-training.

5.5 Weight decay (регуляризация)

Weight decay (L2-регуляризация) помогает предотвратить взрыв весов. Для AdamW weight_decay=0.01 – 0.1.

5.6 Gradient accumulation (накопление градиентов)

Позволяет увеличить эффективный batch size без роста памяти. Если spikes связаны с шумом градиента, накопление сглаживает обновления.


6. Precision issues: mixed precision и loss scaling

При обучении LLM (70B+) часто используют mixed precision (fp16 + fp32). Проблемы:

  • Overflow — градиенты становятся Inf (слишком большие для fp16).
  • Underflow — градиенты становятся 0 (слишком маленькие).

Решение loss scaling — умножение loss на коэффициент перед backward, чтобы градиенты не обнулялись. PyTorch AMP автоматически управляет scaling:

scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
    loss = model(inputs)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Если loss scale падает до 1 и градиенты всё ещё Inf — проблема не в precision, а в explosion. Если scale часто уменьшается — нужно увеличить init_scale или проверить данные.

Для моделей 70B+ также важно проверить масштабирование градиентов при использовании tensor parallelism (например, в Megatron-LM). Неправильное масштабирование между устройствами может вызвать spikes.


7. Специфика LLM: attention softmax overflow

В Transformer внимание вычисляется через softmax от скалярного произведения Q и K. Если логиты (значения до softmax) становятся большими (например, >100), softmax даёт экстремальные вероятности (почти 0 или 1), градиенты затухают или взрываются.

Причины

  • Большие веса в Q/K проекциях.
  • Длинные последовательности (скалярное произведение растёт с длиной).
  • Отсутствие scale factor (деление на sqrt(d_k)).

Диагностика логировать logit statistics (max, min, mean) перед softmax. Если max > 50–100 — риск overflow.

Решение

  • Использовать Flash Attention (автоматически масштабирует).
  • Добавить logit clipping (ограничение max значения).
  • Уменьшить initialization range для Q/K весов.
  • Проверить positional encoding (не даёт ли она большие значения).

8. Инструменты мониторинга

ИнструментОсобенности
Weights & Biases (W&B)Интерактивные дашборды, логирование градиентов, активаций, автоматические алерты при spikes
TensorBoardВстроен в PyTorch, лёгкий, но менее гибкий
MLflowХорош для экспериментов, но хуже для real-time мониторинга
Gradient Pulse (от W&B)Специализированный инструмент для анализа градиентов

Рекомендация в production используйте W&B с алертами на loss > threshold или gradient norm > max_norm.


9. Пет-проект для закрепления

Задача Обучить небольшую LLM (GPT-2, 124M) на синтетических данных (например, случайные последовательности токенов) с намеренно внесённой нестабильностью. Применить техники стабилизации и сравнить loss curves.

Инструменты PyTorch, Hugging Face Transformers, W&B (опционально), Jupyter Notebook.

Шаги:

  1. Baseline (нестабильный): learning rate = 1e-3, без gradient clipping, без warmup, batch size = 4. Обучите 500 шагов. Зафиксируйте spikes и divergence.
  2. Добавить gradient clipping max_norm=1.0. Запустите снова. Сравните loss и gradient norm.
  3. Добавить warmup linear warmup 200 steps. Запустите. Убедитесь, что divergence исчезла.
  4. Использовать mixed precision включите torch.cuda.amp. Проверьте, не появились ли spikes из-за loss scaling.
  5. Оптимальная конфигурация LR=1e-4, clipping=1.0, warmup=200, cosine decay. Обучите 2000 шагов. Постройте график loss.

Ожидаемый результат Вы увидите, как каждая техника влияет на стабильность. Baseline даст spikes и NaN, после добавления clipping и warmup loss станет гладким. Mixed precision может потребовать настройки loss scale.


10. Связь с другими вопросами

ВопросТема
484Как оценивать качество fine-tuning?
486Как выбирать learning rate?
487Как обрабатывать выбросы в данных?
488Как использовать mixed precision?
489Как дебажить переобучение?
490Как работает gradient checkpointing?

11. Навигация


Навигация