中文翻译暂不可用,显示俄语原文。
Что такое KL divergence и где она применяется в LLM (RLHF, distillation)?
Краткий тезис
KL divergence (control дивергенция Кульбака-Лейблера) — это асимметричная мера «расстояния» между двумя вероятностными распределениями. В контексте LLM она используется как штраф за отклонение от исходной модели (reference policy) в RLHF, как целевая функция при дистилляции знаний (минимизация KL(student || teacher)), а также как неявная регуляризация в DPO. Понимание penalty|KL divergence необходимо для анализа стабильности обучения и баланса между изучением нового и сохранением старого поведения.
1. Определение и формула KL divergence
KL divergence (также относительная энтропия) для дискретных распределений P и Q определяется как:
KL(P || Q) = Σ_x P(x) * log( P(x) / Q(x) )
Для непрерывных распределений сумма заменяется интегралом.
Ключевые свойства
- Неотрицательность: KL(P || Q) ≥ 0, равенство только если P = Q почти всюду.
- Асимметричность: KL(P || Q) ≠ KL(Q || P) в общем случае.
- Не является метрикой: не выполняется неравенство треугольника и симметрия.
Интуиция: KL(P || Q) измеряет, сколько информации теряется, когда мы используем Q для аппроксимации P. Чем больше расхождение, тем хуже Q приближает P.
Связь с кросс-энтропией:
KL(P || Q) = H(P, Q) - H(P)
где H(P, Q) — кросс-энтропия, H(P) — энтропия P. Минимизация KL эквивалентна минимизации кросс-энтропии, если P фиксировано (как в supervised learning).
2. KL divergence в RLHF (Reinforcement Learning from Human Feedback)
RLHF — метод тонкой настройки LLM с использованием сигнала от человека. Типичная pipeline:
- SFT на демонстрациях.
- Обучение reward model на человеческих предпочтениях.
- Оптимизация policy (LLM) с помощью PPO с KL penalty.
Зачем нужна KL divergence в RLHF?
- Без штрафа policy может быстро отклониться от исходной модели (reference policy), порождая бессмысленные или небезопасные тексты.
- KL penalty ограничивает отклонение, сохраняя «знания» из предобучения и SFT.
Формула KL penalty в PPO для RLHF:
reward_total = reward_model(x, y) - β * KL(π_ref || π_θ)
где:
π_θ— текущая policy (обучаемая модель),- π_ref — reference policy (замороженная SFT-модель),
β— гиперпараметр, контролирующий силу штрафа.
Вычисление KL на практике:
- Для каждого токена вычисляется log π_θ(token | context) - log π_ref(token | context).
- Усредняется по последовательности.
Эффект β:
| β | Поведение |
|---|---|
| Очень мал (< 0.01) | Policy почти не штрафуется, может сильно отклониться, риск «overoptimization» |
| Умеренный (0.01–0.1) | Хороший баланс между следованием наградам и сохранением исходного поведения |
| Велик (> 0.1) | Policy почти не меняется, обучение неэффективно |
3. KL divergence в дистилляции (Knowledge Distillation)
Дистилляция — перенос знаний от большой модели (teacher) к маленькой (student). Цель — минимизировать расхождение между распределениями выходов.
Формула дистилляции:
L_distill = α * KL(softmax(teacher_logits / T) || softmax(student_logits / T))
где:
T— температура (temperature), сглаживает распределения,α— вес дистилляционной потери.
Почему KL(student || teacher), а не наоборот?
- Мы хотим, чтобы student аппроксимировал teacher. KL(student || teacher) штрафует, когда student присваивает высокую вероятность там, где teacher — низкую (ложные уверенности). Обратная KL (teacher || student) штрафовала бы за пропуск мод teacher, что менее критично для генерации.
Пример кода (PyTorch):
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, temperature=2.0, alpha=0.5):
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
kl = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
return alpha * kl
4. KL divergence в DPO (Direct Preference Optimization)
DPO — альтернатива RLHF, не требующая отдельной reward model. DPO использует предпочтения напрямую для обновления policy.
Неявная KL регуляризация в DPO:
- DPO выводится из задачи максимизации ожидаемой награды с KL-штрафом относительно reference policy.
- Целевая функция DPO содержит член, эквивалентный KL(π_ref || π_θ), но встроенный в логистическую потерю.
Формально:
L_DPO = -E[ log σ( β * (log π_θ(y_w | x) - log π_ref(y_w | x) - log π_θ(y_l | x) + log π_ref(y_l | x)) ) ]
где y_w — предпочитаемый ответ, y_l — непредпочитаемый. Параметр β контролирует силу неявного KL-штрафа.
Сравнение RLHF (PPO) и DPO:
| Аспект | RLHF (PPO) | DPO |
|---|---|---|
| KL регуляризация | Явный штраф в reward | Неявный через форму потери |
| Сложность | Reward model + PPO | Только policy, без reward model |
| Стабильность | Чувствителен к β | Менее чувствителен, но β всё равно важен |
| Вычислительные затраты | Высокие (4 модели) | Низкие (2 модели: policy + ref) |
5. KL divergence в других контекстах LLM
- Variational Autoencoders (VAE): KL между апостериорным и априорным распределениями в латентном пространстве.
- Bayesian deep learning: KL для вариационного вывода.
- Evaluation of generative models: KL между распределением реальных данных и сгенерированных (FID использует KL между признаками).
- Prompt tuning: иногда KL штрафует отклонение soft prompt от исходного.
6. Практические аспекты вычисления KL
Численная стабильность:
- Используйте
log_softmaxиkl_divиз PyTorch/TensorFlow. - Избегайте прямого вычисления
log(P/Q)при нулевых вероятностях — добавляйте epsilon.
Выбор направления KL:
- В RLHF обычно
KL(π_ref || π_θ)(forward KL) — штрафует, когда π_θ присваивает высокую вероятность там, где π_ref низкую. - В дистилляции —
KL(student || teacher)(reverse KL) — штрафует за ложные уверенности student.
Мониторинг KL во время обучения:
- Следите за средним KL на батч. Если KL резко растёт — возможно, обучение нестабильно.
- В RLHF типичные значения KL после нескольких шагов PPO — 0.1–1.0 nats (натуральные единицы).
7. Связь с другими метриками
| Метрика | Формула | Связь с KL |
|---|---|---|
| Cross-Entropy | H(P, Q) = -Σ P log Q | KL = H(P, Q) - H(P) |
| JS divergence | 0.5*KL(P | |
| Total Variation | 0.5 * Σ | P - Q |
8. Ограничения KL divergence
- Асимметричность: выбор направления критичен.
- Чувствительность к хвостам: если Q имеет нулевую вероятность там, где P положительна, KL уходит в бесконечность.
- Не является метрикой: не удовлетворяет неравенству треугольника.
- Интерпретация: значения KL не имеют абсолютной шкалы, только относительное сравнение.
Пет-проект для закрепления
Задача: Реализовать KL-штраф для fine-tuning небольшой LLM (например, GPT-2) на задаче генерации с предпочтениями (имитация RLHF).
Инструменты: Python, PyTorch, Hugging Face Transformers, TRL (Transformer Reinforcement Learning) или ручная реализация.
Шаги:
- Загрузите предобученную GPT-2 (reference model) и создайте копию (policy model).
- Сгенерируйте датасет предпочтений: для каждого промпта два ответа (хороший/плохой) вручную или через rule-based.
- Реализуйте PPO с KL penalty:
- Обучите policy несколько шагов, отслеживая средний KL.
- Сравните с baseline (без KL penalty) — проверьте, что модель не «сходит с ума».
Ожидаемый результат: Вы увидите, что без KL penalty модель быстро начинает генерировать бессвязный текст, а с KL penalty сохраняет осмысленность, постепенно адаптируясь к предпочтениям.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 650 | RLHF: полный цикл обучения с human feedback |
| 651 | PPO: алгоритм оптимизации policy |
| 652 | DPO: direct preference optimization |
| 653 | Knowledge distillation: методы сжатия моделей |
| 654 | Reward hacking и overoptimization в RLHF |
| 655 | Temperature sampling и его влияние на распределение |
Навигация
- Предыдущий: 656
- Следующий: 658
- Индекс: 00. Индекс разборов