English translation is not available yet. Showing Russian content.

Реализовать process reward model (PRM)

ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Реализовать process reward model (PRM)

1. Цель задачи

Разработать и обучить process reward model (PRM) — модель, которая присваивает численную оценку каждому шагу логического рассуждения (reasoning step), а не только финальному ответу. Это позволяет использовать пошаговое обучение (step-level training) в RL/RLHF и улучшает качество multi-step reasoning. Ключевой результат Работающая PRM, способная различать правильные и неправильные промежуточные шаги с точностью >80% на тестовом наборе reasoning-траекторий.

2. Исходные данные

Что нужноОткуда взять
Датасет с пошаговыми решениямиGSM8K (train + test) или MATH (выбрать подмножество, ~200–500 задач)
Разметка правильности каждого шагаСгенерировать автоматически с помощью LLM (GPT-4 или DeepSeek)
Базовый язык LLM для синтеза данныхOpenAI API / DeepSeek API / HuggingFace (например, Mistral-7B-Instruct)
Вычислительный ресурс (GPU)Локальный (24GB+) или облачный (Colab Pro, RunPod)
Фреймворк для обученияPyTorch + HuggingFace Transformers + TRL (для PPO/KTO) или кастомный loss

Если нет реального инструмента — симулируем:

  1. Выберите 50 задач из GSM8K (например, 3–5-шаговые).
  2. Для каждой задачи сгенерируйте 5–10 рассуждений (trajectories) с помощью LLM (через API). Пусть модель пишет ответ с <step>...</step>.
  3. Для каждого шага автоматически пометьте correct/incorrect по правилу: шаг считается правильным, если цепочка после него ведёт к верному финальному ответу (т.е. все последующие шаги верны и ответ совпадает с золотым). Шаги, после которых ответ неверен, помечаются как неправильные (даже если шаг выглядит разумно). Это даст зашумлённые, но полезные метки.
  4. Соберите итоговый датасет в формате JSONL: каждая строка — {"question": "...", "steps": [{"text": "...", "label": 0/1}, ...], "final_answer": "..."}.

3. Технологический стек

КомпонентИнструментыНазначение
Язык программированияPython 3.10+Реализация модели и скриптов
Глубокое обучениеPyTorch 2.xОбучение PRM
HuggingFace Transformerstransformers, tokenizersЗагрузка базовой модели (например, DeepSeek-Math-Base-7B)
Библиотека для RLHFTRL (Transformer Reinforcement Learning)Использование готовых loss-функций (PPO, KTO) или написание своей
Векторное представление шаговPyTorch + mean poolingАгрегация скрытых состояний шагов в одно число
ЛогированиеWeights & Biases / MLflowОтслеживание метрик обучения
Генерация данныхOpenAI API / DeepSeek APIСинтез пошаговых рассуждений
Тестированиеpytest, evaluateПроверка корректности кода и метрик

4. Этапы выполнения

Этап 1: Подготовка данных с пошаговой разметкой (2 часа)

Действия

  1. Загрузить GSM8K с HuggingFace (datasets.load_dataset("gsm8k", "main")).
  2. Написать промпт для LLM-генератора, который получает вопрос и золотой ответ и выдаёт пошаговое решение в формате:
    <step>Первый шаг.</step>
    <step>Второй шаг.</step>
    ...
    <answer>Числовой ответ</answer>
    
  3. Выбрать 200 задач (100 из train, 100 из test). Для каждой задачи сгенерировать 3 траектории (с temperature=0.7). Сохранить в CSV с колонками: question, trajectory, final_llm_answer, gold_answer.
  4. Написать скрипт разметки:
    • Разбить траекторию на шаги по тегу <step>.
    • Если final_llm_answer совпадает с gold_answer — все шаги мечаются как correct (1).
    • Иначе: идти с конца, находить первый шаг, после которого ответ стал неверным (сравнивая результат подцепочки с золотым). Все шаги до этого шага — correct, этот шаг и следующие — incorrect (0).
    • Альтернатива: использовать другой LLM как judge (например, GPT-4-mini) для оценки каждого шага. Это точнее, но дороже.
  5. Сформировать датасет PRM в виде словаря: {"input_ids": tokens_question, "step_spans": [(start,end,label), ...]}.

