Как вы проверяете, что RLHF не сломал базовые способности модели?

Краткий тезис

RLHF (Reinforcement Learning from Human Feedback) — мощный метод выравнивания модели под предпочтения человека, но он может приводить к катастрофическому забыванию (catastrophic forgetting) базовых знаний и навыков. Проверка заключается в систематическом сравнении метрик модели до и после RLHF на наборах бенчмарков, измеряющих общие знания (MMLU), здравый смысл (HellaSwag), правдивость (TruthfulQA) и другие способности. Ухудшение более чем на 5% (или на величину, превышающую статистическую погрешность) считается тревожным сигналом и требует корректировки процесса обучения.


1. Термин: RLHF и его влияние на базовые способности

RLHF — это процесс дообучения языковой модели с использованием сигнала вознаграждения, полученного от модели предпочтений (reward model), обученной на человеческих оценках. Цель — сделать ответы модели более полезными, безвредными и честными (helpful, harmless, honest).

Однако при оптимизации под узкую цель (например, вежливость или отказ от вредных запросов) модель может «забыть» фактические знания, логику или грамматику, которые были заложены на этапе pre-training и SFT (Supervised Fine-Tuning). Это явление называется катастрофическим забыванием (catastrophic forgetting).

Почему это происходит

  • Reward hacking — модель находит короткие пути к высокому вознаграждению, игнорируя содержание.
  • Сужение распределения — RLHF смещает распределение ответов в сторону узкого набора «безопасных» паттернов.
  • Переобучение на reward model — модель начинает подстраиваться под артефакты reward model, а не под истинные предпочтения.

2. Бенчмарки для оценки базовых способностей

Для проверки используются стандартизированные тесты, покрывающие разные аспекты «интеллекта» модели. Ниже — ключевые бенчмарки.

БенчмаркЧто измеряетПример задачиТипичный формат
MMLU (Massive Multitask Language Understanding)Знания по 57 предметам (от математики до права)«Какая планета самая большая?»Multiple-choice (4 варианта)
HellaSwagЗдравый смысл и понимание причинно-следственных связейВыбрать логичное продолжение историиMultiple-choice
TruthfulQAПравдивость и склонность к галлюцинациям«Правда ли, что Земля плоская?»Генерация + оценка LLM-судьёй
ARC (AI2 Reasoning Challenge)Научные рассуждения (easy / challenge)Вопросы из школьной программыMultiple-choice
GSM8KМатематические рассуждения (цепочка шагов)«У Маши 5 яблок, у Пети — 3. Сколько всего?»Генерация ответа
BBH (BIG-Bench Hard)Сложные рассуждения (логика, планирование)Задачи на индуктивное умозаключениеMultiple-choice / генерация

Дополнительно для оценки безопасности используют Anthropic’s HH-RLHF и AdvBench, но они не входят в «базовые способности».


3. Процесс проверки: до/после RLHF

Проверка должна быть встроена в пайплайн обучения. Основные шаги:

3.1. Заморозка контрольной точки (checkpoint)

Перед началом RLHF сохраняем модель после SFT (или pre-training). Это baseline.

3.2. Запуск evaluation на бенчмарках

Для каждого бенчмарка вычисляем метрику (accuracy, F1, exact match). Используем одинаковые гиперпараметры инференса (температура, top-p, max tokens).

3.3. Периодическое тестирование во время RLHF

Каждые N шагов (например, 100 шагов PPO) повторяем evaluation на тех же бенчмарках. Это позволяет отследить момент, когда начинается забывание.

3.4. Финальное сравнение

Сравниваем метрики после RLHF с baseline. Если ухудшение превышает порог (обычно 5% относительного падения), это сигнал проблемы.

Формула для относительного изменения

Δ = (metric_after - metric_before) / metric_before * 100%

Пример:

  • MMLU до RLHF: 72.3%
  • MMLU после RLHF: 68.1%
  • Δ = (68.1 - 72.3) / 72.3 ≈ -5.8% → тревожный сигнал

4. Пороги и интерпретация

Δ (относительное падение)ИнтерпретацияДействие
< 2%Шум / допустимоПродолжать обучение
2–5%Умеренное ухудшениеУвеличить KL-штраф, добавить данные из SFT
> 5%Критическое забываниеОстановить обучение, пересмотреть гиперпараметры

Важно порог может зависеть от бенчмарка. Для MMLU падение на 5% — серьёзно, для TruthfulQA (где модель изначально может быть плохой) — менее критично.


5. Стратегии предотвращения катастрофического забывания

5.1. KL-дивергенция как штраф

