English translation is not available yet. Showing Russian content.

Как работает алгоритм ReST (Reinforced Self-Training) и когда он лучше PPO?

Краткий тезис

ReST (Reinforced Self-Training) — это итеративный алгоритм дообучения языковых моделей, который чередует генерацию ответов, фильтрацию по reward (награде) и fine-tuning на отфильтрованных данных. Он проще PPO (Policy Optimization|Proximal Policy Optimization]]), так как не требует value network и сложной оптимизации, но уступает в sample efficiency. ReST лучше PPO в задачах с чётким, легко вычислимым reward (например, правильность ответа в математике или коде), где не нужен сложный credit assignment.


1. Что такое ReST (Reinforced Self-Training)?

ReST — это метод, предложенный Google (Singh et al., 2023), который объединяет идеи self-training (самообучения) и reinforcement learning (обучения с подкреплением). Основная идея: модель генерирует множество вариантов ответов, оценивает их с помощью reward function (функции награды), отбирает лучшие и дообучается на них. Процесс повторяется итеративно.

Ключевые термины

  • Self-training: модель обучается на собственных предсказаниях, отфильтрованных по некоторому критерию.
  • Reward: числовая оценка качества ответа (например, 1 — правильный, 0 — неправильный).
  • Fine-tuning: дообучение модели на отобранных данных с помощью стандартного supervised learning (обучения с учителем).

ReST не требует policy gradient или value function, что делает его простым в реализации и стабильным.


2. Алгоритм ReST пошагово

ReST состоит из двух фаз, которые повторяются:

Фаза 1: Grow (Рост)

  • Модель (текущая политика) генерирует множество ответов для каждого запроса из датасета.
  • Для каждого ответа вычисляется reward (например, точность ответа, оценка LLM-судьи).
  • Ответы сортируются по reward.

Фаза 2: Improve (Улучшение)

  • Выбирается подмножество ответов с высоким reward (например, top-k или выше порога).
  • Модель дообучается на этих отфильтрованных данных с помощью supervised fine-tuning (SFT) — минимизация cross-entropy loss (кросс-энтропийной функции потерь) для выбранных ответов.
  • После дообучения модель становится новой политикой, и процесс повторяется.

Псевдокод

def rest_iteration(model, dataset, reward_fn, threshold=0.8):
    # Grow
    all_responses = []
    for query in dataset:
        responses = model.generate(query, n_samples=10)
        for resp in responses:
            reward = reward_fn(query, resp)
            all_responses.append((query, resp, reward))
    # Filter
    filtered = [(q, r) for q, r, rew in all_responses if rew >= threshold]
    # Improve
    model.fine_tune(filtered)  # SFT on filtered responses
    return model

Важно ReST не использует value network (сеть ценности) и advantage estimation (оценку преимущества), что отличает его от PPO.


3. Что такое PPO (Proximal Policy Optimization)?

PPO — это популярный алгоритм reinforcement learning для обучения политики. Он использует:

  • Policy network (сеть политики) — генерирует действия (ответы).
  • Value network (сеть ценности) — оценивает ожидаемую сумму наград из данного состояния.
  • Clipping (клиппирование) — ограничивает изменение политики на каждом шаге для стабильности.

PPO оптимизирует surrogate objective (суррогатную целевую функцию) с помощью importance sampling (выборки по важности). Это позволяет эффективно использовать данные, собранные старой политикой, но требует тщательной настройки гиперпараметров.

Формула surrogate loss (упрощённо):

L = -E[ min(ratio * A, clip(ratio, 1-ε, 1+ε) * A) ]

где ratio — отношение вероятностей новой и старой политики, Aadvantage (преимущество) из value network.


4. Сравнение ReST и PPO

ХарактеристикаReSTPPO
Сложность реализацииНизкая (только генерация + SFT)Высокая (policy + value network, clipping, GAE)
Value networkНе нужнаОбязательна
Sample efficiencyНизкая (использует только лучшие ответы)Высокая (использует все данные с importance sampling)
СтабильностьВысокая (нет риска расхождения политики)Средняя (требует настройки clip range)
Credit assignmentОтсутствует (reward присваивается всему ответу)Есть (advantage распределяет reward по токенам)
Вычислительные затратыНизкие (один forward pass для генерации)Высокие (два forward pass: policy + value)
Требования к rewardДолжен быть чётким и доступным для каждого ответаМожет работать с sparse reward (редкими наградами)
Типичное применениеЗадачи с однозначной оценкой (математика, код)Задачи с длинным горизонтом (диалоги, игры)

5. Когда ReST лучше PPO?

ReST превосходит PPO в следующих сценариях:

5.1 Чёткий и быстрый reward

Если reward можно вычислить сразу после генерации ответа (например, совпадение с правильным ответом в math reasoning или успешная компиляция кода), ReST эффективен. PPO в таких случаях избыточен.

