English translation is not available yet. Showing Russian content.
Как работает KL penalty в RLHF и как подобрать коэффициент?
Краткий тезис
penalty|KL penalty — это штраф за отклонение от reference модели, добавляемый к reward в RLHF. Он стабилизирует обучение, предотвращает переобучение reward model и сохраняет разнообразие генераций. Коэффициент β (обычно 0.01–0.1) подбирается через sweep на валидации по метрикам качества ответов и KL-расхождения.
1. Термины: RLHF, KL penalty, reference model, reward model
RLHF (Reinforcement Learning from Human Feedback) — метод дообучения LLM, при котором модель учится генерировать ответы, максимизирующие reward, полученный от обученной на человеческих предпочтениях reward model.
KL penalty — штраф за расхождение Кульбака-Лейблера (KL divergence) между текущей политикой (обучаемой моделью) и reference моделью (замороженной версией до RL-этапа). Добавляется к reward, чтобы ограничить отклонение.
Reference model — копия модели до начала RL-обучения (обычно после SFT). Её веса не обновляются, используется как якорь.
Reward model — модель, обученная предсказывать оценку (reward) для пары (prompt, response) на основе человеческих предпочтений. В RLHF она даёт сигнал для оптимизации.
2. Зачем нужен KL penalty в RLHF?
Без KL penalty модель может:
- Переобучиться под reward model: найти «дыры» в reward model и генерировать неестественные, но высокооцененные ответы (reward hacking).
- Потерять разнообразие: сойтись к узкому набору шаблонов, которые дают высокий reward.
- Забыть исходные знания: сильно отклониться от базового SFT-распределения, ухудшив качество на обычных запросах.
KL penalty решает эти проблемы, штрафуя за большие отклонения от reference модели. Это аналог регуляризации в supervised learning.
3. Математическая формулировка
Итоговая функция вознаграждения в RLHF:
reward_total = reward_model(prompt, response) - β * KL(P_θ || P_ref)
Где:
- reward_model(prompt, response) — оценка от reward model.
- β — коэффициент, контролирующий силу штрафа.
- KL(P_θ || P_ref) — KL divergence между распределением текущей модели (P_θ) и reference модели (P_ref) для данной генерации.
KL divergence вычисляется как:
KL(P_θ || P_ref) = Σ_t P_θ(token_t | context) * log(P_θ(token_t | context) / P_ref(token_t | context))
Суммируется по всем токенам ответа. Чем больше различаются вероятности токенов, тем выше штраф.
4. Как KL penalty работает на практике (в PPO)
В алгоритме PPO (Proximal Policy Optimization) KL penalty встраивается в loss-функцию. Стандартный подход — adaptive KL penalty, где β динамически подстраивается:
- Если KL слишком мал (< target_kl * 0.5) → β уменьшается (штраф слабее).
- Если KL слишком велик (> target_kl * 1.5) → β увеличивается (штраф сильнее).
Псевдокод шага PPO с KL penalty:
# Для каждого батча
for prompt, response in batch:
# Получаем логиты текущей и reference модели
logits_θ = model(prompt, response)
logits_ref = ref_model(prompt, response)
# Вычисляем KL divergence (по токенам)
kl = kl_divergence(logits_θ, logits_ref)
# Reward от reward model
r = reward_model(prompt, response)
# Итоговый reward с penalty
r_total = r - beta * kl
# PPO loss (clipped surrogate objective)
loss = -ppo_loss(logits_θ, r_total, old_logits)
# Обновление β (adaptive)
if kl < target_kl * 0.5:
beta *= 0.9
elif kl > target_kl * 1.5:
beta *= 1.1
5. Как подобрать коэффициент β?
β — гиперпараметр, который нельзя вычислить аналитически. Подбор проводится через sweep (перебор значений) на валидационном наборе.
Шаги подбора:
- Выбрать диапазон: обычно β ∈ [0.001, 1.0] с логарифмической шкалой (например, 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1.0).
- Обучить несколько моделей с разными β на фиксированном числе шагов (например, 1000 шагов PPO).
- Оценить каждую модель на валидационном наборе по трём метрикам:
- Reward (средний score от reward model) — должен расти.
- KL divergence (среднее KL между θ и ref) — не должен превышать target_kl (обычно 0.01–0.1 на токен).
- Качество ответов (человеческая оценка или автоматические метрики: BLEU, ROUGE, GPT-4 eval).
- Выбрать β, дающий наилучший баланс: высокий reward при умеренном KL и приемлемом качестве.
Типичный диапазон β
- β = 0.01–0.1 — рабочий диапазон для большинства задач.
- β < 0.01 — слабый штраф, риск reward hacking.
- β > 0.1 — сильный штраф, модель почти не отклоняется от reference, обучение неэффективно.
6. Влияние β на поведение модели
| β | KL divergence | Reward | Качество ответов | Риски |
|---|---|---|---|---|
| Очень малый (<0.01) | Высокий | Высокий (возможно, за счёт хаков) | Падает (неестественные ответы) | Reward hacking, потеря разнообразия |
| Оптимальный (0.01–0.1) | Умеренный (0.01–0.05 на токен) | Умеренно высокий | Хорошее | Сбалансировано |
| Очень большой (>0.1) | Низкий | Низкий (штраф доминирует) | Близко к reference | Слабое обучение, нет улучшений |
7. Альтернативы и сравнение
DPO (Direct Preference Optimization) — не использует KL penalty явно, но в loss заложено ограничение на отклонение от reference. Коэффициент β там тоже есть, но подбирается аналогично.
GRPO (Group Relative Policy Optimization) — использует групповые baseline и KL penalty, но β часто фиксирован (например, 0.04).
| Метод | KL penalty | Подбор β |
|---|---|---|
| PPO (RLHF) | Явный, adaptive | Sweep + adaptive |
| DPO | Неявный (в loss) | Sweep |
| GRPO | Явный, фиксированный | Sweep, но реже |
8. Пет-проект для закрепления
Задача Реализовать мини-RLHF с KL penalty для небольшой модели (GPT-2) на синтетических предпочтениях.
Инструменты Python, PyTorch, Hugging Face Transformers, TRL (Transformer Reinforcement Learning).
Шаги:
- Взять предобученный GPT-2 (SFT).
- Создать синтетическую reward model: простая функция, оценивающая длину ответа и наличие ключевых слов.
- Реализовать PPO с KL penalty (можно использовать
trl.PPOTrainer). - Провести sweep β = [0.001, 0.01, 0.1, 1.0] на 200 шагах обучения.
- Для каждого β построить графики: reward, KL, качество ответов (например, по метрике «доля ответов с ключевыми словами»).
Ожидаемый результат Вы увидите, что при β=0.001 модель быстро переобучается (высокий reward, но ответы состоят только из ключевых слов), при β=1.0 reward почти не растёт, а при β=0.01–0.1 достигается баланс.
9. Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 330 | Архитектура Agentic RAG |
| 331 | Tool use в агентах |
| 333 | PPO в RLHF |
| 334 | Обучение reward model |
| 335 | RLHF: полный пайплайн |
| 336 | DPO vs RLHF |
10. Навигация
- Предыдущий: 331
- Следующий: 333
- Индекс: 00. Индекс разборов
Навигация
- Предыдущий: 331
- Следующий: 333
- Индекс: 00. Индекс разборов