English translation is not available yet. Showing Russian content.

Настроить search-based inference (AlphaSearch)

ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Настроить search-based inference (AlphaSearch)

1. Цель задачи

Реализовать поисковый механизм поверх LLM, использующий MCTS (Monte Carlo Tree Search) в сочетании с verifier для выбора наилучшего ответа. Цель — улучшить качество рассуждений модели на сложных задачах (математика, логика) по сравнению с обычным greedy-декодированием или best-of-N. Вы научитесь строить композитную inference-систему, комбинировать multiple rollouts, обрезать ветви и присваивать вознаграждения.

Ключевой результат Рабочий пайплайн AlphaSearch, который на тестовых задачах (например, 100 примеров из GSM8K) даёт accuracy выше, чем greedy decoding (минимум +5%).


2. Исходные данные

Что нужноОткуда взять
LLM (7B-13B, желательно с умением chain-of-thought)Hugging Face: microsoft/phi-2, mistralai/Mistral-7B-Instruct-v0.3, meta-llama/Llama-2-7b-chat-hf
Датасет сложных задач (математика)GSM8K (gsm8k), MATH (hendrycks/ma th) или арифметические цепочки
Verifier (reward model)Обученная модель (например, OpenAssistant/reward-model-deberta-v3-large-v2) или rule-based: сравнение финального числового ответа с эталоном
Инфраструктура для запускаPython 3.10+, PyTorch, CUDA (GPU 16GB+), vLLM (опционально для ускорения)
Референсный baselineGreedy decoding с temperature=0, top-p=1

