EAGLE-3 vs Medusa-2 vs Hydra: сравнение speculative decoding методов?

Краткий тезис

Speculative decoding — это техника ускорения инференса LLM без потери качества, основанная на генерации нескольких токенов «черновиком» (draft model) и их верификации целевой моделью (target model). Medusa-2, EAGLE-3 и Hydra — три современных подхода, различающиеся архитектурой draft-модели, использованием скрытых представлений и стратегией построения дерева кандидатов. EAGLE-3 обеспечивает наилучший баланс между качеством и ускорением, Hydra — максимальный acceptance rate ценой сложности, Medusa-2 — компромисс с минимальным потреблением памяти.

1. Термин: Speculative Decoding

Speculative decoding (decoding|спекулятивное декодирование) — метод ускорения генерации LLM, при котором маленькая и быстрая draft model (модель-черновик) генерирует последовательность из K токенов, а большая target model (целевая модель) верифицирует их параллельно за один forward pass. Если target model принимает токен, он фиксируется; если отвергает — происходит откат и генерация продолжается с правильного токена.

Ключевые метрики

Почему это работает LLM часто генерируют «лёгкие» токены (артикли, предлоги, союзы), которые легко предсказать. Draft model берёт на себя эти токены, а target model фокусируется на «трудных» токенах.

2. Medusa-2: Головы на плечах target модели

Medusa-2 — метод, при котором к последним слоям target model добавляются несколько Medusa heads (дополнительные линейные слои). Каждая голова предсказывает следующий токен, но со сдвигом: первая голова — токен t+1, вторая — t+2 и т.д.

Архитектура

  • Нет отдельной draft model — heads встроены в target model.
  • Heads обучаются совместно с target model (fine-tuning).
  • Дерево кандидатов строится из top-k предсказаний каждой головы.

Характеристики

Плюсы минимальный overhead, простота деплоя (одна модель). Минусы более низкий acceptance rate, необходимость дообучения.

3. EAGLE-3: Feature-Aware Speculative Decoding

EAGLE-3 (Efficient AGgregation of Latent Embeddings) — метод, использующий feature-aware подход: draft model получает на вход не только последний токен, но и скрытые представления (hidden states) из target model.

Архитектура

  • Draft model — небольшая transformer-подобная сеть (обычно 1–2 слоя).
  • На каждом шаге draft model получает:
    • Эмбеддинг последнего токена.
    • Hidden state последнего слоя target model для этого токена.
  • Это позволяет draft model «видеть» внутренние представления target model и лучше предсказывать её поведение.

Характеристики

  • Memory overhead ~1.5x (небольшая draft model + кэш hidden states).
  • Acceptance rate: 78–82%.
  • Quality лучший среди трёх методов (почти неотличим от target model).
  • Training требуется обучение draft model, но target model остаётся frozen.

Плюсы высокий acceptance rate, отличное качество, не требует fine-tuning target model. Минусы overhead на хранение hidden states, сложность реализации.

4. Hydra: Множественные головы и дерево кандидатов

Hydra — метод, использующий несколько draft models (или heads), каждая из которых генерирует свою последовательность кандидатов. Все кандидаты объединяются в дерево (tree of candidates), которое верифицируется target model за один проход.

Архитектура

  • M draft models (или heads), каждая генерирует K токенов.
  • Все M × K кандидатов формируют дерево с ветвлением.
  • Target model верифицирует все пути дерева параллельно.

Характеристики

  • Memory overhead ~2x (M моделей/heads).
  • Acceptance rate: до 85% (максимальный среди трёх).
  • Quality хорошее, но может страдать от конфликтов между draft models.
  • Training сложное обучение (нужно синхронизировать M моделей).

Плюсы максимальный acceptance rate, гибкость (можно комбинировать разные draft models). Минусы высокий memory overhead, сложный деплой и обучение.

5. Сравнительная таблица

