Как speculative decoding ускоряет inference? (детально)
Краткий тезис
Speculative decoding (decoding|спекулятивное декодирование) — это метод ускорения авторегрессивной генерации больших языковых моделей без потери качества. Он использует маленькую «draft» модель (1–3B параметров) для быстрой генерации нескольких токенов-кандидатов, а затем большая «target» модель (70B+) проверяет их параллельно за один forward pass. Принимается наиболее длинный префикс, совпадающий с распределением target. Это даёт ускорение в 1.5–3x, но требует дополнительной памяти для KV cache обеих моделей и точности draft.
1. Проблема: почему inference LLM медленный?
Большие языковые модели (LLM) генерируют текст авторегрессивно: каждый новый токен зависит от всех предыдущих. Это означает последовательные вызовы модели, где каждый шаг требует полного forward pass через десятки миллиардов параметров. Для моделей размером 70B и более время генерации одного токена может составлять десятки миллисекунд, а при длине ответа в 1000 токенов общее время становится неприемлемым для интерактивных сценариев.
Ключевые факторы медлительности
- Последовательная природа: no parallelism across tokens.
- Огромное число параметров и вычислений на шаг.
- Ограничения пропускной способности памяти (memory bandwidth) при загрузке весов модели.
Speculative decoding обходит это, позволяя генерировать несколько токенов за один forward pass большой модели.
2. Основная идея speculative decoding
Speculative decoding (спекулятивное декодирование) — метод, при котором маленькая модель (draft) быстро набрасывает несколько предполагаемых токенов, а большая модель (target) верифицирует их параллельно, используя технику rejection sampling (выборка с отклонением). Результат — ускорение генерации при сохранении точности target модели.
Идея в том, что большая модель может выполнить один forward pass на последовательности из K токенов (а не на одном) благодаря parallellable forward pass (параллельный проход по всем позициям). Это не нарушает авторегрессивность, так как draft модель уже сгенерировала токены, а target только корректирует и принимает решение.
3. Детали механизма: шаг за шагом
Рассмотрим один раунд speculative decoding:
- Draft генерация Маленькая draft модель (например, 1.3B параметров) генерирует K токенов авторегрессивно, используя текущий контекст. Это быстро, так как модель мала.
- Целевой forward pass Большая target модель принимает на вход все K токенов одновременно (как последовательность) и вычисляет логиты для каждой позиции за один forward pass. Используется свой KV cache для хранения ключей и значений внимания.
- Проверка (rejection sampling): Для каждой позиции i (от 1 до K) сравнивается вероятность токена, предложенного draft, с вероятностью, которую даёт target. Формальное правило: если ( q_{[text](/wiki/text){target}}(t_i \mid [text](/wiki/text){history}) / q_{[text](/wiki/text){draft}}(t_i \mid [text](/wiki/text){history}) \ge 1 ), токен принимается; иначе — принимается с вероятностью ( q_{[text](/wiki/text){target}}/q_{[text](/wiki/text){draft}} ). При отклонении генерируется новый токен из распределения target.
- Принятие максимального префикса Находим самый длинный префикс из предложенных draft токенов, удовлетворяющий критерию acceptance. Все токены префикса принимаются. Если префикс пуст (первый токен не принят), генерируем первый токен из target.
- Повтор Сдвигаем контекст на принятые токены и повторяем раунд.
Acceptance rate (коэффициент принятия) — средняя доля токенов из K, которые были приняты target моделью. На практике α = 0.5–0.8.
4. Формула ускорения и факторы
Ускорение (speedup) оценивается как:
[ [text](/wiki/text){Speedup} \approx \frac{[alpha](/wiki/alpha) K}{1 + \frac{t_{[text](/wiki/text){draft}}}{t_{[text](/wiki/text){target}}}} ]
где:
- ( t_{[text](/wiki/text){draft}} ) — время одного forward pass draft модели,
- ( t_{[text](/wiki/text){target}} ) — время одного forward pass target модели,
- ( K ) — число предложенных токенов (draft length),
- ( [alpha](/wiki/alpha) ) — acceptance rate.
Если ( t_{[text](/wiki/text){draft}} \ll t_{[text](/wiki/text){target}} ), то Speedup ≈ αK. При α = 0.6 и K = 5 получаем ~3x. На практике из-за накладных расходов (переключение моделей, синхронизация) реальное ускорение 1.5–2x.
Факторы, влияющие на acceptance rate:
- Качество draft модели (чем точнее, тем выше α).
- Сложность задачи (генерация из узких доменов может снижать α).
- Температура выборки (при низкой температуре acceptance выше).
- Размер K (слишком большое K — α падает).
5. Trade-offs и вызовы
| Аспект | Плюс | Минус |
|---|---|---|
| Скорость | Ускорение в 1.5–3x без потери качества | Не всегда достигается, зависит от задачи |
| Память | – | Требуется 2x KV cache (draft + target) – значительный overhead |
| Качество | Идентично target модели (теоретически) | На практике может быть лёгкое смещение из-за реализации sampling |
| Сложность внедрения | Относительно простой алгоритм | Нужна подходящая draft модель; может требовать обучения draft модели под target |
Дополнительные сложности
- Синхронизация KV cache между раундами.
- Разный batch size для draft и target.
- Выбор оптимального K: слишком малое K — мало ускорение; слишком большое K — падает acceptance rate.
6. Сравнение с другими методами ускорения
| Метод | Суть | Потеря качества? | Overhead |
|---|---|---|---|
| Квантование (Quantization) | Уменьшение точности весов (int8, int4) | Небольшая (или нулевая при careful quantization) | Нет дополнительной модели |
| Прунинг (Pruning) | Удаление «неважных» параметров | Умеренная | Нет дополнительной модели |
| Distillation | Обучение маленькой модели имитировать большую | Небольшая, но модель отдельная | Требуется обучение |
| Speculative decoding | Маленькая + большая, параллельная проверка | Нет (теоретически) | Две модели в памяти |
| Flash Attention | Оптимизация внимания | Нет | Только change kernel |
Speculative decoding уникален тем, что не требует дообучения target и сохраняет её точность, но добавляет overhead памяти.
7. Варианты speculative decoding
- Стандартный (Leviathan et al., 2022): Отдельная draft модель, обученная на том же корпусе.
- Medusa (Cai et al., 2024): Добавляет «головы» (heads) для предсказания нескольких токенов сразу, без отдельной модели.
- Eagle (Zhou et al., 2024): Использует динамический draft без отдельной модели, предсказывая следующий токен на основе скрытых состояний.
- Self-speculative decoding Target модель сама себе draft через упрощённые слои.
8. Когда speculative decoding наиболее эффективен?
- Онлайн-сервисы с низкой latency: чат-боты, ассистенты.
- Большие модели (70B+), где время одного forward pass доминирует.
- Сценарии с длинной генерацией (500+ токенов).
- Задачи с высокой предсказуемостью (код, формулы), где draft может быть точным.
Неэффективен: для маленьких моделей (<7B), для очень коротких ответов, при недостатке памяти GPU.
9. Пример кода (псевдокод)
def speculative_decoding(draft_model, target_model, prompt, K=5):
context = tokenize(prompt)
accepted_tokens = []
while not finished:
# 1. Draft генерация K токенов
draft_tokens = draft_model.generate(context, max_new_tokens=K)
# 2. Один forward pass target на все K токенов
target_logits = target_model.forward(torch.cat([context, draft_tokens]))
# 3. Rejection sampling для позиций i=0..K-1
accepted_prefix = []
for i in range(K):
q_draft = draft_model.logits(context + draft_tokens[:i], draft_tokens[i])
q_target = target_logits[..., i, draft_tokens[i]]
r = random()
if r <= min(1, q_target / q_draft):
accepted_prefix.append(draft_tokens[i])
else:
new_token = sample_from(target_logits[:, i, :])
accepted_prefix.append(new_token)
break
accepted_tokens.extend(accepted_prefix)
context = torch.cat([context, accepted_prefix])
return decode(accepted_tokens)
Примечание: В реальных реализациях часто используют более эффективные методы, например, принятие целого префикса без break внутри цикла.
10. Практические рекомендации
- Выбор draft модели Она должна быть быстрой (малое время forward pass) и достаточно точной (высокий acceptance rate). Обычно выбирают модель того же семейства, в 10–50x меньше по размеру.
- Обучение draft модели Можно обучить дистиллированную версию target на её же выходах, чтобы повысить acceptance rate.
- Оптимизация KV cache Используйте PagedAttention (vLLM) для эффективного кэширования.
- Динамический K Настраивайте K в зависимости от acceptance rate на лету.
Пет-проект для закрепления
Задача Реализовать простой speculative decoding с использованием двух предобученных моделей из библиотеки Hugging Face (например, distilgpt2 как draft и gpt2 как target).
Инструменты Python, PyTorch, Hugging Face Transformers.
Шаги:
- Загрузить модели и токенизаторы.
- Написать функцию draft_generate, которая авторегрессивно генерирует K токенов с маленькой моделью.
- Написать функцию target_verify, которая выполняет один forward pass на конкатенированном контексте и возвращает логиты.
- Реализовать rejection sampling по правилу из теории.
- Сравнить время генерации 100 токенов с speculative decoding против обычного генератора target модели.
- Построить график зависимости ускорения от K и acceptance rate.
Ожидаемый результат Вы увидите ускорение примерно в 1.5–2x для небольшой задачи; анализ acceptance rate покажет, как он меняется с длиной.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 212 | Basic speculative decoding (обзор) |
| 213 | KV cache management и оптимизация памяти |
| 214 | Квантование моделей для ускорения inference |
| 215 | Flash Attention и оптимизация внимания |
| 216 | Distillation и Purning для ускорения |
| 217 | Auto-regressive generation и его ограничения |
Навигация
- Предыдущий: 837
- Следующий: 839
- Индекс: 00. Индекс разборов