Как обучается Reward Model (RM)? Вход: (prompt, answer_chosen, answer_rejected), выход: скаляр.
Краткий тезис
Reward Model (RM) — это нейросеть, которая учится предсказывать человеческие предпочтения между парами ответов. Она принимает на вход триплет (prompt, answer_chosen, answer_rejected) и выдаёт скалярный reward, отражающий, насколько ответ соответствует предпочтениям человека. Обучение строится на pairwise сравнениях с использованием Bradley-Terry loss, что позволяет модели ранжировать ответы без абсолютных меток. Архитектурно RM — это языковая модель с линейной головой регрессии, дообученная на данных предпочтений. Ключевой вызов — калибровка: модель должна не только правильно ранжировать, но и давать согласованные оценки в разных контекстах.
2. Потеря: Bradley-Terry (σ(r_chosen - r_rejected))
Обучение RM формулируется как задача бинарной классификации: для пары (answer_chosen, answer_rejected) модель должна предсказать, что chosen предпочтительнее. Используется вероятностная модель Bradley-Terry:
[ P([text](/wiki/text){chosen} \succ [text](/wiki/text){rejected}) = \sigma(r_{[text](/wiki/text){chosen}} - r_{[text](/wiki/text){rejected}}) ]
где (\sigma) — сигмоида, (r_{[text](/wiki/text){chosen}}) и (r_{[text](/wiki/text){rejected}}) — скалярные выходы RM для двух ответов.
Функция потерь — отрицательное логарифмическое правдоподобие (binary cross-entropy):
[ \mathcal{L} = -\log \sigma(r_{[text](/wiki/text){chosen}} - r_{[text](/wiki/text){rejected}}) ]
Интуиция:
- Если модель правильно оценивает, что
chosenлучше, разность положительна, сигмоида близка к 1, лосс мал. - Если модель ошибается (разность отрицательна), лосс большой, градиент толкает веса так, чтобы увеличить разрыв.
Почему Bradley-Terry?
- Не требует абсолютных меток (например, "ответ 4.5 балла"), только относительные предпочтения.
- Легко дифференцируема, совместима с SGD.
- Может быть обобщена на множественное ранжирование (Plackett-Luce), но на практике pairwise достаточно.
Пример кода (PyTorch):
import torch.nn.functional as F
def bradley_terry_loss(r_chosen, r_rejected):
# r_chosen, r_rejected: (batch_size,)
logits = r_chosen - r_rejected # (batch_size,)
loss = -F.logsigmoid(logits).mean()
return loss
3. Можно обучать на pairwise или ranking
Существует два основных режима сбора данных и обучения RM:
3.1 Pairwise (бинарные сравнения)
- Данные: для каждого промпта собирается ровно одна пара
(chosen, rejected). - Преимущества: простота сбора (человеку легче выбрать лучший из двух ответов, чем оценить каждый по шкале).
- Недостатки: теряется информация о том, насколько один ответ лучше другого (только порядок).
- Обучение: каждая пара даёт один лосс, батч состоит из множества независимых пар.
3.2 Ranking (множественное ранжирование)
- Данные: для одного промпта собирается список из (K) ответов, отсортированных по предпочтению (например, рейтинг 1..K).
- Преимущества: более информативно — модель учится различать не только соседние ранги, но и дистанции.
- Потеря: используется обобщение Bradley-Terry — Plackett-Luce или ListMLE.
- На практике: часто преобразуют ranking в несколько pairwise пар (каждый ответ сравнивается с каждым нижестоящим), что даёт (O(K^2)) пар. Это увеличивает количество обучающих примеров, но может привести к дисбалансу.
Рекомендация: для большинства сценариев достаточно pairwise, так как он проще и менее требователен к данным. Ranking полезен, когда нужно тонкое ранжирование (например, в соревновательных задачах).
4. Важно: калибровка
Калибровка Reward Model — критический аспект, влияющий на стабильность PPO и качество финальной политики.
Проблемы:
- Дрейф шкалы: RM может со временем выдавать всё большие или меньшие значения, что нарушает обучение RL (reward hacking).
- Несогласованность: для разных промптов модель может давать несопоставимые оценки (например, для сложного вопроса reward=10, для простого — 5, хотя оба ответа идеальны).
- Переобучение на артефакты: RM может выучить поверхностные признаки (длина ответа, наличие списков), а не реальное качество.
Методы калибровки:
- Нормализация: вычитать среднее и делить на стандартное отклонение reward в батче или по всей истории.
- Clipping: ограничить reward диапазоном (например, [-1, 1]) для стабильности PPO.
- Регуляризация: добавить L2-штраф на веса головы или использовать dropout.
- Аугментация данных: добавлять примеры с «шумными» предпочтениями, чтобы модель не запоминала точные метки.
- Ensemble: обучать несколько RM и усреднять их выходы — снижает variance и улучшает калибровку (используется в DeepSeek-R1).
Практический совет: после обучения RM всегда проверяйте распределение reward на валидационном наборе — оно должно быть примерно нормальным с нулевым средним. Если среднее сильно смещено, применяйте пост-обработку.
5. Пет-проект для закрепления
Задача: Обучить Reward Model на синтетических данных предпочтений для задачи «полезный и безвредный ассистент».
Инструменты:
- Python, PyTorch / Hugging Face Transformers
- Датасет: Anthropic/hh-rlhf (реальные человеческие предпочтения) или синтезировать с помощью LLM (например, GPT-4 генерирует
chosenиrejected). - Базовая модель: DistilGPT2 (лёгкая, для быстрого прототипирования).
Шаги:
- Загрузите предобученный DistilGPT2 и добавьте линейную голову (один нейрон).
- Подготовьте датасет: для каждого промпта возьмите пару
(chosen, rejected), токенизируйтеprompt + answerс паддингом до одинаковой длины. - Реализуйте Bradley-Terry loss (см. код выше).
- Обучите модель на 1-2 эпохи с learning rate 1e-5, батч 8.
- Оцените accuracy на валидации: доля пар, где
r_chosen > r_rejected. - Визуализируйте распределение reward для chosen и rejected — они должны быть разделены.
Ожидаемый результат:
- Accuracy > 70% на валидации (для DistilGPT2 на hh-rlhf).
- Модель стабильно выдаёт более высокий reward для chosen ответов.
- Понимание, как калибровка влияет на разброс значений.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 329 | PPO использует reward от RM для обновления политики. |
| 983 | Общий контекст: RM — компонент пайплайна RLHF. |
| 985 | Процесс сбора пар для обучения RM. |
Навигация
- Предыдущий: 983
- Следующий: 985
- Индекс: 00. Индекс разборов