Ожидаемый результат этапа Датасет из ~200–600 траекторий с пошаговыми метками (в формате HuggingFace Dataset или JSONL). Пример записи:

{
  "question": "Если 3 яблока стоят 5 долларов, сколько стоят 12 яблок?",
  "steps": [
    {"text": "Найдем цену одного яблока: 5 / 3 ≈ 1.667", "label": 1},
    {"text": "Теперь умножим на 12: 1.667 * 12 = 20", "label": 1}
  ],
  "gold_answer": "20"
}

Этап 2: Архитектура process reward model (2 часа)

Действия

  1. Выбрать базовую модель: deepseek-ai/deepseek-math-7b-base (хороша для math reasoning) или meta-llama/Llama-3.2-1B (для быстрого прототипа). Загрузить через HuggingFace.
  2. Реализовать класс ProcessRewardModel(nn.Module):
    • Загрузить base LM без головы (only embeddings + transformer layers).
    • На последнем слое повесить линейный projection hidden_size → 1 (scalar reward per token).
    • Написать forward, который принимает input_ids, attention_mask, step_token_mask (0 не-шаги, 1 начало или конец шага).
    • Получить скрытые состояния (outputs.hidden_states[-1]).
    • Собрать скрытые состояния для позиций, соответствующих концу каждого шага (или mean pool по токенам шага). Получить вектор для каждого шага.
    • Применить projection → scalar.
    • Обучить с бинарной кросс-энтропией между предсказанным reward (logit) и меткой шага (0/1).
  3. Создать даталоадер, который батчит по задачам, паддит до максимальной длины, строит маску шагов.

Ожидаемый результат этапа Класс модели с forward, даталоадер, конфигурация обучения (lr=1e-5, batch=4, grad_accum=2).

Этап 3: Обучение PRM (3–4 часа на GPU, ~30 минут на маленькой модели)

Действия

  1. Написать скрипт train_prm.py:
    • Загрузить данные, токенизировать.
    • Инициализировать модель, оптимизатор AdamW, scheduler cosine.
    • Цикл обучения: для каждого батча вычислить loss = BCE(step_rewards, step_labels), backward, optimizer.step.
    • Логировать loss на каждом 10-м шаге, accuracy на уровне шагов.
    • Сохранять чекпоинты каждые 500 шагов.
  2. Обучить на train-датасете 5 эпох (early stopping по val loss).
  3. Во время обучения оценивать на валидационном датасете (10% от train) метрики:
    • Step accuracy доля шагов, где предсказанный reward > 0.5 совпадает с меткой.
    • Trajectory accuracy для каждой траектории сравнить средний reward с финальным ответом (правильный vs неправильный) — PRM должен давать более высокий средний reward для правильных траекторий.
  4. Сохранить лучшую модель в ./prm-best/.

Ожидаемый результат этапа Обученная PRM, логи в W&B, чекпоинты.

Этап 4: Оценка и анализ (1 час)

Действия

  1. Загрузить лучшую модель.
  2. На тестовом датасете (100 задач) вычислить:
    • Step accuracy >80% .
    • PRM score correlation со значением верности финального ответа (корреляция Пирсона между средним reward шагов и binary correct).
    • Win rate в сравнении с outcome reward model (ORM): для каждой задачи сгенерировать 5 траекторий, ранжировать по PRM и ORM. Проверить, какой ранжировщик чаще ставит правильную траекторию на первое место (лучше PRM, т.к. более гранулярный).
  3. Визуализировать несколько примеров: показать вопрос, шаги, предсказанные reward и истинные метки.

