Настроить MCTS для математических задач

ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Настроить MCTS для математических задач

1. Цель задачи

Разработать агента для решения сложных математических задач уровня MATH‑500 с использованием Monte Carlo Tree Search (MCTS) в комбинации с LLM. Вам предстоит реализовать ключевые компоненты MCTS: Upper Confidence Bound (UCB) для выбора узлов, rollout (симуляцию) и backpropagation наград. Цель — научиться применять алгоритмы поиска с отложенным вознаграждением (test‑time compute) для улучшения качества ответов LLM на задачах, требующих многошагового рассуждения.

Ключевой результат Рабочий пайплайн MCTS + LLM, который решает задачи из датасета MATH‑500 с accuracy ≥ 0.4 (на подмножестве из 100 задач). Код выложен в репозиторий с воспроизводимым README.


2. Исходные данные

Что нужноОткуда взять
Датасет MATH‑500 (500 задач уровня олимпиадной математики)Hugging Face: competition_math или lighteval/MATH
LLM для генерации шагов рассужденияHugging Face: microsoft/phi-1_5, Qwen/Qwen2.5-1.5B-Instruct или любая small‑model (< 3B)
Функция проверки ответа (exact match или символьное упрощение)PyPi: math_verify (рекомендовано) или написать самому с sympy
Среда для запуска (Python 3.10+, 16 GB RAM, GPU опционально)Локальный/облачный сервер или Colab (T4)

Если нет реального LLM (например, нет доступа к GPU) — симулируем:

  1. Возьмите готовый датасет с правильными ответами (lighteval/MATH).
  2. Напишите заглушку rollout, которая случайно выбирает один из предопределённых вариантов ответа (3–4 варианта на задачу). MCTS будет обучаться выбору правильного.
  3. Замените LLM на pipeline из transformers с torch.device("cpu") и small model (< 1B) — будет медленно, но позволит проверить алгоритм.

3. Технологический стек

КомпонентИнструментыНазначение
Язык программированияPython 3.10+Ядро пайплайна
LLM InferenceHugging Face transformers, torchГенерация шагов рассуждения
ДатасетHugging Face datasetsЗагрузка MATH‑500
Проверка ответовmath_verify (или sympy)Точное сравнение математических выражений
Утилитыnumpy, tqdm, asyncio (опционально)Расчёты, прогресс, параллелизм
Версионирование и воспроизводимостьGit, requirements.txt, condaУправление кодом и зависимостями

4. Этапы выполнения

Этап 1: Подготовка среды и данных (30–60 минут)

Действия

  1. Установите окружение

    conda create -n mcts_math python=3.10
    conda activate mcts_math
    pip install transformers datasets torch numpy tqdm math-verify sympy
    
  2. Загрузите датасет MATH‑500

    from datasets import load_dataset
    dataset = load_dataset("lighteval/MATH", split="test")
    # выберите первые 100 задач для отладки
    subset = dataset.select(range(100))
    
  3. Создайте функцию проверки ответа

    from math_verify import verify, parse
    
    def is_correct(pred: str, gt: str) -> bool:
        try:
            parsed_pred = parse(pred)    # преобразует строку в math-выражение
            parsed_gt = parse(gt)
            return verify(parsed_pred, parsed_gt)
        except:
            return False
    
  4. Подготовьте LLM загрузите модель и токенизатор:

    from transformers import AutoModelForCausalLM, AutoTokenizer
    model_name = "microsoft/phi-1_5"   # или другая small model
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    

Ожидаемый результат этапа Рабочий скрипт, который загружает данные, проверяет любой заданный ответ на точное соответствие с использованием math_verify, и готов к генерации текста.


Этап 2: Реализация узла MCTS (60–90 минут)

