Настроить MCTS для математических задач
ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Настроить MCTS для математических задач
1. Цель задачи
Разработать агента для решения сложных математических задач уровня MATH‑500 с использованием Monte Carlo Tree Search (MCTS) в комбинации с LLM. Вам предстоит реализовать ключевые компоненты MCTS: Upper Confidence Bound (UCB) для выбора узлов, rollout (симуляцию) и backpropagation наград. Цель — научиться применять алгоритмы поиска с отложенным вознаграждением (test‑time compute) для улучшения качества ответов LLM на задачах, требующих многошагового рассуждения.
Ключевой результат Рабочий пайплайн MCTS + LLM, который решает задачи из датасета MATH‑500 с accuracy ≥ 0.4 (на подмножестве из 100 задач). Код выложен в репозиторий с воспроизводимым README.
2. Исходные данные
| Что нужно | Откуда взять |
|---|---|
| Датасет MATH‑500 (500 задач уровня олимпиадной математики) | Hugging Face: competition_math или lighteval/MATH |
| LLM для генерации шагов рассуждения | Hugging Face: microsoft/phi-1_5, Qwen/Qwen2.5-1.5B-Instruct или любая small‑model (< 3B) |
| Функция проверки ответа (exact match или символьное упрощение) | PyPi: math_verify (рекомендовано) или написать самому с sympy |
| Среда для запуска (Python 3.10+, 16 GB RAM, GPU опционально) | Локальный/облачный сервер или Colab (T4) |
Если нет реального LLM (например, нет доступа к GPU) — симулируем:
- Возьмите готовый датасет с правильными ответами (
lighteval/MATH). - Напишите заглушку rollout, которая случайно выбирает один из предопределённых вариантов ответа (3–4 варианта на задачу). MCTS будет обучаться выбору правильного.
- Замените LLM на pipeline из transformers с torch.device("cpu") и small model (< 1B) — будет медленно, но позволит проверить алгоритм.
3. Технологический стек
| Компонент | Инструменты | Назначение |
|---|---|---|
| Язык программирования | Python 3.10+ | Ядро пайплайна |
| LLM Inference | Hugging Face transformers, torch | Генерация шагов рассуждения |
| Датасет | Hugging Face datasets | Загрузка MATH‑500 |
| Проверка ответов | math_verify (или sympy) | Точное сравнение математических выражений |
| Утилиты | numpy, tqdm, asyncio (опционально) | Расчёты, прогресс, параллелизм |
| Версионирование и воспроизводимость | Git, requirements.txt, conda | Управление кодом и зависимостями |
4. Этапы выполнения
Этап 1: Подготовка среды и данных (30–60 минут)
Действия
-
Установите окружение
conda create -n mcts_math python=3.10 conda activate mcts_math pip install transformers datasets torch numpy tqdm math-verify sympy -
Загрузите датасет MATH‑500
from datasets import load_dataset dataset = load_dataset("lighteval/MATH", split="test") # выберите первые 100 задач для отладки subset = dataset.select(range(100)) -
Создайте функцию проверки ответа
from math_verify import verify, parse def is_correct(pred: str, gt: str) -> bool: try: parsed_pred = parse(pred) # преобразует строку в math-выражение parsed_gt = parse(gt) return verify(parsed_pred, parsed_gt) except: return False -
Подготовьте LLM загрузите модель и токенизатор:
from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "microsoft/phi-1_5" # или другая small model model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token
Ожидаемый результат этапа Рабочий скрипт, который загружает данные, проверяет любой заданный ответ на точное соответствие с использованием math_verify, и готов к генерации текста.
Этап 2: Реализация узла MCTS (60–90 минут)
Действия
-
Создайте класс MCTSNode
class MCTSNode: def __init__(self, state: str, parent=None, action=None): self.state = state # текущий текст (промпт + шаги) self.parent = parent self.action = action # последнее действие (шаг рассуждения) self.children = [] self.visits = 0 self.total_reward = 0.0 self.is_terminal = False # достигнут ответ -
Определите методы
is_fully_expanded()— проверяет, все ли возможные действия перебраны (можно ограничить максимальное количество детей).best_child(c=1.4)— выбирает ребёнка с максимумом UCB1:
Q(s,a) + c * sqrt(ln(N_parent) / n_child).- update(reward) — обновляет visits и
total_reward.
-
Параметры
- Максимальная глубина (количество шагов рассуждения): max_depth=20.
- Максимальное количество детей на узел (ширина): max_children=5 (действия = топ-5 продолжений от LLM).
- Константа UCB:
c = 1.4(настраивается).
Ожидаемый результат этапа Класс MCTSNode с реализованными методами, протестирован на простых строках.
Этап 3: Реализация UCB, rollout и backpropagation (2–3 часа)
Действия
-
Selection (выбор) — рекурсивно спускаемся, пока узел полностью раскрыт и не терминален:
def select(node): while node.is_fully_expanded() and not node.is_terminal: node = node.best_child() return node -
Expansion (расширение) — если узел не полностью раскрыт, генерируем новый шаг:
- Промпт: {state} Let’s think step by step.\nStep {n}:
- LLM генерирует несколько вариантов (beam search / top‑k sampling).
- Создаём дочерний узел для каждого варианта, добавляем к
node.children.
-
Rollout (симуляция) — от нового узла генерируем полный путь до терминального состояния (ответ) без обратного хода:
- Запускаем LLM в greedy mode (temperature=0) до получения ответа в формате
\boxed{...}или до достиженияmax_steps. - Извлекаем ответ, сверяем с правильным через
is_correct. - Награда: 1.0 если верно, −1.0 если неверно (или 0.0, но разница важна).
- Запускаем LLM в greedy mode (temperature=0) до получения ответа в формате
-
Backpropagation — обновляем статистики на всём пути от листа до корня:
def backpropagate(node, reward): while node: node.visits += 1 node.total_reward += reward node = node.parent
Настройка rollout для ускорения можно использовать короткую симуляцию (3‑5 шагов) и затем экстраполировать награду (но это усложнение; для первого раза делаем полный rollout).
Ожидаемый результат этапа Функции uct_search (один полный цикл MCTS) и основной цикл run_mcts(root_state, num_iterations=100).
Этап 4: Интеграция и тестирование на MATH‑500 (2–3 часа)
Действия
-
Напишите основной пайплайн
- Для каждой задачи из subset (100 задач):
- Сформируйте корневой узел:
MCTSNode(state=prompt). - Запустите MCTS на N итераций (начните с 50, позже увеличьте до 200).
- Выберите лучшее действие из корня (по наибольшему числу посещений).
- Сгенерируйте итоговый ответ (начиная с выбранного действия, затем greedy rollout).
- Запишите результат (правильный/неправильный).
- Сформируйте корневой узел:
- Для каждой задачи из subset (100 задач):
-
Соберите статистику
correct = 0 for i, example in enumerate(subset): final_answer = ... # полученный ответ if is_correct(final_answer, example["solution"]): correct += 1 print(f"Accuracy: {correct / len(subset):.3f}") -
Оптимизация попробуйте разные значения
c(0.5, 1.0, 1.4, 2.0) и количество итераций (50, 100, 200). Запишите результаты в лог.
Ожидаемый результат этапа Таблица с accuracy для разных гиперпараметров + лучший результат ≥ 0.4 на 100 задачах.
Этап 5: Документирование и упаковка (30–60 минут)
Действия
-
Оформите репозиторий
src/mcts.py— реализация MCTS.src/llm_utils.py— функции генерации и проверки.scripts/run_experiments.py— главный скрипт.requirements.txtREADME.mdс инструкцией и результатами.
-
В README опишите
Ожидаемый результат этапа Полный репозиторий, готовый к ревью.
5. Критерии приемки (Definition of Done)
- Реализованы класс
MCTSNodeи все четыре фазы MCTS (selection, expansion, rollout, backpropagation). - UCB1 используется для выбора узла, константа
cзадаётся параметром. - Функция rollout использует LLM для генерации шагов и возвращает награду (1/0 или ±1).
- После MCTS ответ извлекается из best‑path (наиболее посещаемый узел).
- Accuracy на подмножестве из 100 задач MATH‑500 ≥ 0.4 (при любых разумных гиперпараметрах).
- Код воспроизводим: есть
requirements.txt, скрипты запускаются без ошибок. - В репозитории есть README с описанием и результатами.
- Код использует
math_verifyдля корректной проверки математических выражений. - Реализована возможность отладки (логирование, визуализация дерева опционально).
6. Ожидаемый результат
Основной артефакт Репозиторий на GitHub/GitLab, содержащий:
src/mcts.py— реализация MCTS.src/llm_utils.py— функции для генерации и проверки.scripts/run_experiments.py— скрипт, который запускает эксперимент и выводит accuracy.results/log.csv— лог с гиперпараметрами и accuracy.
Содержание results/log.csv
model,iterations,c,accuracy,time_sec
phi-1.5,50,1.4,0.35,1200
phi-1.5,100,1.4,0.42,2400
...
Опционально График зависимости accuracy от числа итераций MCTS (PNG в папке plots).
7. Возможные сложности и их решение
| Сложность | Решение |
|---|---|
| LLM генерирует очень длинные рассуждения, rollout занимает > 10 сек | Ограничьте rollout глубиной (например, 5 шагов) и затем экстраполируйте награду (heuristic), или используйте маленькую модель (< 1B). |
| UCB не различает узлы (все посещения одинаковы) | Увеличьте константу c до 2.0 или измените награду на ±1 (вместо 0/1). |
math_verify не может распарсить ответ | Дополнительно реализуйте fallback: извлеките \boxed{...} и сравните строково, убрав пробелы. |
| Эксперимент на 100 задачах занимает > 30 мин | Используйте batch‑inference (передавайте несколько промптов одновременно) или сократите глубину MCTS. |
| Дерево MCTS растёт слишком быстро (память) | Ограничьте максимальный размер дерева (число узлов), удаляйте наименее перспективные ветви. |
8. Бюджет времени (оценка)
| Этап | Время |
|---|---|
| 1. Подготовка среды и данных | 1 ч |
| 2. Реализация узла MCTS | 1.5 ч |
| 3. Реализация UCB, rollout, backprop | 3 ч |
| 4. Интеграция и тестирование | 3 ч |
| 5. Документирование | 1 ч |
| Итого | 9.5 ч |
Примечание для первого раза Если вы впервые реализуете MCTS с LLM, заложите +2 часа на отладку (особенно на проверку правильности backpropagation и UCB). Для ускорения rollout можно заменить LLM на заглушку (см. Этап 2, раздел «симулируем»).
9. Связанные вопросы из базы знаний
| Вопрос | Тема |
|---|---|
| 42 | Monte Carlo Tree Search: основные концепции |
| 65 | Upper Confidence Bound (UCB) в многоруких бандитах |
| 101 | Tree search with neural networks (AlphaZero) |
| 203 | Reward shaping для задач с многошаговым рассуждением |
| 304 | Решение математических задач с помощью LLM + поиск |
| 415 | Параллельный rollout в MCTS |
| 506 | Backpropagation наград в дереве решений |
| 607 | Test‑time compute scaling laws |
| 708 | Сравнение beam search и MCTS для генерации |
| 809 | Верификация математических ответов (sympy / math_verify) |
10. Чек-лист самопроверки
- Я проверил, что UCB1 корректен:
best_child(c)возвращает узел с максимумомQ + c * sqrt(ln N / n), а не минимумом. - Я протестировал backpropagation на дереве глубины 2: награда корректно суммируется до корня.
- Я убедился, что
math_verifyкорректно обрабатывает как\boxed{3}, так и3. - Я запустил эксперимент на 10 задачах перед полным запуском и получил ненулевую accuracy.
- В репозитории нет закомментированного мусора или путей к локальным файлам — всё воспроизводимо.