Реализовать speculative decoding с draft моделью
ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Реализовать speculative decoding с draft моделью
1. Цель задачи
Освоить технику speculative decoding — ускорение инференса большой языковой модели (target) с помощью маленькой быстрой модели (draft). Реализовать pipeline, в котором draft-модель (1B параметров) генерирует последовательность кандидатов, а target-модель (8B) проверяет их параллельно. Добиться практического ускорения в 1.5–2x по сравнению с последовательной генерацией только target-моделью.
Ключевой результат Рабочий скрипт или ноутбук, демонстрирующий speculative decoding на заданном промпте с замером скорости (токенов/сек) и измеренным speedup в диапазоне 1.5–2x.
2. Исходные данные
Перед началом необходимо иметь:
| Что нужно | Откуда взять |
|---|---|
| Draft-модель ~1B параметров | Hugging Face: TinyLlama/TinyLlama-1.1B-Chat-v1.0 или Qwen/Qwen2.5-1.5B-Instruct |
| Target-модель ~8B параметров | Hugging Face: meta-llama/Llama-3.1-8B-Instruct или mistralai/Mistral-7B-Instruct-v0.3 |
| GPU с >=16GB VRAM (для инференса 8B в FP16) | Локальный сервер, Colab Pro / Kaggle GPU (T4/A100) |
| Python 3.10+ | Установлен в окружении |
Если нет реального GPU (только CPU) — симулируем:
- Используем stub для target-модели: симулируем задержку (~0.1 с на токен) и возвращаем детерминированный ответ.
- Draft-модель — любая маленькая (например,
distilgpt2), работает на CPU. - Измеряем логическое ускорение на синтетическом трафике без реального вычисления.
3. Технологический стек
| Компонент | Инструменты | Назначение |
|---|---|---|
| Инференс LLM | PyTorch + Hugging Face Transformers + accelerate | Загрузка и запуск обеих моделей |
| Speculative decoding | Реализация руками (один из алгоритмов: голый speculative / Medusa) | Ускорение генерации |
| Бенчмаркинг | time.perf_counter, tqdm | Замер времени и скорости |
| Визуализация | matplotlib, pandas | Графики speedup, сравнение |
| Окружение | Jupyter Notebook / Python script | Разработка и демонстрация |
4. Этапы выполнения
Этап 1: Подготовка окружения и загрузка моделей (30-45 мин)
Действия
-
Создать окружение
conda create -n spec_dec python=3.10 conda activate spec_dec pip install torch transformers accelerate sentencepiece tqdm matplotlib pandas -
Загрузить draft-модель (1B)
ИспользоватьAutoModelForCausalLM.from_pretrainedсtorch_dtype=torch.float16(если GPU).
Пример:from transformers import AutoModelForCausalLM, AutoTokenizer draft_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" draft_model = AutoModelForCausalLM.from_pretrained(draft_name, torch_dtype=torch.float16, device_map="auto") draft_tokenizer = AutoTokenizer.from_pretrained(draft_name) -
Загрузить target-модель (8B)
Аналогично, но для экономии VRAM использоватьbitsandbytes4-bit квантизацию:target_name = "meta-llama/Llama-3.1-8B-Instruct" target_model = AutoModelForCausalLM.from_pretrained( target_name, torch_dtype=torch.float16, load_in_4bit=True, # 4-bit экономит память device_map="auto" ) target_tokenizer = AutoTokenizer.from_pretrained(target_name) -
Проверить работоспособность каждой модели по отдельности
Сгенерировать один промпт (например, "What is the capital of France?") и вывести ответ.
Ожидаемый результат этапа
Обе модели загружены, успешно генерируют текст при прямом вызове model.generate().
Этап 2: Реализация базового speculative decoding (2-3 часа)
Действия
-
Написать функцию принятия / отклонения (rejection sampling)
Алгоритм (стандартный speculative decoding):- Draft-модель генерирует
Kтокенов за один шаг (каждый токен выбирается из её распределения ( q )). - Target-модель вычисляет логиты для последнего токена и получает распределение ( p ).
- Для каждого токена ( t_i ) принимаем его с вероятностью ( \min(1, p(t_i)/q(t_i)) ).
- При первом отклонении генерируем токен из скорректированного распределения ( \max(0, p - q) ) и останавливаем цикл.
def speculative_generate(draft_model, target_model, input_ids, max_new_tokens=50, K=5, temperature=1.0): # input_ids: (1, seq_len) accepted = 0 for _ in range(max_new_tokens // K + 1): # 1. Draft: generate K candidates draft_output = draft_model.generate(input_ids, max_new_tokens=K, do_sample=True, temperature=temperature, pad_token_id=draft_tokenizer.eos_token_id) candidates = draft_output[0][input_ids.shape[1]:] # (K,) # 2. Get target logits for last position with torch.no_grad(): target_logits = target_model(input_ids).logits[0, -1, :] # vocab # 3. Rejection sampling loop for i in range(K): q = draft_model output distribution for token i p = target_logits.softmax(-1) ratio = p[candidates[i]] / q[candidates[i]] if torch.rand(1).item() < min(1, ratio): accepted += 1 input_ids = torch.cat([input_ids, candidates[i].unsqueeze(0).unsqueeze(0)], dim=-1) else: # resample from (p - q)+ new_probs = torch.clamp(p - q, min=0) new_probs /= new_probs.sum() new_token = torch.multinomial(new_probs, 1) input_ids = torch.cat([input_ids, new_token], dim=-1) break if len(candidates) == accepted: # all K accepted, continue pass return input_idsПримечание: на практике нужно получать q для каждого шага; проще сохранять логиты draft-модели во время её генерации.
- Draft-модель генерирует
-
Оптимизация: за раз передавать target-модели весь блок принятых токенов
Вместо пошагового forward'a, можно передать всю последовательностьinput_ids + accepted_tokensи получить логиты сразу для всех позиций. Это повышает эффективность. -
Настроить параметры: ( K ) (число кандидатов)
Обычно ( K = 5..10 ). Экспериментально подобрать под задачу.
Ожидаемый результат этапа
Функция speculative_generate корректно выводит текст (визуально похожий на target-модельный) и не падает с ошибками.
Этап 3: Интеграция и тестирование на простом промпте (30-45 мин)
Действия
-
Написать бенчмарк-функцию
Сравнить три режима:- Только target-модель (baseline).
- Только draft-модель (для справки).
- Speculative decoding (draft + target).
Для каждого режима замерить:
-
Запустить на 3-5 промптах длиной 20-50 токенов
Выбрать разнообразные темы (факты, рассуждения, код). -
Вывести таблицу результатов
Пример:
Ожидаемый результат этапа
Чёткое численное подтверждение speedup >1.5x хотя бы на одном промпте.
Этап 4: Оптимизация и тонкая настройка (1-2 часа)
Действия
-
Варьировать ( K )
Попробовать ( K = 3, 5, 7, 10, 15 ). Построить график speedup vs K. Определить оптимальное. -
Включить greedy decoding vs sampling
Проверить, как работает speculative приtemperature=0(greedy). Желательно использовать одинаковый seed для сравнения. -
Оптимизация batch-обработки
Если позволяет память, запустить несколько промптов в батче. -
Проверить влияние промпта на акцепт-рейт (acceptance rate)
Вычислить средний % принятых токенов. Если <60%, возможно, draft-модель плохо аппроксимирует target.
Ожидаемый результат этапа
Оптимальное ( K ), acceptance rate 60-80%, стабильный speedup 1.5-2x.
Этап 5: Документирование и демонстрация (30-45 мин)
Действия
- Оформить ноутбук или скрипт с чёткими ячейками: загрузка, speculative-функция, бенчмарк, графики.
- Добавить комментарии к ключевым шагам алгоритма.
- Подготовить краткое описание (в README или Markdown-ячейке) с таблицей результатов и выводом.
Ожидаемый результат этапа
Готовый артефакт (.ipynb или .py), который можно запустить и воспроизвести ускорение.
5. Критерии приемки (Definition of Done)
- Реализована функция speculative decoding, принимающая две модели и промпт.
- Код корректно работает на GPU (или CPU-симуляции).
- Проведено не менее 3 замеров для каждого режима (baseline, draft, spec).
- Speedup >= 1.5x при ( K = 5 ) на хотя бы одном промпте.
- График зависимости speedup от ( K ) построен и интерпретирован.
- Acceptance rate указан в результатах.
- Ноутбук/скрипт воспроизводим (указаны версии библиотек, seed).
- Написаны комментарии к алгоритму rejection sampling.
6. Ожидаемый результат
Основной артефакт
speculative_decoding.ipynb (Jupyter Notebook) или speculative_decoding.py (Python-скрипт), содержащий:
- Загрузку draft (1B) и target (8B) моделей.
- Функцию
speculative_generate. - Бенчмарк-функцию с измерением времени.
- Таблицу и график speedup.
- Выводы о параметре ( K ) и acceptance rate.
Дополнительно (опционально):
requirements.txtс фиксированными версиями.README.mdс краткой инструкцией.
7. Возможные сложности и их решение
| Сложность | Решение |
|---|---|
| Out‑of‑memory (OOM) при загрузке 8B модели на GPU с <20GB VRAM | Использовать 4-bit квантизацию (load_in_4bit=True) или offload части слоёв на CPU (device_map="sequential"). |
| Низкий acceptance rate (<40%) | Увеличить ( K ), снизить temperature, или заменить draft-модель на более сильную (например, 2.7B). |
| Долгая генерация draft-модели (если она не быстрая) | Взять действительно маленькую модель (0.5-1B). Можно квантовать draft. |
| Несовпадение токенизаторов | Иметь общий токенизатор (можно использовать target-токенизатор для draft, если словарь пересекается). |
| Численная нестабильность при делении на малые вероятности q | Добавить small epsilon (1e-10) при вычислении ratio. |
8. Бюджет времени (оценка)
| Этап | Время |
|---|---|
| Этап 1: Подготовка и загрузка | 30-45 мин |
| Этап 2: Реализация speculative decoding | 2-3 часа |
| Этап 3: Тестирование на промпте | 30-45 мин |
| Этап 4: Оптимизация | 1-2 часа |
| Этап 5: Документирование | 30-45 мин |
| Итого | 5-7 часов |
Примечание: для первого раза может потребоваться больше времени на отладку rejection sampling. Рекомендуется начинать с симуляции на CPU.
9. Связанные вопросы из базы знаний
| Вопрос | Тема |
|---|---|
| 45 | Каковы основные принципы speculative decoding? |
| 67 | Как работает rejection sampling в контексте LLM? |
| 102 | Какие метрики используются для оценки ускорения инференса? |
| 215 | Чем отличается draft-модель от target-модели в speculative decoding? |
| 310 | Как квантизация влияет на скорость и качество генерации? |
| 412 | Что такое acceptance rate и как его повысить? |
| 520 | Какие существуют альтернативы speculative decoding (Medusa, Lookahead)? |
| 630 | Как измерить время генерации без влияния библиотечных накладных расходов? |
| 745 | Какие проблемы возникают при использовании разных токенизаторов? |
| 888 | Как выбрать оптимальное K для speculative decoding? |
10. Чек-лист самопроверки
- Я загрузил обе модели и проверил их отдельную работоспособность.
- Моя функция speculative decoding корректно обрабатывает случай отсутствия принятых токенов.
- Я сравнил вывод speculative decoding с выводом только target-модели — тексты не расходятся катастрофически.
- Я зафиксировал seed для воспроизводимости результатов.
- Я построил график speedup от K и выбрал оптимальное значение на основе acceptance rate.