English translation is not available yet. Showing Russian content.

Настроить wave decoding для коротких ответов

ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Настроить wave decoding для коротких ответов

1. Цель задачи

Реализовать и протестировать метод wave decoding (вариант speculative/blockwise параллельной генерации) для получения коротких ответов длиной 5–10 токенов. Добиться ускорения генерации не менее чем в 2 раза по сравнению с обычной авторегрессивной генерацией, сохранив качество ответов (perplexity или семантическая близость не хуже baseline).

Ключевой результат Рабочий пайплайн wave decoding, измеренное ускорение ≥2× на коротких запросах, подтверждённое бенчмарком.


2. Исходные данные

Что нужноОткуда взять
Предобученные модели (draft + target)Hugging Face: distilgpt2 (draft), gpt2 (target)
Датасет коротких запросов (10–20 примеров)Составить вручную или взять из datasets (например, daily_dialog — первые 20 реплик)
Baseline latencyЗамерить на этапе 2
Окружение с GPU (min 4GB VRAM)Локально / Colab / сервер
Python 3.10+, PyTorch 2.0+Среда разработки

Если нет реальной GPU — симулируем:

  1. Использовать CPU и torch.compile(..., backend="eager") для сравнительных замеров — ускорение может быть меньше, но относительная динамика сохранится.
  2. Для draft-модели взять distilgpt2, для target — gpt2 (обе работают на CPU, но медленно). Основной target — измерение методики, а не абсолютная скорость.

3. Технологический стек

КомпонентИнструментыНазначение
МоделиHugging Face transformers, AutoModelForCausalLMЗагрузка draft/target
ФреймворкPyTorch 2.0+Тензорные вычисления, torch.compile
КэшKV cache (ручное управление)Ускорение верификации
Бенчмаркингtime, torch.cuda.EventЗамер латентности
Качествоperplexity (из transformers) или BERTScoreОценка деградации
Оптимизацияtorch.compile, batch processingДополнительное ускорение
ВиртуализацияConda / venv + requirements.txtВоспроизводимость

4. Этапы выполнения

Этап 1: Подготовка окружения и данных (30 минут)

Действия

  1. Установить зависимости:
    pip install transformers torch datasets tqdm scipy
    
  2. Загрузить модели:
    from transformers import AutoModelForCausalLM, AutoTokenizer
    draft_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
    target_model = AutoModelForCausalLM.from_pretrained("gpt2")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
  3. Подготовить датасет из 20 коротких промптов (например, вопросы или начала предложений):
    prompts = [
        "The capital of France is",
        "In brief, the answer is",
        # ... ещё 18
    ]
    
  4. Определить длину генерации gen_len = 8 (целевая длина ответа 5–10 токенов).

Ожидаемый результат этапа Загруженные модели, готовый список промптов, фиксированный gen_len.


Этап 2: Baseline — обычная генерация (20 минут)

Действия

  1. Написать функцию baseline_generate(prompt): использовать target_model.generate с max_new_tokens=gen_len, без дополнительных трюков.
  2. Измерить latency на каждом промпте (медиана по 5 запускам):
    import time
    latencies = []
    for p in prompts:
        inputs = tokenizer(p, return_tensors="pt")
        t0 = time.perf_counter()
        out = target_model.generate(**inputs, max_new_tokens=gen_len, do_sample=False)
        t1 = time.perf_counter()
        latencies.append(t1 - t0)
    baseline_avg = sum(latencies) / len(latencies)
    
  3. Сохранить сгенерированные ответы для дальнейшего сравнения качества.
  4. Вычислить baseline perplexity на ответах (например, средняя loss модели на сгенерированных токенах).

Ожидаемый результат этапа Численное значение baseline_avg (в секундах) и baseline качество.


Этап 3: Реализация wave decoding (speculative decoding) (2 часа)

Действия

  1. Определить протокол волны (wave):

    • Драфт-модель генерирует K = gen_len токенов за один авторегрессивный проход (с KV cache).
    • Target-модель верифицирует все K токенов одним forward pass (с маскированием будущих позиций). Если все токены приняты, берётся вся волна; иначе — отбрасываются токены до первого несовпадения, и процесс повторяется.
  2. Реализовать класс WaveDecoder

    class WaveDecoder:
        def __init__(self, draft, target, tokenizer, K=8):
            self.draft = draft
            self.target = target
            self.tokenizer = tokenizer
            self.K = K
    
        def generate(self, input_ids, max_new_tokens):
            # input_ids: (1, seq_len)
            # Генерация волнами
            ...
    
    • Внутри: на каждом шаге запускаем draft.generate с max_new_tokens=self.K, получаем draft_ids.
    • Подготавливаем вход для target: конкатенация input_ids + draft_ids (полная последовательность).
    • Forward pass target с use_cache=True, получаем логиты для всех позиций.
    • Применяем rejection sampling: сравниваем токены target с draft-токенами по вероятностям (используем top-1).
      (Для простоты можно взять argmax target на последней позиции перед каждым кандидатом и сравнить с соответствующим draft-токеном.)
    • Принимаем токены, начиная с начала, до первого несовпадения. Как минимум 1 токен принимается (следующий после input_ids).
    • Добавляем принятые токены к input_ids, повторяем до достижения max_new_tokens.
  3. Обработка коротких ответов поскольку max_new_tokens ≤ 10, часто одна волна покрывает весь ответ. Настроить K = 8.

  4. Тестирование на одном промпте

    wd = WaveDecoder(draft_model, target_model, tokenizer, K=8)
    inputs = tokenizer(prompts[0], return_tensors="pt")
    out = wd.generate(inputs.input_ids, max_new_tokens=8)
    print(tokenizer.decode(out[0]))
    

