Что такое Medusa (multiple heads) для speculative decoding?

Краткий тезис

Medusa — это метод speculative decoding, который ускоряет генерацию текста в LLM без использования отдельной draft-модели. Вместо этого к последним слоям target-модели добавляются несколько дополнительных голов (heads), каждая из которых предсказывает следующий токен на фиксированном шаге вперёд (head₁ — токен t+1, head₂ — t+2 и т.д.). Это позволяет генерировать несколько токенов за один forward pass, что даёт ускорение в 2–3 раза, хотя acceptance rate составляет 60–70%. Medusa проще в обучении и требует лишь ~1.2× памяти по сравнению с target-моделью, но не даёт выигрыша от использования маленькой draft-модели.


1. Speculative decoding: общая идея

Speculative decoding — это техника ускорения инференса LLM, основанная на параллельной проверке гипотез. Обычно используется пара моделей:

  • Draft-модель — маленькая, model|быстрая модель, которая генерирует несколько токенов (например, 5–10) за один проход.
  • Target-модель — большая, точная модель, которая проверяет эти токены и принимает или отвергает их с помощью rejection sampling.

Ключевой параметр — acceptance rate (доля принятых]] токенов]]). Если draft-модель хорошо предсказывает target, ускорение может достигать 2–5×. Однако обучение и поддержка отдельной draft-модели требуют дополнительных ресурсов.


2. Проблема отдельной draft-модели

Основные недостатки классического speculative decoding:

  • Обучение draft-модели: нужно подбирать архитектуру, обучать на данных target-модели, что увеличивает время разработки.
  • Дополнительная память: draft-модель занимает место в GPU (даже маленькая — сотни MB).
  • Latency инференса: последовательный вызов двух моделей (сначала draft, потом target) может добавить задержку, особенно если draft-модель не оптимизирована.
  • Сложность развёртывания: нужно управлять двумя моделями, их версиями и кэшами.

Medusa решает эти проблемы, устраняя необходимость в отдельной draft-модели.


3. Medusa: основная идея

Medusa (названа в честь мифической Горгоны с множеством змей-голов) предлагает добавлять дополнительные головы прямо на последние слои target-модели. Каждая голова — это небольшой MLP (обычно 1–2 слоя), который принимает на вход скрытое состояние последнего слоя target-модели и предсказывает распределение вероятностей для токена на определённом шаге вперёд.

  • Head₁ предсказывает токен на позиции t+1 (следующий токен).
  • Head₂ предсказывает токен на позиции t+2 (через один).
  • Head₃ — на t+3 и т.д.

Количество голов обычно 3–5. Все головы обучаются совместно с target-моделью (или после её заморозки) на данных, сгенерированных самой target-моделью.


4. Архитектура Medusa

Архитектурно Medusa состоит из:

  • Base model (target) — обычный трансформер (например, LLaMA, GPT).
  • Medusa heads — несколько MLP-голов, каждая из которых имеет входную размерность d_model (скрытое состояние) и выходную vocab_size (размер словаря). Головы могут быть независимыми или разделять часть весов (например, первую проекцию).

Пример конфигурации для 3 голов:

Head1: Linear(d_model, vocab_size)
Head2: Linear(d_model, vocab_size)
Head3: Linear(d_model, vocab_size)

Память: каждая голова добавляет ~d_model × vocab_size параметров. Для модели с d_model=4096 и vocab_size=32000 это ~131M параметров на голову. Три головы — ~393M, что составляет ~1.2× от базовой модели (например, 7B → 7.4B). Это значительно меньше, чем отдельная draft-модель (обычно 0.5–1B).


5. Обучение Medusa

Обучение голов происходит в два этапа:

  1. Генерация обучающих данных: target-модель (без голов) генерирует последовательности токенов. Для каждого шага t сохраняются:

    • Скрытое состояние последнего слоя h_t.
    • Целевые токены: x_{t+1}, x_{t+2}, ..., x_{t+K} (где K — число голов).
  2. Обучение голов: головы обучаются минимизировать cross-entropy loss для каждого шага предсказания. Функция потерь:

    L = Σ_{k=1..K} CE(Head_k(h_t), x_{t+k})
    

    Головы могут обучаться как с замороженной base model, так и с fine-tuning всей модели (второй вариант даёт лучший acceptance rate, но требует больше ресурсов).

Важно: данные генерируются той же моделью, что и будет использоваться в инференсе, чтобы распределения голов соответствовали реальным.


6. Инференс с Medusa

Процесс генерации с Medusa:

  1. Forward pass target-модели для текущего токена x_t. Получаем скрытое состояние h_t и распределение P_target(x_{t+1}).
  2. Forward pass голов: каждая голова выдаёт распределение P_k(x_{t+k}) для k=1..K.
  3. Формирование гипотез: из распределений голов выбираются top-1 токены (или несколько вариантов, если используется tree attention). Получаем последовательность кандидатов: c₁, c₂, ..., c_K.
  4. Проверка target-моделью: target-модель делает forward pass для каждого кандидата (или для всех сразу с помощью attention mask). Для каждого шага вычисляется вероятность принятия по правилу:
    accept if random_uniform(0,1) < min(1, P_target(c_k) / P_head(c_k))
    
    (rejection sampling).
  5. Принятие/отклонение: принятые токены добавляются к выходу, на первом отклонённом шаге генерируется новый токен из скорректированного распределения target.