Ожидаемый результат этапа Отчёт с метриками, графиками, качественными примерами. Файл evaluation_report.md.

5. Критерии приемки (Definition of Done)

  • Датасет с пошаговой разметкой собран, содержит не менее 200 траекторий.
  • Код реализации PRM (модель, обучение, инференс) лежит в Git-репозитории.
  • Модель обучена, чекпоинт сохранён в открытом формате (HuggingFace или PyTorch).
  • На тестовом датасете step accuracy > 80%.
  • Средний reward PRM для правильных траекторий хотя бы на 0.3 выше, чем для неправильных (по шкале от 0 до 1).
  • При сравнении с ORM (baseline) PRM показывает лучший win rate (>55%).
  • Есть скрипт inference.py, который по вопросу выводит список шагов и их reward.
  • В репозитории есть README с инструкцией по воспроизведению.
  • Все зависимости зафиксированы в requirements.txt или pyproject.toml.

6. Ожидаемый результат

  • Основной артефакт Python-пакет с модулями data/, model/, train.py, evaluate.py.
  • Файл чекпоинта prm-best/pytorch_model.bin (или LoRA-адаптеры).
  • Отчёт evaluation_report.md с таблицами метрик (step accuracy, correlation, win rate) и тремя качественными примерами.
  • Опционально Сравнительный анализ с outcome reward model (если реализована ORM на том же датасете).

7. Возможные сложности и их решение

СложностьРешение
Зашумлённость автоматически размеченных шаговИспользовать LLM-асессор (GPT-4) для перепроверки; использовать hard-label только когда цепочка однозначна; добавить soft-label (вероятность).
Большие различия в длине шагов (некоторые шаги длинные, другие короткие)Mean pooling по всем токенам шага; добавить регуляризацию (dropout).
Модель не обучается (loss не падает)Уменьшить learning rate (1e-5 → 1e-6), увеличить batch size; проверить правильность маски шагов; попробовать добавить голову с нормализацией (LayerNorm).
Нехватка памяти GPU (12GB)Использовать PEFT (LoRA) для base модели; уменьшить max_length (512 токенов); использовать gradient checkpointing.
Плохая корреляция с финальным ответомУвеличить датасет, добавить hard-negative mining (траектории с одним неправильным шагом); использовать loss, штрафующий за высокий reward для неправильных траекторий (contrastive).

8. Бюджет времени (оценка)

ЭтапВремя (часы)
Этап 1: Подготовка данных2.0
Этап 2: Архитектура модели2.0
Этап 3: Обучение3.5
Этап 4: Оценка1.0
Итого8.5
Примечание для первого разаУвеличьте на 2–3 часа на отладку и повторные эксперименты

9. Связанные вопросы из базы знаний

ВопросТема
45Что такое reward model в контексте RLHF?
48Отличие outcome reward model от process reward model
52Как строится step-level supervision?
101Синтез пошаговых данных с помощью LLM
207Введение в PyTorch для обучения моделей
310Использование HuggingFace Transformers для Sequence Classification
405PPO для языковых моделей (статья)
508Оценка качества reward model (correlation, accuracy)
612Уменьшение памяти: gradient checkpointing и LoRA
789Пайплайн обучения пошаговой награды для математики (Math PRM)

10. Чек-лист самопроверки

  • Я корректно извлёк и разметил шаги из траекторий (формат step-масок корректен).
  • Модель forward возвращает scalar reward для каждого шага, а не для каждого токена.
  • Функция потерь считает BCE по шагам, а не по токенам.
  • При инференсе я могу подать question и получить список шагов с reward`ами.
  • Я сравнил предсказания PRM с наивной стратегией «все шаги — 1» и убедился, что метрики лучше.
  • Репозиторий содержит README с примером запуска и воспроизводимыми результатами.
  • Я зафиксировал используемые версии библиотек в requirements.txt.