5.2 Простота и скорость прототипирования

ReST можно реализовать за несколько часов, используя стандартный SFT. PPO требует настройки GAE (Generalized Advantage Estimation), value loss и entropy bonus.

5.3 Стабильность при малых данных

На небольших датасетах PPO может переобучаться или расходиться из-за сложной оптимизации. ReST с фильтрацией top-k более устойчив.

5.4 Отсутствие необходимости в credit assignment

Если reward зависит от всего ответа целиком (а не от отдельных токенов), ReST работает хорошо. Например, в задаче question answering (ответ на вопрос) — ответ либо правильный, либо нет.

Пример: Обучение модели решать арифметические задачи. Reward = 1, если ответ совпадает с эталоном. ReST генерирует 100 вариантов, отбирает правильные и дообучается. PPO здесь не даст преимущества, так как reward не требует разложения по токенам.


6. Когда PPO лучше ReST?

PPO незаменим в случаях:

6.1 Sparse reward (редкие награды)

Когда reward приходит только в конце длинной последовательности (например, в диалоге или игре). PPO с value network может оценивать промежуточные состояния и давать shaped reward (формировать награду).

6.2 Необходимость credit assignment

Если нужно понять, какие именно токены привели к успеху/неудаче. PPO через advantage распределяет reward по всем токенам, а ReST даёт одинаковый reward всему ответу.

6.3 Высокая sample efficiency

Когда генерация каждого ответа дорога (например, вызов внешнего API). PPO использует importance sampling и может обучаться на данных, собранных старой политикой, многократно.

6.4 Сложные, многомодальные задачи

Например, обучение agent (агента) в среде с частичной наблюдаемостью. PPO лучше справляется с exploration (исследованием) за счёт entropy bonus (бонуса энтропии).


7. Примеры применения ReST

  • Математические рассуждения (Math reasoning): модель генерирует цепочки мыслей, reward — совпадение финального ответа. ReST улучшает точность на GSM8K, MATH.
  • Генерация кода: reward — успешная компиляция и прохождение тестов. ReST повышает pass@k.
  • Инструктивное следование (Instruction following): reward от LLM-судьи (например, GPT-4 оценивает полезность). ReST улучшает качество ответов.

Пример из статьи ReST на основе PaLM 2 показал улучшение на 5-10% на бенчмарках математики и кода по сравнению с обычным SFT.


8. Ограничения ReST

  • Низкая sample efficiency: отбрасывается большая часть сгенерированных данных (если threshold высокий). Это дорого при использовании больших моделей.
  • Зависимость от качества reward: если reward шумный или неполный, фильтрация может отобрать плохие примеры.
  • Отсутствие exploration: ReST не исследует новые стратегии, а только усиливает уже хорошие ответы. Это может привести к mode collapse (коллапсу моды).
  • Не подходит для задач с длинным горизонтом: где reward отложен на много шагов.

9. Связь с другими методами

ReST является частным случаем EM (Expectation-Maximization) в обучении с подкреплением: Grow — E-шаг (генерация+оценка), Improve — M-шаг (максимизация правдоподобия на лучших данных). Также ReST близок к ReST-EM (вариант с expectation-maximization) и Self-Play (самоигра) в RL.

В контексте Agentic RAG ReST может использоваться для улучшения retrieval agent (агента поиска) или reasoning agent (агента рассуждений), где reward — успешность ответа на вопрос пользователя.


Пет-проект для закрепления

Задача Обучить небольшую языковую модель (например, GPT-2 или TinyLlama) решать простые математические задачи (сложение двух чисел) с помощью ReST.

Инструменты

  • Hugging Face Transformers, PyTorch
  • Датасет: самодельный (1000 примеров вида "2+3=5")
  • Reward: совпадение ответа (0/1)

Шаги:

  1. Загрузите предобученную модель (например, distilgpt2).
  2. Сгенерируйте для каждого запроса 10 ответов (температура 0.8).
  3. Вычислите reward: 1, если ответ содержит правильное число, иначе 0.
  4. Отфильтруйте ответы с reward=1.
  5. Дообучите модель на отфильтрованных данных (SFT, 1 эпоха, lr=1e-5).
  6. Повторите шаги 2-5 три итерации.
  7. Оцените accuracy на тестовом наборе (100 примеров).

Ожидаемый результат Accuracy растёт с каждой итерацией (например, с 20% до 60%). Сравните с baseline (обычный SFT на тех же данных без фильтрации).


Связь с другими вопросами

ВопросТема
340Как работает PPO для LLM?
341Что такое RLHF и как он связан с PPO?
342Как работает DPO (Direct Preference Optimization)?
338Как обучить агента с подкреплением в Agentic RAG?
343Что такое GRPO (Group Relative Policy Optimization)?
337Какие методы fine-tuning используются для агентов?

Навигация