Как работает speculative decoding на уровне логитов, а не токенов?
Краткий тезис
Speculative decoding — это метод ускорения инференса больших языковых моделей, при котором маленькая draft-модель генерирует несколько токенов, а большая target-модель проверяет их параллельно. В варианте «на уровне логитов» draft-модель предсказывает не конкретные токены, а полное распределение вероятностей (логиты) для каждого шага, а target-модель использует эти распределения для rejection sampling, что позволяет точнее корректировать ошибки draft-модели и сохранять качество генерации без дополнительных проходов.
1. Термины: speculative decoding, draft-модель, target-модель, логиты, rejection sampling
- Speculative decoding — техника ускорения генерации текста, при которой быстрая, но менее точная модель (draft) предлагает последовательность токенов, а медленная, но более точная модель (target) принимает или отвергает их, используя параллельные вычисления.
- Draft-модель — маленькая (например, 7B параметров) или distillation|дистиллированная модель, которая генерирует несколько токенов за один прямой проход.
- Target-модель — model|большая модель (например, 70B), которая оценивает предложенные токены и корректирует распределение.
- Логиты (logits) — не нормализованные выходные значения последнего слоя модели перед softmax. Они представляют собой «сырые» оценки для каждого токена словаря.
- Rejection sampling — метод коррекции выборки: если вероятность токена по target-модели ниже, чем по draft-модели, токен может быть отклонён, и генерация перезапускается с исправленным распределением.
2. Проблема: почему генерация больших моделей медленная
Генерация авторегрессионных моделей (например, GPT, LLaMA) происходит последовательно: каждый новый токен зависит от всех предыдущих. Это приводит к линейной задержке относительно длины выхода. Даже с KV-cache (кэшированием ключей и значений) каждый шаг требует одного прямого прохода через все слои модели, что для моделей с десятками миллиардов параметров занимает миллисекунды. При длине ответа в сотни токенов общее время становится неприемлемым для реального времени.
Speculative decoding решает эту проблему, позволяя генерировать несколько токенов за один проход target-модели, но с гарантией, что итоговое распределение не изменится (точность сохраняется).
3. Основная идея speculative decoding (токен-уровень)
Классический speculative decoding (Leviathan et al., 2023; Chen et al., 2023) работает так:
- Draft-модель генерирует K токенов авторегрессионно (быстро, так как она маленькая).
- Target-модель получает всю последовательность из K токенов и вычисляет логиты для каждой позиции параллельно (благодаря маске внимания).
- Для каждой позиции сравниваются вероятности draft и target. Если вероятность target больше или равна вероятности draft, токен принимается. Иначе — с вероятностью
1 - p_target / p_draftтокен отклоняется, и генерация перезапускается с этой позиции.
Этот подход гарантирует, что итоговое распределение совпадает с распределением target-модели, но ускорение достигается за счёт того, что target-модель делает только один прямой проход на K токенов вместо K последовательных.
4. Переход к уровню логитов: что меняется
В варианте «на уровне логитов» draft-модель не генерирует конкретные токены, а выдаёт полное распределение логитов для каждого шага. Target-модель затем использует эти логиты для ресемплинга (resampling) — выборки токенов из скорректированного распределения.
Основное отличие:
- Токен-уровень: draft выдаёт последовательность токенов, target проверяет каждый.
- Логит-уровень: draft выдаёт матрицу логитов (размер
K × V, где V — размер словаря), target вычисляет свои логиты и комбинирует их с draft-логитами через взвешенное суммирование или rejection sampling на уровне распределений.
Этот подход более гибкий: он позволяет target-модели не просто принимать/отвергать токены, а корректировать всё распределение, что особенно полезно, когда draft-модель систематически ошибается в определённых контекстах.
5. Алгоритм speculative decoding на уровне логитов (пошагово)
Пусть у нас есть draft-модель D и target-модель T. Мы хотим сгенерировать следующий токен, но можем предсказать K шагов вперёд.
Шаг 1: Draft-модель генерирует логиты для K шагов
Draft-модель получает текущий контекст x_1, ..., x_t и генерирует последовательность логитов L_draft = [l_1, l_2, ..., l_K], где каждый l_i — вектор размера V. Из этих логитов можно получить распределение p_draft через softmax.
Шаг 2: Target-модель вычисляет логиты для той же последовательности
Target-модель получает конкатенацию исходного контекста и предложенных draft-токенов (которые были сэмплированы из p_draft на предыдущем шаге, или же draft-модель сразу выдаёт токены). Однако в логит-варианте часто draft-модель не выдаёт токены, а только логиты, поэтому target-модель может использовать предварительно сэмплированные токены из draft-распределения или работать напрямую с логитами через специальный механизм внимания.
На практике чаще используется гибрид: draft-модель генерирует токены, а target-модель вычисляет логиты для всех позиций параллельно. Но в «чистом» логит-варианте target-модель получает логиты draft как дополнительный вход и вычисляет свои логиты, после чего происходит слияние распределений.
Шаг 3: Rejection sampling на уровне распределений
Для каждой позиции i от 1 до K:
- Вычисляется
p_target(x_i | context)иp_draft(x_i | context). - Если
p_target(x_i) >= p_draft(x_i), токен принимается. - Иначе с вероятностью
1 - p_target(x_i)/p_draft(x_i)токен отклоняется.
При отклонении все последующие токены отбрасываются, и генерация возвращается к позиции i. Новый токен сэмплируется из скорректированного распределения:
p_corrected(x) = max(0, p_target(x) - p_draft(x)) / Z
где Z — нормировочная константа. Это гарантирует, что итоговое распределение совпадает с p_target.
Шаг 4: Повторение
Если все K токенов приняты, процесс повторяется с новой позиции.
6. Математическая основа: acceptance rate и корректировка распределения
Ключевой параметр — acceptance rate (доля принятых токенов). Он зависит от того, насколько распределения draft и target похожи. Чем ближе draft к target, тем выше acceptance rate и больше ускорение.
Формально, для каждой позиции вероятность принятия токена x равна:
accept(x) = min(1, p_target(x) / p_draft(x))
Если p_draft(x) > p_target(x), токен может быть отклонён. Вероятность отклонения: 1 - p_target(x)/p_draft(x).
После отклонения новый токен сэмплируется из распределения:
p_resample(x) ∝ max(0, p_target(x) - p_draft(x))
Это эквивалентно rejection sampling с proposal distribution p_draft и target distribution p_target. Такой подход гарантирует, что итоговая выборка неотличима от выборки из p_target.
Преимущество логит-уровня: draft-модель может предоставить target-модели не только токены, но и полные логиты, что позволяет target-модели более точно оценить, какие части распределения стоит корректировать. Например, если draft-модель уверена в неверном токене, target-модель может сильно снизить его вероятность при ресемплинге.
7. Сравнение: токен-уровень vs логит-уровень
| Характеристика | Токен-уровень | Логит-уровень |
|---|---|---|
| Что передаёт draft | Последовательность токенов | Матрица логитов (K × V) |
| Как target проверяет | Сравнивает вероятности для каждого токена | Сравнивает полные распределения, корректирует через ресемплинг |
| Вычислительная нагрузка | Меньше (только forward pass target) | Больше (нужно хранить и обрабатывать логиты draft) |
| Точность коррекции | Высокая, но зависит от качества draft | Потенциально выше, т.к. target видит всё распределение |
| Сложность реализации | Проще | Сложнее (требует модификации архитектуры) |
| Применимость | Любая пара моделей | Требует, чтобы draft-модель могла выдавать логиты (обычно да) |
На практике большинство реализаций speculative decoding используют токен-уровень, так как он проще и даёт достаточное ускорение. Логит-уровень применяется в исследовательских работах или когда draft-модель сильно отличается от target.
8. Преимущества и ограничения
Преимущества
- Ускорение без потери качества — итоговое распределение совпадает с target-моделью.
- Параллелизация — target-модель обрабатывает K токенов за один проход.
- Гибкость — можно использовать любую draft-модель (даже не обученную специально).
Ограничения
- Зависимость от acceptance rate — если draft-модель плоха, ускорение минимально.
- Дополнительные вычисления — draft-модель всё равно требует ресурсов.
- Сложность реализации — особенно для логит-уровня (нужно синхронизировать распределения).
- Не подходит для моделей с жёсткими ограничениями по памяти — хранение логитов draft может быть дорогим.
9. Реализация: псевдокод на Python
import torch
import torch.nn.functional as F
def speculative_decoding_logit_level(draft_model, target_model, context, K, vocab_size):
"""
Draft-модель генерирует логиты для K шагов, target проверяет.
Возвращает сгенерированные токены.
"""
# Шаг 1: Draft генерирует логиты для K шагов (авторегрессионно)
draft_logits = [] # список из K тензоров [1, vocab_size]
draft_tokens = []
current_context = context.clone()
for _ in range(K):
with torch.no_grad():
logits = draft_model(current_context)[0, -1, :] # последний токен
draft_logits.append(logits)
probs = F.softmax(logits, dim=-1)
token = torch.multinomial(probs, 1)
draft_tokens.append(token)
current_context = torch.cat([current_context, token], dim=-1)
# Шаг 2: Target вычисляет логиты для всей последовательности
full_context = torch.cat([context, torch.cat(draft_tokens).unsqueeze(0)], dim=-1)
with torch.no_grad():
target_logits_all = target_model(full_context)[0] # [1, len(context)+K, vocab_size]
target_logits = target_logits_all[0, -K:, :] # только последние K позиций
# Шаг 3: Rejection sampling для каждой позиции
accepted_tokens = []
for i in range(K):
p_draft = F.softmax(draft_logits[i], dim=-1)
p_target = F.softmax(target_logits[i], dim=-1)
# Вероятность принятия текущего токена
token = draft_tokens[i]
accept_prob = min(1.0, (p_target[0, token] / p_draft[0, token]).item())
if torch.rand(1).item() < accept_prob:
accepted_tokens.append(token)
else:
# Ресемплинг из скорректированного распределения
correction = torch.clamp(p_target - p_draft, min=0)
correction = correction / correction.sum()
new_token = torch.multinomial(correction, 1)
accepted_tokens.append(new_token)
# Отбрасываем все последующие токены draft
break
return torch.cat(accepted_tokens, dim=-1)
Этот код — упрощённая иллюстрация. На практике используются оптимизации: параллельный forward pass target, кэширование KV, батчинг.
10. Применение в Agentic RAG и других системах
Speculative decoding особенно полезен в Agentic RAG, где агент должен быстро генерировать несколько раундов рассуждений (chain-of-thought) и обращаться к инструментам. Ускорение инференса позволяет агенту:
- Быстрее отвечать пользователю.
- Делать больше итераций поиска и анализа за то же время.
- Использовать большие модели (например, 70B) в реальном времени.
Кроме того, speculative decoding применяется в:
- LLM-сервисах (ChatGPT, Claude) для снижения задержки.
- Мобильных и edge-устройствах, где маленькая draft-модель работает локально, а target — на сервере.
- Дистилляции знаний — draft-модель может быть дистиллированной версией target.
11. Пет-проект для закрепления
Задача: Реализовать speculative decoding на уровне логитов для пары моделей (например, TinyLLaMA как draft и LLaMA-7B как target) и сравнить скорость генерации с обычным авторегрессионным методом.
Инструменты: PyTorch, Hugging Face Transformers, библиотека для бенчмаркинга (timeit).
Шаги:
- Загрузить две модели: draft (например,
TinyLlama/TinyLlama-1.1B-Chat-v1.0) и target (meta-llama/Llama-2-7b-chat-hf). - Реализовать функцию
speculative_generateпо псевдокоду выше. - Реализовать обычную генерацию (autoregressive) для той же target-модели.
- Запустить обе функции на нескольких промптах (длина 50-100 токенов) с K=5, K=10.
- Замерить время генерации и проверить, что распределения выходов совпадают (например, с помощью статистического теста Колмогорова-Смирнова на логитах).
Ожидаемый результат: Ускорение в 2-4 раза при K=5-10, при этом качество (perplexity) не ухудшается.
12. Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 288 | Как работает KV-cache и как он ускоряет инференс? |
| 290 | Какие методы ускорения генерации LLM вы знаете? |
| 287 | Что такое дистилляция моделей и как она применяется? |
| 291 | Как работает параллельная генерация (например, Medusa)? |
| 292 | В чём разница между speculative decoding и early exiting? |
Навигация
- Предыдущий: 288
- Следующий: 290
- Индекс: 00. Индекс разборов