English translation is not available yet. Showing Russian content.

Как работает speculative decoding на уровне логитов, а не токенов?

Краткий тезис

Speculative decoding — это метод ускорения инференса больших языковых моделей, при котором маленькая draft-модель генерирует несколько токенов, а большая target-модель проверяет их параллельно. В варианте «на уровне логитов» draft-модель предсказывает не конкретные токены, а полное распределение вероятностей (логиты) для каждого шага, а target-модель использует эти распределения для rejection sampling, что позволяет точнее корректировать ошибки draft-модели и сохранять качество генерации без дополнительных проходов.


1. Термины: speculative decoding, draft-модель, target-модель, логиты, rejection sampling

  • Speculative decoding — техника ускорения генерации текста, при которой быстрая, но менее точная модель (draft) предлагает последовательность токенов, а медленная, но более точная модель (target) принимает или отвергает их, используя параллельные вычисления.
  • Draft-модель — маленькая (например, 7B параметров) или distillation|дистиллированная модель, которая генерирует несколько токенов за один прямой проход.
  • Target-модельmodel|большая модель (например, 70B), которая оценивает предложенные токены и корректирует распределение.
  • Логиты (logits) — не нормализованные выходные значения последнего слоя модели перед softmax. Они представляют собой «сырые» оценки для каждого токена словаря.
  • Rejection sampling — метод коррекции выборки: если вероятность токена по target-модели ниже, чем по draft-модели, токен может быть отклонён, и генерация перезапускается с исправленным распределением.

2. Проблема: почему генерация больших моделей медленная

Генерация авторегрессионных моделей (например, GPT, LLaMA) происходит последовательно: каждый новый токен зависит от всех предыдущих. Это приводит к линейной задержке относительно длины выхода. Даже с KV-cache (кэшированием ключей и значений) каждый шаг требует одного прямого прохода через все слои модели, что для моделей с десятками миллиардов параметров занимает миллисекунды. При длине ответа в сотни токенов общее время становится неприемлемым для реального времени.

Speculative decoding решает эту проблему, позволяя генерировать несколько токенов за один проход target-модели, но с гарантией, что итоговое распределение не изменится (точность сохраняется).


3. Основная идея speculative decoding (токен-уровень)

Классический speculative decoding (Leviathan et al., 2023; Chen et al., 2023) работает так:

  1. Draft-модель генерирует K токенов авторегрессионно (быстро, так как она маленькая).
  2. Target-модель получает всю последовательность из K токенов и вычисляет логиты для каждой позиции параллельно (благодаря маске внимания).
  3. Для каждой позиции сравниваются вероятности draft и target. Если вероятность target больше или равна вероятности draft, токен принимается. Иначе — с вероятностью 1 - p_target / p_draft токен отклоняется, и генерация перезапускается с этой позиции.

Этот подход гарантирует, что итоговое распределение совпадает с распределением target-модели, но ускорение достигается за счёт того, что target-модель делает только один прямой проход на K токенов вместо K последовательных.


4. Переход к уровню логитов: что меняется

В варианте «на уровне логитов» draft-модель не генерирует конкретные токены, а выдаёт полное распределение логитов для каждого шага. Target-модель затем использует эти логиты для ресемплинга (resampling) — выборки токенов из скорректированного распределения.

Основное отличие:

  • Токен-уровень: draft выдаёт последовательность токенов, target проверяет каждый.
  • Логит-уровень: draft выдаёт матрицу логитов (размер K × V, где V — размер словаря), target вычисляет свои логиты и комбинирует их с draft-логитами через взвешенное суммирование или rejection sampling на уровне распределений.

Этот подход более гибкий: он позволяет target-модели не просто принимать/отвергать токены, а корректировать всё распределение, что особенно полезно, когда draft-модель систематически ошибается в определённых контекстах.


5. Алгоритм speculative decoding на уровне логитов (пошагово)

Пусть у нас есть draft-модель D и target-модель T. Мы хотим сгенерировать следующий токен, но можем предсказать K шагов вперёд.

Шаг 1: Draft-модель генерирует логиты для K шагов

Draft-модель получает текущий контекст x_1, ..., x_t и генерирует последовательность логитов L_draft = [l_1, l_2, ..., l_K], где каждый l_i — вектор размера V. Из этих логитов можно получить распределение p_draft через softmax.

