Aivaro
  • Оглавление
  • Вопросы
  • Практика
  • Вики
  • Материалы сообщества
  • Тесты
  • Поиск
✈Telegram @ai_varo
RUEN中文
…
Оглавление/Вопросы/#971

Что такое Distillation для LLM? Как обучить маленькую модель (student) на выходах большой (teacher)?

Краткий тезис

Дистилляция (Knowledge Distillation) — техника сжатия нейросетевых моделей, при которой компактная модель-студент (student) обучается имитировать поведение большой модели-учителя (teacher). В контексте LLM дистилляция позволяет получить лёгкую языковую модель, сохраняющую до 80–90% качества большой за счёт обучения на мягких метках (soft labels) — распределении вероятностей учителя. Процесс включает минимизацию расхождения Кульбака–Лейблера между логитами студента и учителя, а также, опционально, потерю на промежуточных представлениях. Ключевое преимущество — десятикратное сокращение числа параметров при незначительной потере точности, что критически важно для развёртывания на edge-устройствах и снижения latency.

2. Потеря: KL divergence между student и teacher

Функция потерь для дистилляции базируется на KL-дивергенции, измеряющей расхождение между двумя распределениями вероятностей.

$$L_{KL} = \sum_{i=1}^{V} p_i^{(T)} \log \frac{p_i^{(T)}}{p_i^{(S)}}$$

где $V$ — размер словаря. На практике для численной стабильности реализуют как логарифм вероятностей student под teacher:

import torch
import torch.nn.functional as F

def distillation_loss(logits_student, logits_teacher, temperature=4.0):
    p_teacher = F.softmax(logits_teacher / temperature, dim=-1)
    p_student = F.log_softmax(logits_student / temperature, dim=-1)
    loss = F.kl_div(p_student, p_teacher, reduction='batchmean')
    return loss * (temperature ** 2)

Особенности при работе с LLM:

  • Потоковый расчёт: teacher логиты предвычисляются один раз (offline) или, если размер teacher невелик (например, 7B → 1B), можно делать online-дистилляцию с обоими моделями на одном батче.
  • Высокая температура (τ=4...10) уменьшает «пики» распределения, заставляя студента обращать внимание на слабые альтернативы — это улучшает обобщение, особенно на редкие токены.
  • В некоторых реализациях используют reverse KL $KL(p_S || p_T)$ или симметричные меры (Jensen-Shannon) для стабилизации, когда учитель сильно уверен.

Обучение студента на KL-дивергенции без hard labels (только soft) называется pure distillation. В таком режиме модель не может исправить ошибки учителя, но зато защищена от «катастрофического забывания» и быстрее сходится.


3. Можно использовать промежуточные слои

Помимо финальных логитов, учитель может передавать представления с внутренних слоёв (hidden states, attention maps). Этот способ известен как representation distillation или feature-based KD.

Схема: выделяем K пар «учитель-студент» слоёв (например, каждый 2-й слой учителя соответствует одному слою студента), и после проекции (матрица $W$) студент минимизирует MSE между нормализованными представлениями:

$$L_{rep} = \sum_{k} || H_k^{(S)} - W_k H_{align(k)}^{(T)} ||_2^2$$

Где $align(k)$ — индекс выбранного слоя учителя. Для LLM часто используют:

  • Замороженный embedding layer: студент учит свои эмбеддинги приближать выход эмбеддингов учителя.
  • Слои Transformer: последние слои (ближе к выходу) передают семантику, средние — синтаксис.
  • Attention maps: студент учится воспроизводить паттерны внимания большого учителя, что улучшает моделирование долгосрочных зависимостей.

Why it helps? Промежуточные потери действуют как регуляризатор — студент получает дополнительный сигнал, который помогает избежать «переобучения на логиты» и улучшает качество при малом объёме данных. Однако это увеличивает время обучения и требует тщательного выбора пар слоёв (можно искать соответствие через DTW или просто линейное выравнивание).

