Что такое Distillation для LLM? Как обучить маленькую модель (student) на выходах большой (teacher)?
Краткий тезис
Дистилляция (Knowledge Distillation) — техника сжатия нейросетевых моделей, при которой компактная модель-студент (student) обучается имитировать поведение большой модели-учителя (teacher). В контексте LLM дистилляция позволяет получить лёгкую языковую модель, сохраняющую до 80–90% качества большой за счёт обучения на мягких метках (soft labels) — распределении вероятностей учителя. Процесс включает минимизацию расхождения Кульбака–Лейблера между логитами студента и учителя, а также, опционально, потерю на промежуточных представлениях. Ключевое преимущество — десятикратное сокращение числа параметров при незначительной потере точности, что критически важно для развёртывания на edge-устройствах и снижения latency.
2. Потеря: KL divergence между student и teacher
Функция потерь для дистилляции базируется на KL-дивергенции, измеряющей расхождение между двумя распределениями вероятностей.
$$L_{KL} = \sum_{i=1}^{V} p_i^{(T)} \log \frac{p_i^{(T)}}{p_i^{(S)}}$$
где $V$ — размер словаря. На практике для численной стабильности реализуют как логарифм вероятностей student под teacher:
import torch
import torch.nn.functional as F
def distillation_loss(logits_student, logits_teacher, temperature=4.0):
p_teacher = F.softmax(logits_teacher / temperature, dim=-1)
p_student = F.log_softmax(logits_student / temperature, dim=-1)
loss = F.kl_div(p_student, p_teacher, reduction='batchmean')
return loss * (temperature ** 2)
Особенности при работе с LLM:
- Потоковый расчёт: teacher логиты предвычисляются один раз (offline) или, если размер teacher невелик (например, 7B → 1B), можно делать online-дистилляцию с обоими моделями на одном батче.
- Высокая температура (τ=4...10) уменьшает «пики» распределения, заставляя студента обращать внимание на слабые альтернативы — это улучшает обобщение, особенно на редкие токены.
- В некоторых реализациях используют reverse KL $KL(p_S || p_T)$ или симметричные меры (Jensen-Shannon) для стабилизации, когда учитель сильно уверен.
Обучение студента на KL-дивергенции без hard labels (только soft) называется pure distillation. В таком режиме модель не может исправить ошибки учителя, но зато защищена от «катастрофического забывания» и быстрее сходится.
3. Можно использовать промежуточные слои
Помимо финальных логитов, учитель может передавать представления с внутренних слоёв (hidden states, attention maps). Этот способ известен как representation distillation или feature-based KD.
Схема: выделяем K пар «учитель-студент» слоёв (например, каждый 2-й слой учителя соответствует одному слою студента), и после проекции (матрица $W$) студент минимизирует MSE между нормализованными представлениями:
$$L_{rep} = \sum_{k} || H_k^{(S)} - W_k H_{align(k)}^{(T)} ||_2^2$$
Где $align(k)$ — индекс выбранного слоя учителя. Для LLM часто используют:
- Замороженный embedding layer: студент учит свои эмбеддинги приближать выход эмбеддингов учителя.
- Слои Transformer: последние слои (ближе к выходу) передают семантику, средние — синтаксис.
- Attention maps: студент учится воспроизводить паттерны внимания большого учителя, что улучшает моделирование долгосрочных зависимостей.
Why it helps? Промежуточные потери действуют как регуляризатор — студент получает дополнительный сигнал, который помогает избежать «переобучения на логиты» и улучшает качество при малом объёме данных. Однако это увеличивает время обучения и требует тщательного выбора пар слоёв (можно искать соответствие через DTW или просто линейное выравнивание).
Практические фреймворки: Hugging Face Transformers с библиотекой distil-whisper или Textbooks Are All You Need (Apple) используют дистилляцию только на логитах, а представители DistilBERT и TinyBERT — на промежуточных слоях.
4. Применение: маленькая модель на 10% качества большой
Основная мотивация дистилляции — получить модель, которая при 1/10 параметров демонстрирует 80–90% метрик от оригинальной. Ключевые сценарии:
- Развёртывание на устройствах с ограниченными ресурсами (смартфоны, IoT): например, Phi-3-mini (3.8B) после дистилляции с GPT-4 показывает качество, сопоставимое с Llama-2 7B, но в 2 раза быстрее на CPU.
- Ускорение инференса: студент может работать с меньшей задержкой (latency) и поддерживать квантование без сильной деградации.
- Пайплайн «teacher → student» для задач снижения уровня (например, код-генерация: Copilot → CodeLlama 7B).
Цифры из практики:
| Модель-учитель | Студент | Параметры | Скорость | Метрика (Accuracy на MMLU) |
|---|---|---|---|---|
| GPT-3.5 175B | Alpaca 7B | 7B | ~10× быстрее | 78% vs 85% |
| Llama 3 70B | Llama 3 8B | 8B | ~6× быстрее | 82% vs 87% |
| StableLM 7B | DistilLM 0.7B | 0.7B | ~20× быстрее | 70% vs 62% (hard labels) |
Важно: термин «10% качества» — упрощение. На практике студент может проигрывать на сложных логических задачах, но сохранять fluency и знания в узком домене. В индустрии часто применяют task-specific distillation: дообучают студента на датасете с ответами учителя для конкретной задачи (суммаризация, QA), достигая 95–98% качества.
Scaling law для дистилляции: чем больше teacher и меньше student, тем больше относительная потеря, но абсолютные метрики растут из-за большего объёма знаний учителя. Оптимальное соотношение — 0.5–0.1 от числа параметров учителя.
5. Пет-проект для закрепления
Задача: Обучить маленький GPT-like классификатор (student) с 4 слоями Transformer (~5M параметров) предсказывать тональность отзывов, имитируя поведение BERT-base (teacher) на датасете IMDB Reviews.
Инструменты:
- Python, PyTorch, Hugging Face Transformers.
- Модели:
google-bert/bert-base-uncased(teacher), самописныйGPTClassifier(student). - DataLoader с предвычисленными teacher logits (offline).
Шаги:
- Подготовка teacher logits: загружаем BERT, пропускаем IMDB train (25K отзывов) через модель с temperature=5, сохраняем logits и hard labels.
- Построение student: класс
SimpleTransformer(n_layers=4, d_model=256, n_heads=4). - Обучение: на каждом шаге батч → student logits → KL-дивергенция с teacher logits + кросс-энтропия с истинной меткой (α=0.7, τ=5).
- Валидация: замер accuracy и RLUE (relative logit utility error) между студентом и учителем на тестовой выборке.
- Анализ: сравнить точность с версией, обученной только на hard labels (без дистилляции).
Ожидаемый результат:
- Студент с дистилляцией покажет accuracy ~85% (teacher ~92%), а версия без дистилляции ~78%.
- Визуализация распределений: student c soft labels будет ближе к teacher по энтропии предсказаний.
- Вывод: дистилляция на soft labels даёт +7% точности при той же архитектуре студента.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 971 | Distillation для LLM — данный разбор |
| 970 | PEFT и сравнение с Distillation — оба сжимают модель, но PEFT дообучает маленькие модули, а Distillation передаёт знания |
| 972 | Quantization — ещё один метод сжатия, часто комбинируется с Distillation |
Навигация
- Предыдущий: 970
- Следующий: 972
- Индекс: 00. Индекс разборов zation (GPTQ, AWQ, bitsandbytes)?|972]]
- Индекс: 00. Индекс разборов