中文翻译暂不可用,显示俄语原文。
Что такое Medusa (multiple heads) для speculative decoding?
Краткий тезис
Medusa — это метод speculative decoding, который ускоряет генерацию текста в LLM без использования отдельной draft-модели. Вместо этого к последним слоям target-модели добавляются несколько дополнительных голов (heads), каждая из которых предсказывает следующий токен на фиксированном шаге вперёд (head₁ — токен t+1, head₂ — t+2 и т.д.). Это позволяет генерировать несколько токенов за один forward pass, что даёт ускорение в 2–3 раза, хотя acceptance rate составляет 60–70%. Medusa проще в обучении и требует лишь ~1.2× памяти по сравнению с target-моделью, но не даёт выигрыша от использования маленькой draft-модели.
1. Speculative decoding: общая идея
Speculative decoding — это техника ускорения инференса LLM, основанная на параллельной проверке гипотез. Обычно используется пара моделей:
- Draft-модель — маленькая, model|быстрая модель, которая генерирует несколько токенов (например, 5–10) за один проход.
- Target-модель — большая, точная модель, которая проверяет эти токены и принимает или отвергает их с помощью rejection sampling.
Ключевой параметр — acceptance rate (доля принятых]] токенов]]). Если draft-модель хорошо предсказывает target, ускорение может достигать 2–5×. Однако обучение и поддержка отдельной draft-модели требуют дополнительных ресурсов.
2. Проблема отдельной draft-модели
Основные недостатки классического speculative decoding:
- Обучение draft-модели: нужно подбирать архитектуру, обучать на данных target-модели, что увеличивает время разработки.
- Дополнительная память: draft-модель занимает место в GPU (даже маленькая — сотни MB).
- Latency инференса: последовательный вызов двух моделей (сначала draft, потом target) может добавить задержку, особенно если draft-модель не оптимизирована.
- Сложность развёртывания: нужно управлять двумя моделями, их версиями и кэшами.
Medusa решает эти проблемы, устраняя необходимость в отдельной draft-модели.
3. Medusa: основная идея
Medusa (названа в честь мифической Горгоны с множеством змей-голов) предлагает добавлять дополнительные головы прямо на последние слои target-модели. Каждая голова — это небольшой MLP (обычно 1–2 слоя), который принимает на вход скрытое состояние последнего слоя target-модели и предсказывает распределение вероятностей для токена на определённом шаге вперёд.
- Head₁ предсказывает токен на позиции t+1 (следующий токен).
- Head₂ предсказывает токен на позиции t+2 (через один).
- Head₃ — на t+3 и т.д.
Количество голов обычно 3–5. Все головы обучаются совместно с target-моделью (или после её заморозки) на данных, сгенерированных самой target-моделью.
4. Архитектура Medusa
Архитектурно Medusa состоит из:
- Base model (target) — обычный трансформер (например, LLaMA, GPT).
- Medusa heads — несколько MLP-голов, каждая из которых имеет входную размерность d_model (скрытое состояние) и выходную vocab_size (размер словаря). Головы могут быть независимыми или разделять часть весов (например, первую проекцию).
Пример конфигурации для 3 голов:
Head1: Linear(d_model, vocab_size)
Head2: Linear(d_model, vocab_size)
Head3: Linear(d_model, vocab_size)
Память: каждая голова добавляет ~d_model × vocab_size параметров. Для модели с d_model=4096 и vocab_size=32000 это ~131M параметров на голову. Три головы — ~393M, что составляет ~1.2× от базовой модели (например, 7B → 7.4B). Это значительно меньше, чем отдельная draft-модель (обычно 0.5–1B).
5. Обучение Medusa
Обучение голов происходит в два этапа:
-
Генерация обучающих данных: target-модель (без голов) генерирует последовательности токенов. Для каждого шага t сохраняются:
- Скрытое состояние последнего слоя h_t.
- Целевые токены: x_{t+1}, x_{t+2}, ..., x_{t+K} (где K — число голов).
-
Обучение голов: головы обучаются минимизировать cross-entropy loss для каждого шага предсказания. Функция потерь:
L = Σ_{k=1..K} CE(Head_k(h_t), x_{t+k})Головы могут обучаться как с замороженной base model, так и с fine-tuning всей модели (второй вариант даёт лучший acceptance rate, но требует больше ресурсов).
Важно: данные генерируются той же моделью, что и будет использоваться в инференсе, чтобы распределения голов соответствовали реальным.
6. Инференс с Medusa
Процесс генерации с Medusa:
- Forward pass target-модели для текущего токена x_t. Получаем скрытое состояние h_t и распределение P_target(x_{t+1}).
- Forward pass голов: каждая голова выдаёт распределение P_k(x_{t+k}) для k=1..K.
- Формирование гипотез: из распределений голов выбираются top-1 токены (или несколько вариантов, если используется tree attention). Получаем последовательность кандидатов: c₁, c₂, ..., c_K.
- Проверка target-моделью: target-модель делает forward pass для каждого кандидата (или для всех сразу с помощью attention mask). Для каждого шага вычисляется вероятность принятия по правилу:
(rejection sampling).accept if random_uniform(0,1) < min(1, P_target(c_k) / P_head(c_k)) - Принятие/отклонение: принятые токены добавляются к выходу, на первом отклонённом шаге генерируется новый токен из скорректированного распределения target.
Таким образом, за один проход target-модели можно получить до K+1 токенов (если все приняты). На практике acceptance rate 60–70% даёт ускорение 2–3×.
7. Acceptance rate и ускорение
Acceptance rate — доля токенов, сгенерированных головами, которые были приняты target-моделью. Для Medusa он ниже, чем для хорошо обученной отдельной draft-модели (которая может достигать 80–90%). Причины:
- Головы используют то же скрытое состояние, что и target, но предсказывают на несколько шагов вперёд, что сложнее.
- Ошибки накапливаются: если head₁ ошиблась, head₂ будет предсказывать на основе неверного контекста.
Типичные значения:
| Метод | Acceptance rate | Ускорение (при K=3) |
|---|---|---|
| Medusa (3 heads) | 60–70% | 2.0–2.5× |
| Отдельная draft-модель (0.5B) | 75–85% | 2.5–3.5× |
Ускорение вычисляется как среднее количество принятых токенов за один forward pass target. При K=3 и acceptance rate 0.65: среднее = 1 + 0.65 + 0.65² + 0.65³ ≈ 2.3 токена за проход.
8. Преимущества Medusa
- Нет отдельной draft-модели: экономия памяти (~1.2× против 1.5–2×), упрощение развёртывания.
- Простота обучения: достаточно сгенерировать данные с target-модели и дообучить головы (можно за 1–2 дня на одном GPU).
- Совместимость: работает с любой моделью-трансформером, не требует изменения архитектуры base model.
- Гибкость: количество голов можно менять под задачу (больше голов — выше потенциальное ускорение, но ниже acceptance rate).
9. Недостатки Medusa
- Нет выгоды от малой draft-модели: если draft-модель очень маленькая (например, 100M), она может быть быстрее, чем дополнительный forward pass голов + проверка target. Medusa всегда использует полный forward pass target.
- Acceptance rate ограничен: головы не могут предсказывать так же хорошо, как отдельная модель, обученная на тех же данных.
- Ограничение на количество голов: при K>5 acceptance rate падает ниже 50%, и ускорение перестаёт расти.
- Дополнительная память: хотя и меньше, чем draft-модель, но всё же ~20% от base model.
10. Варианты и улучшения Medusa
- Medusa-1: базовая версия с независимыми головами и top-1 выбором.
- Medusa-2: использует tree attention — головы генерируют несколько вариантов (beam), и target проверяет их в виде дерева, что повышает acceptance rate.
- Self-speculative decoding: вариант, где головы обучаются на скрытых состояниях самой модели, но без дополнительных данных (используется только loss на предсказание следующих токенов).
- Medusa + KV cache reuse: оптимизация, при которой кэш ключей-значений для проверки гипотез переиспользуется.
11. Сравнение с другими методами speculative decoding
| Метод | Draft-модель | Память | Acceptance rate | Ускорение | Сложность |
|---|---|---|---|---|---|
| Классический | Отдельная малая | 1.5–2× | 75–90% | 2–5× | Средняя |
| Medusa | Нет (головы) | 1.2× | 60–70% | 2–3× | Низкая |
| Lookahead decoding | Нет (кэш) | 1.0× | 50–60% | 1.5–2× | Низкая |
| Eagle | Отдельная, но малая | 1.3× | 70–80% | 2–4× | Средняя |
Medusa занимает нишу «лёгкого ускорения без дополнительной модели» — идеально для быстрого прототипирования и случаев, когда нельзя развёртывать вторую модель.
12. Применение в Agentic RAG
В Agentic RAG агенты часто делают множество последовательных вызовов LLM (планирование, извлечение, генерация ответа). Ускорение каждого вызова на 2–3× критически снижает общее время ответа. Medusa особенно полезна, когда:
- Агент использует одну и ту же target-модель для всех шагов (головы уже обучены).
- Не хочется усложнять инфраструктуру второй моделью.
- Требуется низкая задержка (latency) для интерактивного взаимодействия.
Например, в архитектуре ReAct агент делает 5–10 шагов, каждый с генерацией 50–100 токенов. Medusa может сократить время с 10 секунд до 3–4 секунд.
Пет-проект для закрепления
Задача: Реализовать упрощённую версию Medusa для модели GPT-2 (124M) с двумя головами и сравнить скорость генерации с обычной моделью.
Инструменты:
- Python, PyTorch, Hugging Face Transformers
- Датасет: небольшой корпус текстов (например, WikiText-2)
Шаги:
- Загрузить предобученную GPT-2 и заморозить её веса.
- Добавить две Medusa-головы:
nn.Linear(768, 50257)каждая. - Сгенерировать обучающие данные: для каждого токена в датасете получить скрытое состояние последнего слоя и целевые токены t+1, t+2.
- Обучить головы с loss = CE(head1, target1) + CE(head2, target2) (5 эпох, lr=1e-4).
- Реализовать инференс с rejection sampling (без tree attention).
- Замерить время генерации 100 последовательностей длины 50 токенов с и без Medusa.
Ожидаемый результат:
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 455 | Speculative decoding (общая концепция) |
| 457 | Tree-of-thoughts decoding |
| 458 | Parallel decoding |
| 459 | KV cache compression |
| 464 | Inference optimization для LLM |
| 465 | Agentic RAG (ускорение агентов) |
Навигация
- Предыдущий: 455
- Следующий: 457
- Индекс: 00. Индекс разборов