Шаг 2: Target-модель вычисляет логиты для той же последовательности

Target-модель получает конкатенацию исходного контекста и предложенных draft-токенов (которые были сэмплированы из p_draft на предыдущем шаге, или же draft-модель сразу выдаёт токены). Однако в логит-варианте часто draft-модель не выдаёт токены, а только логиты, поэтому target-модель может использовать предварительно сэмплированные токены из draft-распределения или работать напрямую с логитами через специальный механизм внимания.

На практике чаще используется гибрид: draft-модель генерирует токены, а target-модель вычисляет логиты для всех позиций параллельно. Но в «чистом» логит-варианте target-модель получает логиты draft как дополнительный вход и вычисляет свои логиты, после чего происходит слияние распределений.

Шаг 3: Rejection sampling на уровне распределений

Для каждой позиции i от 1 до K:

  • Вычисляется p_target(x_i | context) и p_draft(x_i | context).
  • Если p_target(x_i) >= p_draft(x_i), токен принимается.
  • Иначе с вероятностью 1 - p_target(x_i)/p_draft(x_i) токен отклоняется.

При отклонении все последующие токены отбрасываются, и генерация возвращается к позиции i. Новый токен сэмплируется из скорректированного распределения:

p_corrected(x) = max(0, p_target(x) - p_draft(x)) / Z

где Z — нормировочная константа. Это гарантирует, что итоговое распределение совпадает с p_target.

Шаг 4: Повторение

Если все K токенов приняты, процесс повторяется с новой позиции.


6. Математическая основа: acceptance rate и корректировка распределения

Ключевой параметр — acceptance rate (доля принятых токенов). Он зависит от того, насколько распределения draft и target похожи. Чем ближе draft к target, тем выше acceptance rate и больше ускорение.

Формально, для каждой позиции вероятность принятия токена x равна:

accept(x) = min(1, p_target(x) / p_draft(x))

Если p_draft(x) > p_target(x), токен может быть отклонён. Вероятность отклонения: 1 - p_target(x)/p_draft(x).

После отклонения новый токен сэмплируется из распределения:

p_resample(x) ∝ max(0, p_target(x) - p_draft(x))

Это эквивалентно rejection sampling с proposal distribution p_draft и target distribution p_target. Такой подход гарантирует, что итоговая выборка неотличима от выборки из p_target.

Преимущество логит-уровня: draft-модель может предоставить target-модели не только токены, но и полные логиты, что позволяет target-модели более точно оценить, какие части распределения стоит корректировать. Например, если draft-модель уверена в неверном токене, target-модель может сильно снизить его вероятность при ресемплинге.


7. Сравнение: токен-уровень vs логит-уровень

ХарактеристикаТокен-уровеньЛогит-уровень
Что передаёт draftПоследовательность токеновМатрица логитов (K × V)
Как target проверяетСравнивает вероятности для каждого токенаСравнивает полные распределения, корректирует через ресемплинг
Вычислительная нагрузкаМеньше (только forward pass target)Больше (нужно хранить и обрабатывать логиты draft)
Точность коррекцииВысокая, но зависит от качества draftПотенциально выше, т.к. target видит всё распределение
Сложность реализацииПрощеСложнее (требует модификации архитектуры)
ПрименимостьЛюбая пара моделейТребует, чтобы draft-модель могла выдавать логиты (обычно да)

На практике большинство реализаций speculative decoding используют токен-уровень, так как он проще и даёт достаточное ускорение. Логит-уровень применяется в исследовательских работах или когда draft-модель сильно отличается от target.


8. Преимущества и ограничения

Преимущества

  • Ускорение без потери качества — итоговое распределение совпадает с target-моделью.
  • Параллелизация — target-модель обрабатывает K токенов за один проход.
  • Гибкость — можно использовать любую draft-модель (даже не обученную специально).

Ограничения

  • Зависимость от acceptance rate — если draft-модель плоха, ускорение минимально.
  • Дополнительные вычисления — draft-модель всё равно требует ресурсов.
  • Сложность реализации — особенно для логит-уровня (нужно синхронизировать распределения).
  • Не подходит для моделей с жёсткими ограничениями по памяти — хранение логитов draft может быть дорогим.

9. Реализация: псевдокод на Python

