English translation is not available yet. Showing Russian content.

Как работает KL penalty в RLHF и как подобрать коэффициент?

Краткий тезис

penalty|KL penalty — это штраф за отклонение от reference модели, добавляемый к reward в RLHF. Он стабилизирует обучение, предотвращает переобучение reward model и сохраняет разнообразие генераций. Коэффициент β (обычно 0.01–0.1) подбирается через sweep на валидации по метрикам качества ответов и KL-расхождения.


1. Термины: RLHF, KL penalty, reference model, reward model

RLHF (Reinforcement Learning from Human Feedback) — метод дообучения LLM, при котором модель учится генерировать ответы, максимизирующие reward, полученный от обученной на человеческих предпочтениях reward model.

KL penalty — штраф за расхождение Кульбака-Лейблера (KL divergence) между текущей политикой (обучаемой моделью) и reference моделью (замороженной версией до RL-этапа). Добавляется к reward, чтобы ограничить отклонение.

Reference model — копия модели до начала RL-обучения (обычно после SFT). Её веса не обновляются, используется как якорь.

Reward model — модель, обученная предсказывать оценку (reward) для пары (prompt, response) на основе человеческих предпочтений. В RLHF она даёт сигнал для оптимизации.


2. Зачем нужен KL penalty в RLHF?

Без KL penalty модель может:

  • Переобучиться под reward model: найти «дыры» в reward model и генерировать неестественные, но высокооцененные ответы (reward hacking).
  • Потерять разнообразие: сойтись к узкому набору шаблонов, которые дают высокий reward.
  • Забыть исходные знания: сильно отклониться от базового SFT-распределения, ухудшив качество на обычных запросах.

KL penalty решает эти проблемы, штрафуя за большие отклонения от reference модели. Это аналог регуляризации в supervised learning.


3. Математическая формулировка

Итоговая функция вознаграждения в RLHF:

reward_total = reward_model(prompt, response) - β * KL(P_θ || P_ref)

Где:

  • reward_model(prompt, response) — оценка от reward model.
  • β — коэффициент, контролирующий силу штрафа.
  • KL(P_θ || P_ref)KL divergence между распределением текущей модели (P_θ) и reference модели (P_ref) для данной генерации.

KL divergence вычисляется как:

KL(P_θ || P_ref) = Σ_t P_θ(token_t | context) * log(P_θ(token_t | context) / P_ref(token_t | context))

Суммируется по всем токенам ответа. Чем больше различаются вероятности токенов, тем выше штраф.


4. Как KL penalty работает на практике (в PPO)

В алгоритме PPO (Proximal Policy Optimization) KL penalty встраивается в loss-функцию. Стандартный подход — adaptive KL penalty, где β динамически подстраивается:

  • Если KL слишком мал (< target_kl * 0.5) → β уменьшается (штраф слабее).
  • Если KL слишком велик (> target_kl * 1.5) → β увеличивается (штраф сильнее).

Псевдокод шага PPO с KL penalty:

# Для каждого батча
for prompt, response in batch:
    # Получаем логиты текущей и reference модели
    logits_θ = model(prompt, response)
    logits_ref = ref_model(prompt, response)
    
    # Вычисляем KL divergence (по токенам)
    kl = kl_divergence(logits_θ, logits_ref)
    
    # Reward от reward model
    r = reward_model(prompt, response)
    
    # Итоговый reward с penalty
    r_total = r - beta * kl
    
    # PPO loss (clipped surrogate objective)
    loss = -ppo_loss(logits_θ, r_total, old_logits)
    
    # Обновление β (adaptive)
    if kl < target_kl * 0.5:
        beta *= 0.9
    elif kl > target_kl * 1.5:
        beta *= 1.1

5. Как подобрать коэффициент β?

β — гиперпараметр, который нельзя вычислить аналитически. Подбор проводится через sweep (перебор значений) на валидационном наборе.

Шаги подбора:

  1. Выбрать диапазон: обычно β ∈ [0.001, 1.0] с логарифмической шкалой (например, 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1.0).
  2. Обучить несколько моделей с разными β на фиксированном числе шагов (например, 1000 шагов PPO).
  3. Оценить каждую модель на валидационном наборе по трём метрикам:
    • Reward (средний score от reward model) — должен расти.
    • KL divergence (среднее KL между θ и ref) — не должен превышать target_kl (обычно 0.01–0.1 на токен).
    • Качество ответов (человеческая оценка или автоматические метрики: BLEU, ROUGE, GPT-4 eval).
  4. Выбрать β, дающий наилучший баланс: высокий reward при умеренном KL и приемлемом качестве.

Типичный диапазон β

  • β = 0.01–0.1 — рабочий диапазон для большинства задач.
  • β < 0.01 — слабый штраф, риск reward hacking.
  • β > 0.1 — сильный штраф, модель почти не отклоняется от reference, обучение неэффективно.

6. Влияние β на поведение модели

βKL divergenceRewardКачество ответовРиски
Очень малый (<0.01)ВысокийВысокий (возможно, за счёт хаков)Падает (неестественные ответы)Reward hacking, потеря разнообразия
Оптимальный (0.01–0.1)Умеренный (0.01–0.05 на токен)Умеренно высокийХорошееСбалансировано
Очень большой (>0.1)НизкийНизкий (штраф доминирует)Близко к referenceСлабое обучение, нет улучшений

7. Альтернативы и сравнение

DPO (Direct Preference Optimization) — не использует KL penalty явно, но в loss заложено ограничение на отклонение от reference. Коэффициент β там тоже есть, но подбирается аналогично.

GRPO (Group Relative Policy Optimization) — использует групповые baseline и KL penalty, но β часто фиксирован (например, 0.04).

МетодKL penaltyПодбор β
PPO (RLHF)Явный, adaptiveSweep + adaptive
DPOНеявный (в loss)Sweep
GRPOЯвный, фиксированныйSweep, но реже

8. Пет-проект для закрепления

Задача Реализовать мини-RLHF с KL penalty для небольшой модели (GPT-2) на синтетических предпочтениях.

Инструменты Python, PyTorch, Hugging Face Transformers, TRL (Transformer Reinforcement Learning).

Шаги:

  1. Взять предобученный GPT-2 (SFT).
  2. Создать синтетическую reward model: простая функция, оценивающая длину ответа и наличие ключевых слов.
  3. Реализовать PPO с KL penalty (можно использовать trl.PPOTrainer).
  4. Провести sweep β = [0.001, 0.01, 0.1, 1.0] на 200 шагах обучения.
  5. Для каждого β построить графики: reward, KL, качество ответов (например, по метрике «доля ответов с ключевыми словами»).

Ожидаемый результат Вы увидите, что при β=0.001 модель быстро переобучается (высокий reward, но ответы состоят только из ключевых слов), при β=1.0 reward почти не растёт, а при β=0.01–0.1 достигается баланс.


9. Связь с другими вопросами

ВопросТема
330Архитектура Agentic RAG
331Tool use в агентах
333PPO в RLHF
334Обучение reward model
335RLHF: полный пайплайн
336DPO vs RLHF

10. Навигация


Навигация