中文翻译暂不可用,显示俄语原文。
Как работает алгоритм ReST (Reinforced Self-Training) и когда он лучше PPO?
Краткий тезис
ReST (Reinforced Self-Training) — это итеративный алгоритм дообучения языковых моделей, который чередует генерацию ответов, фильтрацию по reward (награде) и fine-tuning на отфильтрованных данных. Он проще PPO (Policy Optimization|Proximal Policy Optimization]]), так как не требует value network и сложной оптимизации, но уступает в sample efficiency. ReST лучше PPO в задачах с чётким, легко вычислимым reward (например, правильность ответа в математике или коде), где не нужен сложный credit assignment.
1. Что такое ReST (Reinforced Self-Training)?
ReST — это метод, предложенный Google (Singh et al., 2023), который объединяет идеи self-training (самообучения) и reinforcement learning (обучения с подкреплением). Основная идея: модель генерирует множество вариантов ответов, оценивает их с помощью reward function (функции награды), отбирает лучшие и дообучается на них. Процесс повторяется итеративно.
Ключевые термины
- Self-training: модель обучается на собственных предсказаниях, отфильтрованных по некоторому критерию.
- Reward: числовая оценка качества ответа (например, 1 — правильный, 0 — неправильный).
- Fine-tuning: дообучение модели на отобранных данных с помощью стандартного supervised learning (обучения с учителем).
ReST не требует policy gradient или value function, что делает его простым в реализации и стабильным.
2. Алгоритм ReST пошагово
ReST состоит из двух фаз, которые повторяются:
Фаза 1: Grow (Рост)
- Модель (текущая политика) генерирует множество ответов для каждого запроса из датасета.
- Для каждого ответа вычисляется reward (например, точность ответа, оценка LLM-судьи).
- Ответы сортируются по reward.
Фаза 2: Improve (Улучшение)
- Выбирается подмножество ответов с высоким reward (например, top-k или выше порога).
- Модель дообучается на этих отфильтрованных данных с помощью supervised fine-tuning (SFT) — минимизация cross-entropy loss (кросс-энтропийной функции потерь) для выбранных ответов.
- После дообучения модель становится новой политикой, и процесс повторяется.
Псевдокод
def rest_iteration(model, dataset, reward_fn, threshold=0.8):
# Grow
all_responses = []
for query in dataset:
responses = model.generate(query, n_samples=10)
for resp in responses:
reward = reward_fn(query, resp)
all_responses.append((query, resp, reward))
# Filter
filtered = [(q, r) for q, r, rew in all_responses if rew >= threshold]
# Improve
model.fine_tune(filtered) # SFT on filtered responses
return model
Важно ReST не использует value network (сеть ценности) и advantage estimation (оценку преимущества), что отличает его от PPO.
3. Что такое PPO (Proximal Policy Optimization)?
PPO — это популярный алгоритм reinforcement learning для обучения политики. Он использует:
- Policy network (сеть политики) — генерирует действия (ответы).
- Value network (сеть ценности) — оценивает ожидаемую сумму наград из данного состояния.
- Clipping (клиппирование) — ограничивает изменение политики на каждом шаге для стабильности.
PPO оптимизирует surrogate objective (суррогатную целевую функцию) с помощью importance sampling (выборки по важности). Это позволяет эффективно использовать данные, собранные старой политикой, но требует тщательной настройки гиперпараметров.
Формула surrogate loss (упрощённо):
L = -E[ min(ratio * A, clip(ratio, 1-ε, 1+ε) * A) ]
где ratio — отношение вероятностей новой и старой политики, A — advantage (преимущество) из value network.
4. Сравнение ReST и PPO
| Характеристика | ReST | PPO |
|---|---|---|
| Сложность реализации | Низкая (только генерация + SFT) | Высокая (policy + value network, clipping, GAE) |
| Value network | Не нужна | Обязательна |
| Sample efficiency | Низкая (использует только лучшие ответы) | Высокая (использует все данные с importance sampling) |
| Стабильность | Высокая (нет риска расхождения политики) | Средняя (требует настройки clip range) |
| Credit assignment | Отсутствует (reward присваивается всему ответу) | Есть (advantage распределяет reward по токенам) |
| Вычислительные затраты | Низкие (один forward pass для генерации) | Высокие (два forward pass: policy + value) |
| Требования к reward | Должен быть чётким и доступным для каждого ответа | Может работать с sparse reward (редкими наградами) |
| Типичное применение | Задачи с однозначной оценкой (математика, код) | Задачи с длинным горизонтом (диалоги, игры) |
5. Когда ReST лучше PPO?
ReST превосходит PPO в следующих сценариях:
5.1 Чёткий и быстрый reward
Если reward можно вычислить сразу после генерации ответа (например, совпадение с правильным ответом в math reasoning или успешная компиляция кода), ReST эффективен. PPO в таких случаях избыточен.
5.2 Простота и скорость прототипирования
ReST можно реализовать за несколько часов, используя стандартный SFT. PPO требует настройки GAE (Generalized Advantage Estimation), value loss и entropy bonus.
5.3 Стабильность при малых данных
На небольших датасетах PPO может переобучаться или расходиться из-за сложной оптимизации. ReST с фильтрацией top-k более устойчив.
5.4 Отсутствие необходимости в credit assignment
Если reward зависит от всего ответа целиком (а не от отдельных токенов), ReST работает хорошо. Например, в задаче question answering (ответ на вопрос) — ответ либо правильный, либо нет.
Пример: Обучение модели решать арифметические задачи. Reward = 1, если ответ совпадает с эталоном. ReST генерирует 100 вариантов, отбирает правильные и дообучается. PPO здесь не даст преимущества, так как reward не требует разложения по токенам.
6. Когда PPO лучше ReST?
PPO незаменим в случаях:
6.1 Sparse reward (редкие награды)
Когда reward приходит только в конце длинной последовательности (например, в диалоге или игре). PPO с value network может оценивать промежуточные состояния и давать shaped reward (формировать награду).
6.2 Необходимость credit assignment
Если нужно понять, какие именно токены привели к успеху/неудаче. PPO через advantage распределяет reward по всем токенам, а ReST даёт одинаковый reward всему ответу.
6.3 Высокая sample efficiency
Когда генерация каждого ответа дорога (например, вызов внешнего API). PPO использует importance sampling и может обучаться на данных, собранных старой политикой, многократно.
6.4 Сложные, многомодальные задачи
Например, обучение agent (агента) в среде с частичной наблюдаемостью. PPO лучше справляется с exploration (исследованием) за счёт entropy bonus (бонуса энтропии).
7. Примеры применения ReST
- Математические рассуждения (Math reasoning): модель генерирует цепочки мыслей, reward — совпадение финального ответа. ReST улучшает точность на GSM8K, MATH.
- Генерация кода: reward — успешная компиляция и прохождение тестов. ReST повышает pass@k.
- Инструктивное следование (Instruction following): reward от LLM-судьи (например, GPT-4 оценивает полезность). ReST улучшает качество ответов.
Пример из статьи ReST на основе PaLM 2 показал улучшение на 5-10% на бенчмарках математики и кода по сравнению с обычным SFT.
8. Ограничения ReST
- Низкая sample efficiency: отбрасывается большая часть сгенерированных данных (если threshold высокий). Это дорого при использовании больших моделей.
- Зависимость от качества reward: если reward шумный или неполный, фильтрация может отобрать плохие примеры.
- Отсутствие exploration: ReST не исследует новые стратегии, а только усиливает уже хорошие ответы. Это может привести к mode collapse (коллапсу моды).
- Не подходит для задач с длинным горизонтом: где reward отложен на много шагов.
9. Связь с другими методами
ReST является частным случаем EM (Expectation-Maximization) в обучении с подкреплением: Grow — E-шаг (генерация+оценка), Improve — M-шаг (максимизация правдоподобия на лучших данных). Также ReST близок к ReST-EM (вариант с expectation-maximization) и Self-Play (самоигра) в RL.
В контексте Agentic RAG ReST может использоваться для улучшения retrieval agent (агента поиска) или reasoning agent (агента рассуждений), где reward — успешность ответа на вопрос пользователя.
Пет-проект для закрепления
Задача Обучить небольшую языковую модель (например, GPT-2 или TinyLlama) решать простые математические задачи (сложение двух чисел) с помощью ReST.
Инструменты
- Hugging Face Transformers, PyTorch
- Датасет: самодельный (1000 примеров вида "2+3=5")
- Reward: совпадение ответа (0/1)
Шаги:
- Загрузите предобученную модель (например,
distilgpt2). - Сгенерируйте для каждого запроса 10 ответов (температура 0.8).
- Вычислите reward: 1, если ответ содержит правильное число, иначе 0.
- Отфильтруйте ответы с reward=1.
- Дообучите модель на отфильтрованных данных (SFT, 1 эпоха, lr=1e-5).
- Повторите шаги 2-5 три итерации.
- Оцените accuracy на тестовом наборе (100 примеров).
Ожидаемый результат Accuracy растёт с каждой итерацией (например, с 20% до 60%). Сравните с baseline (обычный SFT на тех же данных без фильтрации).
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 340 | Как работает PPO для LLM? |
| 341 | Что такое RLHF и как он связан с PPO? |
| 342 | Как работает DPO (Direct Preference Optimization)? |
| 338 | Как обучить агента с подкреплением в Agentic RAG? |
| 343 | Что такое GRPO (Group Relative Policy Optimization)? |
| 337 | Какие методы fine-tuning используются для агентов? |
Навигация
- Предыдущий: 338
- Следующий: 340
- Индекс: 00. Индекс разборов