English translation is not available yet. Showing Russian content.

Что такое KL divergence и где она применяется в LLM (RLHF, distillation)?

Краткий тезис

KL divergence (control дивергенция Кульбака-Лейблера) — это асимметричная мера «расстояния» между двумя вероятностными распределениями. В контексте LLM она используется как штраф за отклонение от исходной модели (reference policy) в RLHF, как целевая функция при дистилляции знаний (минимизация KL(student || teacher)), а также как неявная регуляризация в DPO. Понимание penalty|KL divergence необходимо для анализа стабильности обучения и баланса между изучением нового и сохранением старого поведения.


1. Определение и формула KL divergence

KL divergence (также относительная энтропия) для дискретных распределений P и Q определяется как:

KL(P || Q) = Σ_x P(x) * log( P(x) / Q(x) )

Для непрерывных распределений сумма заменяется интегралом.

Ключевые свойства

  • Неотрицательность: KL(P || Q) ≥ 0, равенство только если P = Q почти всюду.
  • Асимметричность: KL(P || Q) ≠ KL(Q || P) в общем случае.
  • Не является метрикой: не выполняется неравенство треугольника и симметрия.

Интуиция: KL(P || Q) измеряет, сколько информации теряется, когда мы используем Q для аппроксимации P. Чем больше расхождение, тем хуже Q приближает P.

Связь с кросс-энтропией:

KL(P || Q) = H(P, Q) - H(P)

где H(P, Q) — кросс-энтропия, H(P) — энтропия P. Минимизация KL эквивалентна минимизации кросс-энтропии, если P фиксировано (как в supervised learning).


2. KL divergence в RLHF (Reinforcement Learning from Human Feedback)

RLHF — метод тонкой настройки LLM с использованием сигнала от человека. Типичная pipeline:

  1. SFT на демонстрациях.
  2. Обучение reward model на человеческих предпочтениях.
  3. Оптимизация policy (LLM) с помощью PPO с KL penalty.

Зачем нужна KL divergence в RLHF?

  • Без штрафа policy может быстро отклониться от исходной модели (reference policy), порождая бессмысленные или небезопасные тексты.
  • KL penalty ограничивает отклонение, сохраняя «знания» из предобучения и SFT.

Формула KL penalty в PPO для RLHF:

reward_total = reward_model(x, y) - β * KL(π_ref || π_θ)

где:

  • π_θ — текущая policy (обучаемая модель),
  • π_ref — reference policy (замороженная SFT-модель),
  • β — гиперпараметр, контролирующий силу штрафа.

Вычисление KL на практике:

  • Для каждого токена вычисляется log π_θ(token | context) - log π_ref(token | context).
  • Усредняется по последовательности.

Эффект β:

βПоведение
Очень мал (< 0.01)Policy почти не штрафуется, может сильно отклониться, риск «overoptimization»
Умеренный (0.01–0.1)Хороший баланс между следованием наградам и сохранением исходного поведения
Велик (> 0.1)Policy почти не меняется, обучение неэффективно

3. KL divergence в дистилляции (Knowledge Distillation)

Дистилляция — перенос знаний от большой модели (teacher) к маленькой (student). Цель — минимизировать расхождение между распределениями выходов.

Формула дистилляции:

L_distill = α * KL(softmax(teacher_logits / T) || softmax(student_logits / T))

где:

  • T — температура (temperature), сглаживает распределения,
  • α — вес дистилляционной потери.

Почему KL(student || teacher), а не наоборот?

  • Мы хотим, чтобы student аппроксимировал teacher. KL(student || teacher) штрафует, когда student присваивает высокую вероятность там, где teacher — низкую (ложные уверенности). Обратная KL (teacher || student) штрафовала бы за пропуск мод teacher, что менее критично для генерации.

Пример кода (PyTorch):

import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, temperature=2.0, alpha=0.5):
    soft_student = F.log_softmax(student_logits / temperature, dim=-1)
    soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
    kl = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
    return alpha * kl

4. KL divergence в DPO (Direct Preference Optimization)

DPO — альтернатива RLHF, не требующая отдельной reward model. DPO использует предпочтения напрямую для обновления policy.