ПараметрMedusa-2EAGLE-3Hydra
Draft modelHeads на target modelОтдельная small transformerM отдельных моделей
Использование hidden statesНетДа (feature-aware)Нет (только токены)
Memory overhead~1.2x~1.5x~2x
Acceptance rate60–70%78–82%до 85%
Quality (match with target)ХорошееОтличноеХорошее
Training complexityСредняя (fine-tuning)Средняя (train draft)Высокая (M моделей)
Deployment complexityНизкая (одна модель)Средняя (2 модели)Высокая (M+1 моделей)
Wall-clock speedup1.5–2x2–2.5x2–3x

6. Когда что выбирать

СценарийРекомендуемый методПричина
Минимальный memory budgetMedusa-2Наименьший overhead
Максимальное качествоEAGLE-3Лучший acceptance rate при хорошем качестве
Максимальное ускорениеHydraНаивысший acceptance rate
Frozen target modelEAGLE-3Не требует fine-tuning target
Простота деплояMedusa-2Одна модель, простая архитектура

7. Реализация на Python (упрощённая)

import torch
import torch.nn as nn

# Упрощённая реализация speculative decoding
class SpeculativeDecoder:
    def __init__(self, target_model, draft_model, method='eagle3'):
        self.target = target_model
        self.draft = draft_model
        self.method = method
        
    def generate(self, input_ids, max_new_tokens=100, k=5):
        """Генерация с speculative decoding"""
        generated = input_ids.clone()
        hidden_states = None
        
        while len(generated[0]) < max_new_tokens:
            # Шаг 1: Draft model генерирует K кандидатов
            if self.method == 'eagle3':
                # Feature-aware: передаём hidden states
                draft_tokens, hidden_states = self.draft.generate(
                    generated, 
                    hidden_states=hidden_states,
                    num_candidates=k
                )
            else:
                draft_tokens = self.draft.generate(generated, num_candidates=k)
            
            # Шаг 2: Target model верифицирует
            with torch.no_grad():
                target_logits = self.target(draft_tokens)
            
            # Шаг 3: Accept/reject
            accepted = []
            for i in range(k):
                if self._accept_token(target_logits[:, i], draft_tokens[:, i]):
                    accepted.append(draft_tokens[:, i])
                else:
                    # Reject: генерируем правильный токен
                    correct_token = torch.argmax(target_logits[:, i], dim=-1)
                    accepted.append(correct_token)
                    break
            
            generated = torch.cat([generated, torch.stack(accepted)], dim=-1)
        
        return generated
    
    def _accept_token(self, logits, token):
        """Rejection sampling: принимаем токен с вероятностью min(1, q/p)"""
        p = torch.softmax(logits, dim=-1)
        q = torch.softmax(self.draft.logits, dim=-1)
        acceptance_prob = torch.min(torch.tensor(1.0), q[token] / p[token])
        return torch.rand(1) < acceptance_prob

8. Пет-проект для закрепления

Задача Реализовать и сравнить Medusa-2, EAGLE-3 и Hydra на небольшой LLM (например, GPT-2).

Инструменты

Шаги:

  1. Подготовка Загрузите GPT-2 small (124M параметров) как target model.
  2. Medusa-2 Добавьте 4 линейных слоя (heads) к последнему слою GPT-2. Обучите heads на 1000 примерах из WikiText.
  3. EAGLE-3 Создайте draft model (2 слоя transformer, 4 heads). Обучите её предсказывать токены, используя hidden states GPT-2.
  4. Hydra Создайте 3 draft models (каждая — маленький GPT-2). Реализуйте дерево кандидатов.
  5. Тестирование На 100 запросах измерьте:
  6. Визуализация Постройте графики зависимости acceptance rate от K (числа кандидатов).

Ожидаемый результат

9. Связь с другими вопросами

ВопросТема
440Speculative decoding: общие принципы
442KV-cache compression
443Flash Attention
444Quantization (GPTQ, AWQ)
445Pruning и distillation
446Continuous batching

10. Навигация


Навигация