Настроить wave decoding для коротких ответов
ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Настроить wave decoding для коротких ответов
1. Цель задачи
Реализовать и протестировать метод wave decoding (вариант speculative/blockwise параллельной генерации) для получения коротких ответов длиной 5–10 токенов. Добиться ускорения генерации не менее чем в 2 раза по сравнению с обычной авторегрессивной генерацией, сохранив качество ответов (perplexity или семантическая близость не хуже baseline).
Ключевой результат Рабочий пайплайн wave decoding, измеренное ускорение ≥2× на коротких запросах, подтверждённое бенчмарком.
2. Исходные данные
| Что нужно | Откуда взять |
|---|---|
| Предобученные модели (draft + target) | Hugging Face: distilgpt2 (draft), gpt2 (target) |
| Датасет коротких запросов (10–20 примеров) | Составить вручную или взять из datasets (например, daily_dialog — первые 20 реплик) |
| Baseline latency | Замерить на этапе 2 |
| Окружение с GPU (min 4GB VRAM) | Локально / Colab / сервер |
| Python 3.10+, PyTorch 2.0+ | Среда разработки |
Если нет реальной GPU — симулируем:
- Использовать CPU и torch.compile(..., backend="eager") для сравнительных замеров — ускорение может быть меньше, но относительная динамика сохранится.
- Для draft-модели взять distilgpt2, для target — gpt2 (обе работают на CPU, но медленно). Основной target — измерение методики, а не абсолютная скорость.
3. Технологический стек
| Компонент | Инструменты | Назначение |
|---|---|---|
| Модели | Hugging Face transformers, AutoModelForCausalLM | Загрузка draft/target |
| Фреймворк | PyTorch 2.0+ | Тензорные вычисления, torch.compile |
| Кэш | KV cache (ручное управление) | Ускорение верификации |
| Бенчмаркинг | time, torch.cuda.Event | Замер латентности |
| Качество | perplexity (из transformers) или BERTScore | Оценка деградации |
| Оптимизация | torch.compile, batch processing | Дополнительное ускорение |
| Виртуализация | Conda / venv + requirements.txt | Воспроизводимость |
4. Этапы выполнения
Этап 1: Подготовка окружения и данных (30 минут)
Действия
- Установить зависимости:
pip install transformers torch datasets tqdm scipy - Загрузить модели:
from transformers import AutoModelForCausalLM, AutoTokenizer draft_model = AutoModelForCausalLM.from_pretrained("distilgpt2") target_model = AutoModelForCausalLM.from_pretrained("gpt2") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - Подготовить датасет из 20 коротких промптов (например, вопросы или начала предложений):
prompts = [ "The capital of France is", "In brief, the answer is", # ... ещё 18 ] - Определить длину генерации gen_len = 8 (целевая длина ответа 5–10 токенов).
Ожидаемый результат этапа Загруженные модели, готовый список промптов, фиксированный gen_len.
Этап 2: Baseline — обычная генерация (20 минут)
Действия
- Написать функцию baseline_generate(prompt): использовать target_model.generate с max_new_tokens=gen_len, без дополнительных трюков.
- Измерить latency на каждом промпте (медиана по 5 запускам):
import time latencies = [] for p in prompts: inputs = tokenizer(p, return_tensors="pt") t0 = time.perf_counter() out = target_model.generate(**inputs, max_new_tokens=gen_len, do_sample=False) t1 = time.perf_counter() latencies.append(t1 - t0) baseline_avg = sum(latencies) / len(latencies) - Сохранить сгенерированные ответы для дальнейшего сравнения качества.
- Вычислить baseline perplexity на ответах (например, средняя loss модели на сгенерированных токенах).
Ожидаемый результат этапа Численное значение baseline_avg (в секундах) и baseline качество.
Этап 3: Реализация wave decoding (speculative decoding) (2 часа)
Действия
-
Определить протокол волны (wave):
- Драфт-модель генерирует K = gen_len токенов за один авторегрессивный проход (с KV cache).
- Target-модель верифицирует все K токенов одним forward pass (с маскированием будущих позиций). Если все токены приняты, берётся вся волна; иначе — отбрасываются токены до первого несовпадения, и процесс повторяется.
-
Реализовать класс
WaveDecoderclass WaveDecoder: def __init__(self, draft, target, tokenizer, K=8): self.draft = draft self.target = target self.tokenizer = tokenizer self.K = K def generate(self, input_ids, max_new_tokens): # input_ids: (1, seq_len) # Генерация волнами ...- Внутри: на каждом шаге запускаем draft.generate с max_new_tokens=self.K, получаем
draft_ids. - Подготавливаем вход для target: конкатенация
input_ids+draft_ids(полная последовательность). - Forward pass target с
use_cache=True, получаем логиты для всех позиций. - Применяем rejection sampling: сравниваем токены target с draft-токенами по вероятностям (используем top-1).
(Для простоты можно взять argmax target на последней позиции перед каждым кандидатом и сравнить с соответствующим draft-токеном.) - Принимаем токены, начиная с начала, до первого несовпадения. Как минимум 1 токен принимается (следующий после input_ids).
- Добавляем принятые токены к
input_ids, повторяем до достижения max_new_tokens.
- Внутри: на каждом шаге запускаем draft.generate с max_new_tokens=self.K, получаем
-
Обработка коротких ответов поскольку max_new_tokens ≤ 10, часто одна волна покрывает весь ответ. Настроить
K = 8. -
Тестирование на одном промпте
wd = WaveDecoder(draft_model, target_model, tokenizer, K=8) inputs = tokenizer(prompts[0], return_tensors="pt") out = wd.generate(inputs.input_ids, max_new_tokens=8) print(tokenizer.decode(out[0]))
Ожидаемый результат этапа Рабочий wave decoding, генерирующий осмысленные ответы, сопоставимые с baseline.
Этап 4: Бенчмаркинг и оптимизация (1 час)
Действия
-
Измерить latency wave decoding на всех промптах (аналогично этапу 2):
- Запустить
wd.generateдля каждого промпта, замерить время. - Вычислить среднее
wave_avg.
- Запустить
-
Рассчитать speedup
speedup = baseline_avg / wave_avg -
Оценить качество
- Собрать сгенерированные ответы wave decoding.
- Вычислить perplexity (или BERTScore относительно baseline).
- Убедиться, что деградация ≤ 5% (допустимо небольшое ухудшение при большом ускорении).
-
Оптимизация (если speedup < 2):
- Использовать
torch.compileдля обеих моделей:draft_model = torch.compile(draft_model, mode="reduce-overhead") target_model = torch.compile(target_model, mode="reduce-overhead") - Увеличить batch size (если несколько промптов обрабатывать одновременно) — в данной задаче не требуется, т.к. ответы короткие.
- Подобрать
K(5, 8, 10) — возможно, меньшее K даёт большее ускорение за счёт меньшего числа rejected токенов. - Включить
torch.inference_mode()на прогонах.
- Использовать
-
Повторить замеры после оптимизации, записать финальный speedup.
Ожидаемый результат этапа Таблица с метриками (baseline latency, wave latency, speedup, качество) и финальный speedup ≥2.
Этап 5: Оформление результатов (30 минут)
Действия
- Подготовить отчёт в виде Jupyter Notebook / Python-скрипта с комментариями.
- Включить:
- Написать раздел "Проблемы и решения" (см. п.9).
Ожидаемый результат этапа Оформленный артефакт — wave_decoding_report.ipynb (или .py) с измеримыми результатами.
5. Критерии приемки (Definition of Done)
- Реализован класс
WaveDecoderс методомgenerate. - Baseline latency замерен и зафиксирован.
- Wave decoding latency замерен на том же наборе промптов.
- Speedup ≥ 2.0 (среднее по всем промптам).
- Качество ответов не ухудшилось более чем на 5% по метрике perplexity (или BERTScore >0.95 от baseline).
- Отчёт содержит все замеры, график speedup и описание метода.
- Код запускается воспроизводимо (зафиксированы версии пакетов, seed).
6. Ожидаемый результат
- Файл
wave_decoding_report.ipynb(или.py) со следующим содержанием:- Загрузка моделей и промптов.
- Реализация
WaveDecoder. - Функции бенчмаркинга.
- Таблицы и графики speedup.
- Вывод о достижении цели.
- Дополнительно (опционально):
- Сравнение при разных
K(5, 8, 10). - Использование
torch.compileи его влияние. - Лог rejected токенов (статистика принятия волны).
- Сравнение при разных
7. Возможные сложности и их решение
| Сложность | Решение |
|---|---|
| Wave decoding генерирует мусор (большое количество rejected токенов) | Уменьшить K до 5, или использовать более качественную draft-модель (например, TinyLlama вместо distilgpt2). |
| Speedup меньше 2x из-за накладных расходов Python | Использовать torch.compile, объединять forward'ы в батч, замерить чистое время GPU (игнорировать инициализацию). |
| Разные длины ответов для каждого промпта (gen_len фиксирован) | В ТЗ задан диапазон 5–10; можно взять gen_len=8 как компромисс, либо варьировать и показать среднее. |
| Draft-модель слишком медленная на CPU | Перейти на GPU (даже T4 в Colab). Если нет, можно использовать distilgpt2 и принять меньший speedup (но методика верна). |
| Оценка качества perplexity может быть некорректной из-за разной длины | Нормализовать по количеству токенов или использовать фиксированную длину ответа (pad до gen_len). |
8. Бюджет времени (оценка)
| Этап | Время |
|---|---|
| 1. Подготовка окружения и данных | 30 мин |
| 2. Baseline — обычная генерация | 20 мин |
| 3. Реализация wave decoding | 2 ч |
| 4. Бенчмаркинг и оптимизация | 1 ч |
| 5. Оформление результатов | 30 мин |
| Итого (чистое время) | 4 ч 20 мин |
Примечание Первый раз рекомендуется закладывать +50% (≈6.5 ч) на отладку неожиданных расхождений и изучение механизма rejection sampling.
9. Связанные вопросы из базы знаний
| Вопрос | Тема |
|---|---|
| 15 | Speculative decoding: принцип и реализация |
| 42 | KV cache: как ускорить авторегрессивную генерацию |
| 78 | Blockwise parallel decoding: обзор методов |
| 124 | Метрики скорости генерации: latency, throughput |
| 205 | Rejection sampling в speculative decoding |
| 311 | torch.compile для моделей transformers |
| 409 | Сравнение draft/target моделей (размер, скорость) |
| 522 | Perplexity как метрика качества генерации |
| 678 | Оптимизация batch processing в инференсе |
| 789 | Настройка длины волны в wave decoding |
10. Чек-лист самопроверки
- Я загрузил обе модели (draft и target) и проверил, что они работают по отдельности.
- Я реализовал корректный rejection sampling (сравнение argmax target с draft-токенами).
- Я измерил latency baseline и wave decoding на одних и тех же промптах, усреднил по 5 запускам.
- Я вычислил speedup и убедился, что он ≥2.
- Я проверил качество ответов (визуально или по метрике) — не стало хуже.
- Я зафиксировал seed и версии пакетов для воспроизводимости.