Если нет реального verifier — симулируем:

  1. Используем часть датасета с известными ответами (train split). Для каждого rollout проверяем, совпадает ли финальный ответ с правильным (постобработка: извлечь число после ## или Answer:).
  2. Для шагов внутри MCTS используем прогнозирование правильности по правилу: ближайший rollout score = 1, если промежуточный шаг ведёт к правильному финалу (имитация через несколько симуляций до конца).

3. Технологический стек

КомпонентИнструментыНазначение
Модель LLMHuggingFace Transformers, vLLMГенерация рассуждений и токенов
MCTSPython (библиотеки: mcts, tree или реализация вручную)Поиск по дереву мыслей
Verifier / Reward modelHuggingFace AutoModelForSequenceClassification / кастомный классОценка качества полного или частичного ответа
Дашборд / логированиеWeights & Biases, TensorBoardСравнение метрик разных стратегий
ДатасетHuggingFace datasetsGSM8K, MATH
ОкружениеPython 3.10, PyTorch, CUDAВычисления

4. Этапы выполнения

Этап 1: Подготовка окружения и baseline (1 час)

Действия

  1. Установить зависимости: pip install transformers torch datasets vllm swig (swig для mcts, если нужно).
  2. Загрузить модель (например, Mistral-7B-Instruct) и токенайзер.
  3. Загрузить датасет GSM8K (100 примеров из test split для оценки).
  4. Реализовать greedy-декодирование: model.generate(... temperature=0.0, max_new_tokens=512).
  5. Постобработка: извлечь финальный ответ (число). Вычислить accuracy на 100 примерах.
  6. Зафиксировать результат как baseline.

Ожидаемый результат этапа Скрипт baseline.py, accuracy на GSM8K (например, 42% для 7B модели). Вывод в STDERR: Greedy accuracy: 0.42.


Этап 2: Реализация MCTS-дерева (3 часа)

Действия

  1. Определить структуру узла: state (текущий частичный рассуждение), parent, children, visits, value.
  2. Написать функции:
    • select(node) — выбор узла по UCB1: Q + c * sqrt(log(N_parent) / (1 + n)).
    • expand(node) — генерация k вариантов следующего шага (например, top-k из LLM с temperature=0.8).
    • simulate(node) — rollout от заданного состояния до конца (greedy).
    • backpropagate(node, reward) — обновление Q и visits по пути.
  3. Ограничить количество симуляций (например, 50 на корневой узел).
  4. Интегрировать LLM: при expand вызывать model.generate с несколькими примерами (можно через do_sample=True, num_return_sequences=k).
  5. После завершения MCTS выбрать лучший путь (тот, у которого максимальный Q в корне, или полный путь с наибольшим value на листе).

Ожидаемый результат этапа Функция mcts_inference(prompt) -> final_answer. Тест на 5 примерах — получает ответы, время ~30 секунд на пример.


Этап 3: Интеграция verifier (2 часа)

Действия

  1. Выбрать подход:
    • Option A Rule-based verifier: для каждого полного rollout извлечь финальное число и сравнить с правильным (если датасет train split). Использовать это как reward для backpropagation.
    • Option B Reward model: загрузить предобученный ревард-модель (например, OpenAssistant/reward-model-deberta-v3-large-v2). Подавать полный текст ответа (с промптом) в модель, получить scalar reward (0-1). Использовать как score в backpropagation.
  2. Если выбрали Option B, написать обёртку: verifier.score(prompt + generation) -> float.
  3. Модифицировать simulate(node): после генерации полного ответа вызвать verifier, получить reward.
  4. Убедиться, что reward попадает в backpropagation правильно.

Важно Для Option B reward модель должна быть совместима по токенизации (часто отдельный токенайзер). Использовать AutoTokenizer.from_pretrained.

Ожидаемый результат этапа MCTS теперь использует verifier. Запустить на 10 примерах — accuracy должна быть не ниже baseline (если нет, проверить масштаб reward, c параметр UCB).


Этап 4: Оптимизация и тестирование (2 часа)

Действия

  1. Настроить гиперпараметры MCTS:
    • c (exploration constant) — попробовать 0.5, 1.0, 2.0.
    • num_simulations — 20, 50, 100.
    • expand_k (число детей) — 3, 5.
    • temperature при generate — 0.6, 0.8, 1.0.
  2. Сравнить accuracy на тестовых 100 примерах: Greedy (baseline), Best-of-5 (простой sampling с выбором через verifier), AlphaSearch (MCTS+verifier).
  3. Построить таблицу/график сравнения.
  4. Замерить среднее время инференса на один пример для каждой стратегии.

Ожидаемый результат этапа Таблица метрик:

СтратегияAccuracyСреднее время (с)
Greedy42%3
Best-of-547%15
AlphaSearch50%45

Этап 5: Анализ и документирование (1 час)

Действия

  1. Задокументировать архитектуру: как соединены LLM, MCTS, verifier.
  2. Написать небольшой отчет: какие примеры улучшились, какие ухудшились, почему.
  3. Зафиксировать лучшие гиперпараметры и конфигурацию.
  4. Выгрузить код в репозиторий (GitHub/GitLab) с README.

Ожидаемый результат этапа Файл REPORT.md с результатами и выводами.


5. Критерии приемки (Definition of Done)

  • Реализован MCTS с UCB1, expand, simulate, backpropagate.
  • Verifier интегрирован в simulate (rule-based или reward model).
  • Альфа-версия запускается на 100 примерах GSM8K.
  • Accuracy AlphaSearch выше, чем greedy (разница > 5 процентных пунктов).
  • Получен хотя бы один из двух вариантов verifier (rule-based или модель).
  • Создана таблица сравнения greedy, best-of-5, AlphaSearch.
  • Зафиксировано среднее время инференса.
  • Код оформлен в виде модулей (mcts.py, verifier.py, inference.py) с README.
  • Воспроизводимость: requirements.txt, скрипт run_experiments.sh или python main.py.

6. Ожидаемый результат

Основной артефакт Репозиторий с папкой alphasearch, содержащей:

  • mcts.py — реализация дерева MCTS.
  • verifier.py — класс verifier (rule-based или модель).
  • inference.py — функция answer = alpha_search(prompt, model, verifier).
  • run_experiments.py — сравнение стратегий, вывод метрик.
  • REPORT.md — отчёт с таблицей accuracy и времени.
  • requirements.txt — зависимости.

Дополнительно График accuracy_vs_time.png (опционально).


7. Возможные сложности и их решение

СложностьРешение
LLM слишком медленная для множества симуляцийИспользовать vLLM для batch inference; уменьшить симуляции до 20; использовать маленькую модель (Phi-2)
Verifier (reward model) несовместим по токенайзеруВзять rule-based verifier, который можно натренировать на валидации ответа по числу
MCTS зависает в рекурсии / бесконечные rolloutОграничить глубину rollouts (max_new_tokens=512) и количество симуляций; добавить максимальное количество токенов на узел
UCB1 не сходитсяУвеличить exploration constant c; использовать нормализацию reward (z-score)
Разница в форматах вывода модели (например, ответ без спецметок)Стандартизировать постобработку: извлекать последнее число после символа =, Result:, или answer:. Использовать regex r'[-+]?\d+\.?\d*'
GPU Out-of-Memory при использовании нескольких моделейИспользовать одну модель для генерации, verifier загрузить на CPU (если медленно — на GPU с offloading)

8. Бюджет времени (оценка)

ЭтапВремя
Этап 1: baseline1 час
Этап 2: MCTS3 часа
Этап 3: verifier2 часа
Этап 4: оптимизация и тест2 часа
Этап 5: документирование1 час
Итого9 часов

Примечание Если используется предобученный reward model, этап 3 может занять меньше (1 час). Для первого раза рекомендуется заложить +2 часа на отладку скрытых багов.


9. Связанные вопросы из базы знаний

ВопросТема
12MCTS в AlphaGo/AlphaZero
17Beam search и лучшие стратегии декодирования
23Chain-of-Thought prompting
34Reward models и RLHF
41Тестово-временное масштабирование (test-time compute)
55Self-consistency decoding
78Verifiers и их обучение на reasoning tasks
94Best-of-N sampling
112UCB1 и exploration-exploitation
205Масштабирование inference для математических задач

10. Чек-лист самопроверки

  • Я разобрался(лась) с архитектурой MCTS: selection, expansion, simulation, backpropagation.
  • Я выбрал(а) подходящий verifier и настроил(а) его связку с LLM.
  • Я проверил(а) baseline accuracy и убедился(лась), что AlphaSearch даёт прирост.
  • Я задокументировал(а) гиперпараметры, чтобы эксперимент можно было воспроизвести.
  • Я написал(а) простой тест на 1–2 примерах, который показывает корректность выдачи.