Практические фреймворки: Hugging Face Transformers с библиотекой distil-whisper или Textbooks Are All You Need (Apple) используют дистилляцию только на логитах, а представители DistilBERT и TinyBERT — на промежуточных слоях.


4. Применение: маленькая модель на 10% качества большой

Основная мотивация дистилляции — получить модель, которая при 1/10 параметров демонстрирует 80–90% метрик от оригинальной. Ключевые сценарии:

  • Развёртывание на устройствах с ограниченными ресурсами (смартфоны, IoT): например, Phi-3-mini (3.8B) после дистилляции с GPT-4 показывает качество, сопоставимое с Llama-2 7B, но в 2 раза быстрее на CPU.
  • Ускорение инференса: студент может работать с меньшей задержкой (latency) и поддерживать квантование без сильной деградации.
  • Пайплайн «teacher → student» для задач снижения уровня (например, код-генерация: Copilot → CodeLlama 7B).

Цифры из практики:

Модель-учительСтудентПараметрыСкоростьМетрика (Accuracy на MMLU)
GPT-3.5 175BAlpaca 7B7B~10× быстрее78% vs 85%
Llama 3 70BLlama 3 8B8B~6× быстрее82% vs 87%
StableLM 7BDistilLM 0.7B0.7B~20× быстрее70% vs 62% (hard labels)

Важно: термин «10% качества» — упрощение. На практике студент может проигрывать на сложных логических задачах, но сохранять fluency и знания в узком домене. В индустрии часто применяют task-specific distillation: дообучают студента на датасете с ответами учителя для конкретной задачи (суммаризация, QA), достигая 95–98% качества.

Scaling law для дистилляции: чем больше teacher и меньше student, тем больше относительная потеря, но абсолютные метрики растут из-за большего объёма знаний учителя. Оптимальное соотношение — 0.5–0.1 от числа параметров учителя.


5. Пет-проект для закрепления

Задача: Обучить маленький GPT-like классификатор (student) с 4 слоями Transformer (~5M параметров) предсказывать тональность отзывов, имитируя поведение BERT-base (teacher) на датасете IMDB Reviews.

Инструменты:

  • Python, PyTorch, Hugging Face Transformers.
  • Модели: google-bert/bert-base-uncased (teacher), самописный GPTClassifier (student).
  • DataLoader с предвычисленными teacher logits (offline).

Шаги:

  1. Подготовка teacher logits: загружаем BERT, пропускаем IMDB train (25K отзывов) через модель с temperature=5, сохраняем logits и hard labels.
  2. Построение student: класс SimpleTransformer(n_layers=4, d_model=256, n_heads=4).
  3. Обучение: на каждом шаге батч → student logits → KL-дивергенция с teacher logits + кросс-энтропия с истинной меткой (α=0.7, τ=5).
  4. Валидация: замер accuracy и RLUE (relative logit utility error) между студентом и учителем на тестовой выборке.
  5. Анализ: сравнить точность с версией, обученной только на hard labels (без дистилляции).

Ожидаемый результат:

  • Студент с дистилляцией покажет accuracy ~85% (teacher ~92%), а версия без дистилляции ~78%.
  • Визуализация распределений: student c soft labels будет ближе к teacher по энтропии предсказаний.
  • Вывод: дистилляция на soft labels даёт +7% точности при той же архитектуре студента.

Связь с другими вопросами

ВопросТема
971Distillation для LLM — данный разбор
970PEFT и сравнение с Distillation — оба сжимают модель, но PEFT дообучает маленькие модули, а Distillation передаёт знания
972Quantization — ещё один метод сжатия, часто комбинируется с Distillation

Навигация

  • Предыдущий: 970
  • Следующий: 972
  • Индекс: 00. Индекс разборов zation (GPTQ, AWQ, bitsandbytes)?|972]]
  • Индекс: 00. Индекс разборов