Ожидаемый результат этапа Рабочий wave decoding, генерирующий осмысленные ответы, сопоставимые с baseline.


Этап 4: Бенчмаркинг и оптимизация (1 час)

Действия

  1. Измерить latency wave decoding на всех промптах (аналогично этапу 2):

    • Запустить wd.generate для каждого промпта, замерить время.
    • Вычислить среднее wave_avg.
  2. Рассчитать speedup

    speedup = baseline_avg / wave_avg
    
  3. Оценить качество

    • Собрать сгенерированные ответы wave decoding.
    • Вычислить perplexity (или BERTScore относительно baseline).
    • Убедиться, что деградация ≤ 5% (допустимо небольшое ухудшение при большом ускорении).
  4. Оптимизация (если speedup < 2):

    • Использовать torch.compile для обеих моделей:
      draft_model = torch.compile(draft_model, mode="reduce-overhead")
      target_model = torch.compile(target_model, mode="reduce-overhead")
      
    • Увеличить batch size (если несколько промптов обрабатывать одновременно) — в данной задаче не требуется, т.к. ответы короткие.
    • Подобрать K (5, 8, 10) — возможно, меньшее K даёт большее ускорение за счёт меньшего числа rejected токенов.
    • Включить torch.inference_mode() на прогонах.
  5. Повторить замеры после оптимизации, записать финальный speedup.

Ожидаемый результат этапа Таблица с метриками (baseline latency, wave latency, speedup, качество) и финальный speedup ≥2.


Этап 5: Оформление результатов (30 минут)

Действия

  1. Подготовить отчёт в виде Jupyter Notebook / Python-скрипта с комментариями.
  2. Включить:
    • График сравнения latency по каждому промпту (столбчатая диаграмма).
    • Краткое описание реализации wave decoding (алгоритм).
    • Вывод: достигнут ли speedup 2x, и при каких условиях.
  3. Написать раздел "Проблемы и решения" (см. п.9).

Ожидаемый результат этапа Оформленный артефакт — wave_decoding_report.ipynb (или .py) с измеримыми результатами.


5. Критерии приемки (Definition of Done)

  • Реализован класс WaveDecoder с методом generate.
  • Baseline latency замерен и зафиксирован.
  • Wave decoding latency замерен на том же наборе промптов.
  • Speedup ≥ 2.0 (среднее по всем промптам).
  • Качество ответов не ухудшилось более чем на 5% по метрике perplexity (или BERTScore >0.95 от baseline).
  • Отчёт содержит все замеры, график speedup и описание метода.
  • Код запускается воспроизводимо (зафиксированы версии пакетов, seed).

6. Ожидаемый результат

  • Файл wave_decoding_report.ipynb (или .py) со следующим содержанием:
    • Загрузка моделей и промптов.
    • Реализация WaveDecoder.
    • Функции бенчмаркинга.
    • Таблицы и графики speedup.
    • Вывод о достижении цели.
  • Дополнительно (опционально):
    • Сравнение при разных K (5, 8, 10).
    • Использование torch.compile и его влияние.
    • Лог rejected токенов (статистика принятия волны).

7. Возможные сложности и их решение

СложностьРешение
Wave decoding генерирует мусор (большое количество rejected токенов)Уменьшить K до 5, или использовать более качественную draft-модель (например, TinyLlama вместо distilgpt2).
Speedup меньше 2x из-за накладных расходов PythonИспользовать torch.compile, объединять forward'ы в батч, замерить чистое время GPU (игнорировать инициализацию).
Разные длины ответов для каждого промпта (gen_len фиксирован)В ТЗ задан диапазон 5–10; можно взять gen_len=8 как компромисс, либо варьировать и показать среднее.
Draft-модель слишком медленная на CPUПерейти на GPU (даже T4 в Colab). Если нет, можно использовать distilgpt2 и принять меньший speedup (но методика верна).
Оценка качества perplexity может быть некорректной из-за разной длиныНормализовать по количеству токенов или использовать фиксированную длину ответа (pad до gen_len).

8. Бюджет времени (оценка)

ЭтапВремя
1. Подготовка окружения и данных30 мин
2. Baseline — обычная генерация20 мин
3. Реализация wave decoding2 ч
4. Бенчмаркинг и оптимизация1 ч
5. Оформление результатов30 мин
Итого (чистое время)4 ч 20 мин

Примечание Первый раз рекомендуется закладывать +50% (≈6.5 ч) на отладку неожиданных расхождений и изучение механизма rejection sampling.


9. Связанные вопросы из базы знаний

ВопросТема
15Speculative decoding: принцип и реализация
42KV cache: как ускорить авторегрессивную генерацию
78Blockwise parallel decoding: обзор методов
124Метрики скорости генерации: latency, throughput
205Rejection sampling в speculative decoding
311torch.compile для моделей transformers
409Сравнение draft/target моделей (размер, скорость)
522Perplexity как метрика качества генерации
678Оптимизация batch processing в инференсе
789Настройка длины волны в wave decoding

10. Чек-лист самопроверки

  • Я загрузил обе модели (draft и target) и проверил, что они работают по отдельности.
  • Я реализовал корректный rejection sampling (сравнение argmax target с draft-токенами).
  • Я измерил latency baseline и wave decoding на одних и тех же промптах, усреднил по 5 запускам.
  • Я вычислил speedup и убедился, что он ≥2.
  • Я проверил качество ответов (визуально или по метрике) — не стало хуже.
  • Я зафиксировал seed и версии пакетов для воспроизводимости.