import torch
import torch.nn.functional as F

def speculative_decoding_logit_level(draft_model, target_model, context, K, vocab_size):
    """
    Draft-модель генерирует логиты для K шагов, target проверяет.
    Возвращает сгенерированные токены.
    """
    # Шаг 1: Draft генерирует логиты для K шагов (авторегрессионно)
    draft_logits = []  # список из K тензоров [1, vocab_size]
    draft_tokens = []
    current_context = context.clone()
    for _ in range(K):
        with torch.no_grad():
            logits = draft_model(current_context)[0, -1, :]  # последний токен
        draft_logits.append(logits)
        probs = F.softmax(logits, dim=-1)
        token = torch.multinomial(probs, 1)
        draft_tokens.append(token)
        current_context = torch.cat([current_context, token], dim=-1)
    
    # Шаг 2: Target вычисляет логиты для всей последовательности
    full_context = torch.cat([context, torch.cat(draft_tokens).unsqueeze(0)], dim=-1)
    with torch.no_grad():
        target_logits_all = target_model(full_context)[0]  # [1, len(context)+K, vocab_size]
    target_logits = target_logits_all[0, -K:, :]  # только последние K позиций
    
    # Шаг 3: Rejection sampling для каждой позиции
    accepted_tokens = []
    for i in range(K):
        p_draft = F.softmax(draft_logits[i], dim=-1)
        p_target = F.softmax(target_logits[i], dim=-1)
        
        # Вероятность принятия текущего токена
        token = draft_tokens[i]
        accept_prob = min(1.0, (p_target[0, token] / p_draft[0, token]).item())
        
        if torch.rand(1).item() < accept_prob:
            accepted_tokens.append(token)
        else:
            # Ресемплинг из скорректированного распределения
            correction = torch.clamp(p_target - p_draft, min=0)
            correction = correction / correction.sum()
            new_token = torch.multinomial(correction, 1)
            accepted_tokens.append(new_token)
            # Отбрасываем все последующие токены draft
            break
    
    return torch.cat(accepted_tokens, dim=-1)

Этот код — упрощённая иллюстрация. На практике используются оптимизации: параллельный forward pass target, кэширование KV, батчинг.


10. Применение в Agentic RAG и других системах

Speculative decoding особенно полезен в Agentic RAG, где агент должен быстро генерировать несколько раундов рассуждений (chain-of-thought) и обращаться к инструментам. Ускорение инференса позволяет агенту:

  • Быстрее отвечать пользователю.
  • Делать больше итераций поиска и анализа за то же время.
  • Использовать большие модели (например, 70B) в реальном времени.

Кроме того, speculative decoding применяется в:

  • LLM-сервисах (ChatGPT, Claude) для снижения задержки.
  • Мобильных и edge-устройствах, где маленькая draft-модель работает локально, а target — на сервере.
  • Дистилляции знаний — draft-модель может быть дистиллированной версией target.

11. Пет-проект для закрепления

Задача: Реализовать speculative decoding на уровне логитов для пары моделей (например, TinyLLaMA как draft и LLaMA-7B как target) и сравнить скорость генерации с обычным авторегрессионным методом.

Инструменты: PyTorch, Hugging Face Transformers, библиотека для бенчмаркинга (timeit).

Шаги:

  1. Загрузить две модели: draft (например, TinyLlama/TinyLlama-1.1B-Chat-v1.0) и target (meta-llama/Llama-2-7b-chat-hf).
  2. Реализовать функцию speculative_generate по псевдокоду выше.
  3. Реализовать обычную генерацию (autoregressive) для той же target-модели.
  4. Запустить обе функции на нескольких промптах (длина 50-100 токенов) с K=5, K=10.
  5. Замерить время генерации и проверить, что распределения выходов совпадают (например, с помощью статистического теста Колмогорова-Смирнова на логитах).

Ожидаемый результат: Ускорение в 2-4 раза при K=5-10, при этом качество (perplexity) не ухудшается.


12. Связь с другими вопросами

ВопросТема
288Как работает KV-cache и как он ускоряет инференс?
290Какие методы ускорения генерации LLM вы знаете?
287Что такое дистилляция моделей и как она применяется?
291Как работает параллельная генерация (например, Medusa)?
292В чём разница между speculative decoding и early exiting?

Навигация