Действия

  1. Создайте класс MCTSNode

    class MCTSNode:
        def __init__(self, state: str, parent=None, action=None):
            self.state = state          # текущий текст (промпт + шаги)
            self.parent = parent
            self.action = action        # последнее действие (шаг рассуждения)
            self.children = []
            self.visits = 0
            self.total_reward = 0.0
            self.is_terminal = False    # достигнут ответ
    
  2. Определите методы

    • is_fully_expanded() — проверяет, все ли возможные действия перебраны (можно ограничить максимальное количество детей).
    • best_child(c=1.4) — выбирает ребёнка с максимумом UCB1:
      Q(s,a) + c * sqrt(ln(N_parent) / n_child).
    • update(reward) — обновляет visits и total_reward.
  3. Параметры

    • Максимальная глубина (количество шагов рассуждения): max_depth=20.
    • Максимальное количество детей на узел (ширина): max_children=5 (действия = топ-5 продолжений от LLM).
    • Константа UCB: c = 1.4 (настраивается).

Ожидаемый результат этапа Класс MCTSNode с реализованными методами, протестирован на простых строках.


Этап 3: Реализация UCB, rollout и backpropagation (2–3 часа)

Действия

  1. Selection (выбор) — рекурсивно спускаемся, пока узел полностью раскрыт и не терминален:

    def select(node):
        while node.is_fully_expanded() and not node.is_terminal:
            node = node.best_child()
        return node
    
  2. Expansion (расширение) — если узел не полностью раскрыт, генерируем новый шаг:

    • Промпт: {state} Let’s think step by step.\nStep {n}:
    • LLM генерирует несколько вариантов (beam search / top‑k sampling).
    • Создаём дочерний узел для каждого варианта, добавляем к node.children.
  3. Rollout (симуляция) — от нового узла генерируем полный путь до терминального состояния (ответ) без обратного хода:

    • Запускаем LLM в greedy mode (temperature=0) до получения ответа в формате \boxed{...} или до достижения max_steps.
    • Извлекаем ответ, сверяем с правильным через is_correct.
    • Награда: 1.0 если верно, −1.0 если неверно (или 0.0, но разница важна).
  4. Backpropagation — обновляем статистики на всём пути от листа до корня:

    def backpropagate(node, reward):
        while node:
            node.visits += 1
            node.total_reward += reward
            node = node.parent
    

Настройка rollout для ускорения можно использовать короткую симуляцию (3‑5 шагов) и затем экстраполировать награду (но это усложнение; для первого раза делаем полный rollout).

Ожидаемый результат этапа Функции uct_search (один полный цикл MCTS) и основной цикл run_mcts(root_state, num_iterations=100).


Этап 4: Интеграция и тестирование на MATH‑500 (2–3 часа)

Действия

  1. Напишите основной пайплайн

    • Для каждой задачи из subset (100 задач):
      • Сформируйте корневой узел: MCTSNode(state=prompt).
      • Запустите MCTS на N итераций (начните с 50, позже увеличьте до 200).
      • Выберите лучшее действие из корня (по наибольшему числу посещений).
      • Сгенерируйте итоговый ответ (начиная с выбранного действия, затем greedy rollout).
      • Запишите результат (правильный/неправильный).
  2. Соберите статистику

    correct = 0
    for i, example in enumerate(subset):
        final_answer = ...   # полученный ответ
        if is_correct(final_answer, example["solution"]):
            correct += 1
    print(f"Accuracy: {correct / len(subset):.3f}")
    
  3. Оптимизация попробуйте разные значения c (0.5, 1.0, 1.4, 2.0) и количество итераций (50, 100, 200). Запишите результаты в лог.

Ожидаемый результат этапа Таблица с accuracy для разных гиперпараметров + лучший результат ≥ 0.4 на 100 задачах.


Этап 5: Документирование и упаковка (30–60 минут)

Действия

  1. Оформите репозиторий

    • src/mcts.py — реализация MCTS.
    • src/llm_utils.py — функции генерации и проверки.
    • scripts/run_experiments.py — главный скрипт.
    • requirements.txt
    • README.md с инструкцией и результатами.
  2. В README опишите

    • Архитектуру MCTS.
    • Используемые гиперпараметры.
    • Accuracy на MATH‑500 (subset 100 задач).
    • Как воспроизвести.

