Как вы дебажите training instability (loss spikes, divergence)?
Краткий тезис
Training instability — это ситуация, когда loss ведёт себя непредсказуемо: резкие скачки (spikes) или уход в бесконечность (divergence). Причины лежат в гиперпараметрах (слишком высокий learning rate, отсутствие gradient clipping), данных (NaN, выбросы), precision (переполнение в fp16) или архитектуре (softmax overflow в LLM). Диагностика требует логирования градиентных норм, активаций и loss-кривых, а стабилизация достигается комбинацией warmup, gradient clipping, корректировки learning rate и использования precision precision precision training|mixed precision с loss scaling.
1. Термин: Training Instability (нестабильность обучения)
Training instability — это состояние, при котором loss-функция не сходится гладко, а демонстрирует аномальное поведение. Основные симптомы:
- Loss spikes — резкие кратковременные скачки loss (например, с 2.5 до 15.0 за один шаг).
- Divergence — loss неуклонно растёт или уходит в NaN/Inf, обучение «взрывается».
- Oscillations — loss колеблется с большой амплитудой, не снижаясь.
Нестабильность критична для LLM (особенно 70B+), так как один сбойный шаг может испортить всю контрольную точку (checkpoint). В контексте Agentic RAG instability может возникнуть при fine-tuning агентов (RL, PPO) или при обучении ретриверов.
2. Причины loss spikes
2.1 Слишком высокий learning rate
Если learning rate (скорость обучения) превышает оптимальное значение, градиенты делают слишком большие шаги, перескакивая минимум. Это приводит к резкому росту loss.
Диагностика построить график learning rate и loss — spike совпадает с моментом, когда LR ещё не снизился (например, после warmup).
2.2 Gradient explosion (взрыв градиентов)
Gradient explosion — когда норма градиента становится огромной (например, >1000). Это происходит из-за глубоких сетей (RNN, Transformer) или плохой инициализации. В LLM часто связано с attention softmax (см. раздел 7).
Диагностика логировать gradient norm (норму градиента) — если она резко возрастает перед spike, причина в explosion.
2.3 Проблемы с данными
- NaN в данных — пропуски, деление на ноль.
- Outliers — экстремальные значения в loss (например, токен с очень низкой вероятностью).
- Некорректная маска — attention mask, пропускающая pad-токены, может давать неправильные логиты.
3. Причины divergence (расходимости)
3.1 Отсутствие warmup
Warmup — постепенное увеличение learning rate с нуля до целевого значения в начале обучения. Без warmup модель сразу получает большой градиентный шаг, что может привести к divergence, особенно при использовании AdamW (адаптивный оптимизатор с коррекцией момента).
3.2 Плохая инициализация
Если веса инициализированы слишком большими значениями, активации насыщаются (например, tanh или sigmoid), градиенты затухают или взрываются. Для Transformer стандарт — инициализация из Xavier или Kaiming с учётом глубины.
3.3 Высокий batch size
При увеличении batch size (размер батча) оценка градиента становится точнее, но если learning rate не скорректирован (линейное масштабирование), шаги могут быть слишком большими. Это вызывает divergence.
3.4 Проблемы с precision (точностью)
При использовании mixed precision (fp16) числа с плавающей точкой могут переполняться (overflow) или терять точность (underflow). Если loss scaling не настроен, градиенты обнуляются или становятся Inf.
4. Диагностика: что логировать
Для эффективного дебага нужно логировать не только loss, но и внутренние состояния модели. Основные метрики:
| Метрика | Что показывает | Как логировать |
|---|---|---|
| Loss | Основной сигнал нестабильности | Каждый шаг или каждые N шагов |
| Gradient norm | Взрыв градиентов | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) до клиппинга |
| Gradient variance | Нестабильность направления | Среднее и std градиентов по слоям |
| Activation statistics | Насыщение нейронов | Среднее, std, min/max активаций (например, после attention) |
| Learning rate | Связь spikes с LR | Текущее значение LR |
| Loss scale (для mixed precision) | Переполнение fp16 | scaler.get_scale() из torch.cuda.amp |
| NaN/Inf count | Проблемы с данными или вычислениями | Проверка каждого тензора на NaN/Inf |
Инструменты Weights & Biases (W&B), TensorBoard, MLflow. В W&B можно построить панель с loss, gradient norm и LR на одном графике.
5. Техники стабилизации
5.1 Gradient clipping (клиппинг градиентов)
Ограничивает норму градиента сверху. Стандартное значение max_norm=1.0 для LLM.
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Когда помогает при gradient explosion. Если spikes исчезают после клиппинга — причина в explosion.
5.2 Warmup steps
Постепенное увеличение LR от 0 до целевого за warmup_steps (обычно 500–2000 шагов). Реализация через linear schedule:
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=total_steps)
Когда помогает при divergence в начале обучения.
5.3 Learning rate schedule (расписание LR)
После warmup обычно используют cosine decay или linear decay. Cosine decay более плавный, снижает риск spikes на поздних этапах.
5.4 Уменьшение learning rate
Если spikes появляются после warmup, попробуйте уменьшить базовый LR в 2–10 раз. Для LLM типичные значения: 1e-5 – 5e-5 для fine-tuning, 1e-4 – 3e-4 для pre-training.
5.5 Weight decay (регуляризация)
Weight decay (L2-регуляризация) помогает предотвратить взрыв весов. Для AdamW weight_decay=0.01 – 0.1.
5.6 Gradient accumulation (накопление градиентов)
Позволяет увеличить эффективный batch size без роста памяти. Если spikes связаны с шумом градиента, накопление сглаживает обновления.
6. Precision issues: mixed precision и loss scaling
При обучении LLM (70B+) часто используют mixed precision (fp16 + fp32). Проблемы:
- Overflow — градиенты становятся Inf (слишком большие для fp16).
- Underflow — градиенты становятся 0 (слишком маленькие).
Решение loss scaling — умножение loss на коэффициент перед backward, чтобы градиенты не обнулялись. PyTorch AMP автоматически управляет scaling:
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
loss = model(inputs)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Если loss scale падает до 1 и градиенты всё ещё Inf — проблема не в precision, а в explosion. Если scale часто уменьшается — нужно увеличить init_scale или проверить данные.
Для моделей 70B+ также важно проверить масштабирование градиентов при использовании tensor parallelism (например, в Megatron-LM). Неправильное масштабирование между устройствами может вызвать spikes.
7. Специфика LLM: attention softmax overflow
В Transformer внимание вычисляется через softmax от скалярного произведения Q и K. Если логиты (значения до softmax) становятся большими (например, >100), softmax даёт экстремальные вероятности (почти 0 или 1), градиенты затухают или взрываются.
Причины
- Большие веса в Q/K проекциях.
- Длинные последовательности (скалярное произведение растёт с длиной).
- Отсутствие scale factor (деление на sqrt(d_k)).
Диагностика логировать logit statistics (max, min, mean) перед softmax. Если max > 50–100 — риск overflow.
Решение
- Использовать Flash Attention (автоматически масштабирует).
- Добавить logit clipping (ограничение max значения).
- Уменьшить initialization range для Q/K весов.
- Проверить positional encoding (не даёт ли она большие значения).
8. Инструменты мониторинга
| Инструмент | Особенности |
|---|---|
| Weights & Biases (W&B) | Интерактивные дашборды, логирование градиентов, активаций, автоматические алерты при spikes |
| TensorBoard | Встроен в PyTorch, лёгкий, но менее гибкий |
| MLflow | Хорош для экспериментов, но хуже для real-time мониторинга |
| Gradient Pulse (от W&B) | Специализированный инструмент для анализа градиентов |
Рекомендация в production используйте W&B с алертами на loss > threshold или gradient norm > max_norm.
9. Пет-проект для закрепления
Задача Обучить небольшую LLM (GPT-2, 124M) на синтетических данных (например, случайные последовательности токенов) с намеренно внесённой нестабильностью. Применить техники стабилизации и сравнить loss curves.
Инструменты PyTorch, Hugging Face Transformers, W&B (опционально), Jupyter Notebook.
Шаги:
- Baseline (нестабильный): learning rate = 1e-3, без gradient clipping, без warmup, batch size = 4. Обучите 500 шагов. Зафиксируйте spikes и divergence.
- Добавить gradient clipping
max_norm=1.0. Запустите снова. Сравните loss и gradient norm. - Добавить warmup linear warmup 200 steps. Запустите. Убедитесь, что divergence исчезла.
- Использовать mixed precision включите
torch.cuda.amp. Проверьте, не появились ли spikes из-за loss scaling. - Оптимальная конфигурация LR=1e-4, clipping=1.0, warmup=200, cosine decay. Обучите 2000 шагов. Постройте график loss.
Ожидаемый результат Вы увидите, как каждая техника влияет на стабильность. Baseline даст spikes и NaN, после добавления clipping и warmup loss станет гладким. Mixed precision может потребовать настройки loss scale.
10. Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 484 | Как оценивать качество fine-tuning? |
| 486 | Как выбирать learning rate? |
| 487 | Как обрабатывать выбросы в данных? |
| 488 | Как использовать mixed precision? |
| 489 | Как дебажить переобучение? |
| 490 | Как работает gradient checkpointing? |
11. Навигация
- Предыдущий: 484
- Следующий: 486
- Индекс: 00. Индекс разборов
Навигация
- Предыдущий: 484
- Следующий: 486
- Индекс: 00. Индекс разборов