Неявная KL регуляризация в DPO:

  • DPO выводится из задачи максимизации ожидаемой награды с KL-штрафом относительно reference policy.
  • Целевая функция DPO содержит член, эквивалентный KL(π_ref || π_θ), но встроенный в логистическую потерю.

Формально:

L_DPO = -E[ log σ( β * (log π_θ(y_w | x) - log π_ref(y_w | x) - log π_θ(y_l | x) + log π_ref(y_l | x)) ) ]

где y_w — предпочитаемый ответ, y_l — непредпочитаемый. Параметр β контролирует силу неявного KL-штрафа.

Сравнение RLHF (PPO) и DPO:

АспектRLHF (PPO)DPO
KL регуляризацияЯвный штраф в rewardНеявный через форму потери
СложностьReward model + PPOТолько policy, без reward model
СтабильностьЧувствителен к βМенее чувствителен, но β всё равно важен
Вычислительные затратыВысокие (4 модели)Низкие (2 модели: policy + ref)

5. KL divergence в других контекстах LLM

  • Variational Autoencoders (VAE): KL между апостериорным и априорным распределениями в латентном пространстве.
  • Bayesian deep learning: KL для вариационного вывода.
  • Evaluation of generative models: KL между распределением реальных данных и сгенерированных (FID использует KL между признаками).
  • Prompt tuning: иногда KL штрафует отклонение soft prompt от исходного.

6. Практические аспекты вычисления KL

Численная стабильность:

  • Используйте log_softmax и kl_div из PyTorch/TensorFlow.
  • Избегайте прямого вычисления log(P/Q) при нулевых вероятностях — добавляйте epsilon.

Выбор направления KL:

  • В RLHF обычно KL(π_ref || π_θ) (forward KL) — штрафует, когда π_θ присваивает высокую вероятность там, где π_ref низкую.
  • В дистилляции — KL(student || teacher) (reverse KL) — штрафует за ложные уверенности student.

Мониторинг KL во время обучения:

  • Следите за средним KL на батч. Если KL резко растёт — возможно, обучение нестабильно.
  • В RLHF типичные значения KL после нескольких шагов PPO — 0.1–1.0 nats (натуральные единицы).

7. Связь с другими метриками

МетрикаФормулаСвязь с KL
Cross-EntropyH(P, Q) = -Σ P log QKL = H(P, Q) - H(P)
JS divergence0.5*KL(P
Total Variation0.5 * ΣP - Q

8. Ограничения KL divergence

  • Асимметричность: выбор направления критичен.
  • Чувствительность к хвостам: если Q имеет нулевую вероятность там, где P положительна, KL уходит в бесконечность.
  • Не является метрикой: не удовлетворяет неравенству треугольника.
  • Интерпретация: значения KL не имеют абсолютной шкалы, только относительное сравнение.

Пет-проект для закрепления

Задача: Реализовать KL-штраф для fine-tuning небольшой LLM (например, GPT-2) на задаче генерации с предпочтениями (имитация RLHF).

Инструменты: Python, PyTorch, Hugging Face Transformers, TRL (Transformer Reinforcement Learning) или ручная реализация.

Шаги:

  1. Загрузите предобученную GPT-2 (reference model) и создайте копию (policy model).
  2. Сгенерируйте датасет предпочтений: для каждого промпта два ответа (хороший/плохой) вручную или через rule-based.
  3. Реализуйте PPO с KL penalty:
    • Для каждого батча вычислите logits policy и reference.
    • Рассчитайте KL на токен: log_policy - log_ref.
    • Добавьте -β * KL к reward.
  4. Обучите policy несколько шагов, отслеживая средний KL.
  5. Сравните с baseline (без KL penalty) — проверьте, что модель не «сходит с ума».

Ожидаемый результат: Вы увидите, что без KL penalty модель быстро начинает генерировать бессвязный текст, а с KL penalty сохраняет осмысленность, постепенно адаптируясь к предпочтениям.


Связь с другими вопросами

ВопросТема
650RLHF: полный цикл обучения с human feedback
651PPO: алгоритм оптимизации policy
652DPO: direct preference optimization
653Knowledge distillation: методы сжатия моделей
654Reward hacking и overoptimization в RLHF
655Temperature sampling и его влияние на распределение

Навигация