Ожидаемый результат этапа Полный репозиторий, готовый к ревью.


5. Критерии приемки (Definition of Done)

  • Реализованы класс MCTSNode и все четыре фазы MCTS (selection, expansion, rollout, backpropagation).
  • UCB1 используется для выбора узла, константа c задаётся параметром.
  • Функция rollout использует LLM для генерации шагов и возвращает награду (1/0 или ±1).
  • После MCTS ответ извлекается из best‑path (наиболее посещаемый узел).
  • Accuracy на подмножестве из 100 задач MATH‑500 ≥ 0.4 (при любых разумных гиперпараметрах).
  • Код воспроизводим: есть requirements.txt, скрипты запускаются без ошибок.
  • В репозитории есть README с описанием и результатами.
  • Код использует math_verify для корректной проверки математических выражений.
  • Реализована возможность отладки (логирование, визуализация дерева опционально).

6. Ожидаемый результат

Основной артефакт Репозиторий на GitHub/GitLab, содержащий:

  • src/mcts.py — реализация MCTS.
  • src/llm_utils.py — функции для генерации и проверки.
  • scripts/run_experiments.py — скрипт, который запускает эксперимент и выводит accuracy.
  • results/log.csv — лог с гиперпараметрами и accuracy.

Содержание results/log.csv

model,iterations,c,accuracy,time_sec
phi-1.5,50,1.4,0.35,1200
phi-1.5,100,1.4,0.42,2400
...

Опционально График зависимости accuracy от числа итераций MCTS (PNG в папке plots).


7. Возможные сложности и их решение

СложностьРешение
LLM генерирует очень длинные рассуждения, rollout занимает > 10 секОграничьте rollout глубиной (например, 5 шагов) и затем экстраполируйте награду (heuristic), или используйте маленькую модель (< 1B).
UCB не различает узлы (все посещения одинаковы)Увеличьте константу c до 2.0 или измените награду на ±1 (вместо 0/1).
math_verify не может распарсить ответДополнительно реализуйте fallback: извлеките \boxed{...} и сравните строково, убрав пробелы.
Эксперимент на 100 задачах занимает > 30 минИспользуйте batch‑inference (передавайте несколько промптов одновременно) или сократите глубину MCTS.
Дерево MCTS растёт слишком быстро (память)Ограничьте максимальный размер дерева (число узлов), удаляйте наименее перспективные ветви.

8. Бюджет времени (оценка)

ЭтапВремя
1. Подготовка среды и данных1 ч
2. Реализация узла MCTS1.5 ч
3. Реализация UCB, rollout, backprop3 ч
4. Интеграция и тестирование3 ч
5. Документирование1 ч
Итого9.5 ч

Примечание для первого раза Если вы впервые реализуете MCTS с LLM, заложите +2 часа на отладку (особенно на проверку правильности backpropagation и UCB). Для ускорения rollout можно заменить LLM на заглушку (см. Этап 2, раздел «симулируем»).


9. Связанные вопросы из базы знаний

ВопросТема
42Monte Carlo Tree Search: основные концепции
65Upper Confidence Bound (UCB) в многоруких бандитах
101Tree search with neural networks (AlphaZero)
203Reward shaping для задач с многошаговым рассуждением
304Решение математических задач с помощью LLM + поиск
415Параллельный rollout в MCTS
506Backpropagation наград в дереве решений
607Test‑time compute scaling laws
708Сравнение beam search и MCTS для генерации
809Верификация математических ответов (sympy / math_verify)

10. Чек-лист самопроверки

  • Я проверил, что UCB1 корректен: best_child(c) возвращает узел с максимумом Q + c * sqrt(ln N / n), а не минимумом.
  • Я протестировал backpropagation на дереве глубины 2: награда корректно суммируется до корня.
  • Я убедился, что math_verify корректно обрабатывает как \boxed{3}, так и 3.
  • Я запустил эксперимент на 10 задачах перед полным запуском и получил ненулевую accuracy.
  • В репозитории нет закомментированного мусора или путей к локальным файлам — всё воспроизводимо.