В алгоритме PPO (Proximal Policy Optimization) добавляют штраф за отклонение от исходной политики (модели до RLHF). Коэффициент KL-штрафа (β) контролирует баланс между выравниванием и сохранением знаний.

# Псевдокод для PPO с KL-штрафом
kl_penalty = beta * kl_divergence(ref_logits, policy_logits)
reward = human_reward - kl_penalty

5.2. Смешивание данных (data mixing)

В каждой батче PPO перемешивают данные из RLHF-датасета и из оригинального SFT-датасета (например, 80% RLHF + 20% SFT). Это напоминает модель о «старых» задачах.

5.3. Gradual fine-tuning (постепенное дообучение)

Начинают с малого learning rate и постепенно увеличивают его, чтобы модель не делала резких шагов.

5.4. Early stopping

Мониторят метрики на валидационном наборе базовых бенчмарков и останавливают обучение при первых признаках ухудшения.

5.5. Использование референсной модели (reference model)

В PPO всегда хранят копию модели до RLHF (reference model) и считают KL-дивергенцию между текущей политикой и reference. Это стандартный приём.


6. Инструменты и библиотеки

ИнструментНазначение
HuggingFace Transformers + TRLРеализация PPO, SFT, evaluation
DeepSpeedОптимизация памяти для больших моделей
lm-evaluation-harnessСтандартизированный запуск бенчмарков (MMLU, HellaSwag и др.)
Weights & BiasesЛогирование метрик в реальном времени
RayРаспределённое обучение и evaluation

Пример команды для evaluation через lm-evaluation-harness:

lm_eval --model hf --model_args pretrained=my_model_after_rlhf \
        --tasks mmlu,hellaswag,truthfulqa \
        --batch_size 8 --output_path results.json

7. Пример кода: сравнение до/после

import json
from lm_eval import evaluator

def evaluate_model(model_name, tasks):
    results = evaluator.simple_evaluate(
        model="hf",
        model_args=f"pretrained={model_name}",
        tasks=tasks,
        batch_size=8
    )
    return {task: results["results"][task]["acc,none"] for task in tasks}

# Baseline
before = evaluate_model("my_model_sft", ["mmlu", "hellaswag"])
print("Before RLHF:", before)

# После RLHF
after = evaluate_model("my_model_rlhf", ["mmlu", "hellaswag"])
print("After RLHF:", after)

# Сравнение
for task in before:
    delta = (after[task] - before[task]) / before[task] * 100
    print(f"{task}: {before[task]:.2f}% -> {after[task]:.2f}% (Δ={delta:.2f}%)")
    if delta < -5:
        print("  ⚠️ Критическое ухудшение!")

8. Ограничения и нюансы

  • Перекрытие бенчмарков (contamination): если модель видела тестовые вопросы во время pre-training, метрики будут завышены. Используйте свежие версии бенчмарков (MMLU-Pro, HellaSwag v2).
  • Статистическая значимость разница в 1–2% может быть шумом. Запускайте evaluation несколько раз с разными seed и считайте доверительные интервалы.
  • Выбор порога для некоторых задач (например, TruthfulQA) ухудшение на 10% может быть приемлемым, если модель стала безопаснее. Оценивайте trade-off.
  • Мультимодальные модели для них добавляют бенчмарки вроде VQAv2, COCO Captions.

9. Пет-проект для закрепления

Задача Реализовать пайплайн проверки влияния RLHF на базовые способности небольшой модели (например, GPT-2 или Pythia-1B).

Инструменты Python, HuggingFace Transformers, TRL, lm-evaluation-harness, Weights & Biases.

Шаги:

  1. Взять предобученную модель (например, gpt2).
  2. Провести SFT на небольшом датасете инструкций (Alpaca).
  3. Сохранить baseline — запустить evaluation на MMLU и HellaSwag.
  4. Обучить reward model на датасете предпочтений (Anthropic HH-RLHF).
  5. Запустить PPO (из TRL) с разными значениями KL-штрафа (β = 0.1, 0.5, 1.0).
  6. После каждых 50 шагов PPO повторять evaluation.
  7. Построить график зависимости метрик от шага обучения.
  8. Сделать вывод: при каком β забывание минимально.

Ожидаемый результат Вы научитесь на практике отслеживать catastrophic forgetting и подбирать гиперпараметры RLHF для баланса между выравниванием и сохранением знаний.


10. Связь с другими вопросами

ВопросТема
335Что такое RLHF и зачем он нужен?
336Как работает PPO в контексте RLHF?
338Как оценивать качество reward model?
339Как обеспечить safety при RLHF?
340Что такое alignment tax и как его минимизировать?
341Как сравнивать разные алгоритмы выравнивания (DPO, RLHF)?

11. Навигация


Навигация