Таким образом, за один проход target-модели можно получить до K+1 токенов (если все приняты). На практике acceptance rate 60–70% даёт ускорение 2–3×.


7. Acceptance rate и ускорение

Acceptance rate — доля токенов, сгенерированных головами, которые были приняты target-моделью. Для Medusa он ниже, чем для хорошо обученной отдельной draft-модели (которая может достигать 80–90%). Причины:

  • Головы используют то же скрытое состояние, что и target, но предсказывают на несколько шагов вперёд, что сложнее.
  • Ошибки накапливаются: если head₁ ошиблась, head₂ будет предсказывать на основе неверного контекста.

Типичные значения:

МетодAcceptance rateУскорение (при K=3)
Medusa (3 heads)60–70%2.0–2.5×
Отдельная draft-модель (0.5B)75–85%2.5–3.5×

Ускорение вычисляется как среднее количество принятых токенов за один forward pass target. При K=3 и acceptance rate 0.65: среднее = 1 + 0.65 + 0.65² + 0.65³ ≈ 2.3 токена за проход.


8. Преимущества Medusa

  • Нет отдельной draft-модели: экономия памяти (~1.2× против 1.5–2×), упрощение развёртывания.
  • Простота обучения: достаточно сгенерировать данные с target-модели и дообучить головы (можно за 1–2 дня на одном GPU).
  • Совместимость: работает с любой моделью-трансформером, не требует изменения архитектуры base model.
  • Гибкость: количество голов можно менять под задачу (больше голов — выше потенциальное ускорение, но ниже acceptance rate).

9. Недостатки Medusa

  • Нет выгоды от малой draft-модели: если draft-модель очень маленькая (например, 100M), она может быть быстрее, чем дополнительный forward pass голов + проверка target. Medusa всегда использует полный forward pass target.
  • Acceptance rate ограничен: головы не могут предсказывать так же хорошо, как отдельная модель, обученная на тех же данных.
  • Ограничение на количество голов: при K>5 acceptance rate падает ниже 50%, и ускорение перестаёт расти.
  • Дополнительная память: хотя и меньше, чем draft-модель, но всё же ~20% от base model.

10. Варианты и улучшения Medusa

  • Medusa-1: базовая версия с независимыми головами и top-1 выбором.
  • Medusa-2: использует tree attention — головы генерируют несколько вариантов (beam), и target проверяет их в виде дерева, что повышает acceptance rate.
  • Self-speculative decoding: вариант, где головы обучаются на скрытых состояниях самой модели, но без дополнительных данных (используется только loss на предсказание следующих токенов).
  • Medusa + KV cache reuse: оптимизация, при которой кэш ключей-значений для проверки гипотез переиспользуется.

11. Сравнение с другими методами speculative decoding

МетодDraft-модельПамятьAcceptance rateУскорениеСложность
КлассическийОтдельная малая1.5–2×75–90%2–5×Средняя
MedusaНет (головы)1.2×60–70%2–3×Низкая
Lookahead decodingНет (кэш)1.0×50–60%1.5–2×Низкая
EagleОтдельная, но малая1.3×70–80%2–4×Средняя

Medusa занимает нишу «лёгкого ускорения без дополнительной модели» — идеально для быстрого прототипирования и случаев, когда нельзя развёртывать вторую модель.


12. Применение в Agentic RAG

В Agentic RAG агенты часто делают множество последовательных вызовов LLM (планирование, извлечение, генерация ответа). Ускорение каждого вызова на 2–3× критически снижает общее время ответа. Medusa особенно полезна, когда:

  • Агент использует одну и ту же target-модель для всех шагов (головы уже обучены).
  • Не хочется усложнять инфраструктуру второй моделью.
  • Требуется низкая задержка (latency) для интерактивного взаимодействия.

Например, в архитектуре ReAct агент делает 5–10 шагов, каждый с генерацией 50–100 токенов. Medusa может сократить время с 10 секунд до 3–4 секунд.


Пет-проект для закрепления

Задача: Реализовать упрощённую версию Medusa для модели GPT-2 (124M) с двумя головами и сравнить скорость генерации с обычной моделью.

Инструменты:

Шаги:

  1. Загрузить предобученную GPT-2 и заморозить её веса.
  2. Добавить две Medusa-головы: nn.Linear(768, 50257) каждая.
  3. Сгенерировать обучающие данные: для каждого токена в датасете получить скрытое состояние последнего слоя и целевые токены t+1, t+2.
  4. Обучить головы с loss = CE(head1, target1) + CE(head2, target2) (5 эпох, lr=1e-4).
  5. Реализовать инференс с rejection sampling (без tree attention).
  6. Замерить время генерации 100 последовательностей длины 50 токенов с и без Medusa.

Ожидаемый результат:

  • Acceptance rate ~50–60% (для GPT-2).
  • Ускорение ~1.5–2×.
  • Код на GitHub с демонстрацией.

Связь с другими вопросами

ВопросТема
455Speculative decoding (общая концепция)
457Tree-of-thoughts decoding
458Parallel decoding
459KV cache compression
464Inference optimization для LLM
465Agentic RAG (ускорение агентов)

Навигация