Aivaro
  • Оглавление
  • Вопросы
  • Практика
  • Вики
  • Материалы сообщества
  • Тесты
  • Поиск
✈Telegram @ai_varo
RUEN中文
…
Оглавление/Вопросы/#972

Что такое Quantization-Aware Training (QAT)? Чем отличается от Post-Training Quantization (PTQ)?

Краткий тезис

Quantization-Aware Training (QAT) — это метод квантования, при котором модель обучается с симуляцией квантования в прямом проходе, что позволяет адаптировать веса к потере точности. В отличие от Post-Training Quantization (PTQ), где квантование применяется после завершения обучения, QAT даёт значительно меньшую деградацию качества при низкобитных представлениях (4, 3 бита), но требует дополнительных вычислительных затрат на этапе обучения.

2. QAT: симулируем квантование во время обучения

QAT — это метод, при котором квантование встраивается в процесс обучения. Ключевые элементы:

  • Fake quantization (псевдоквантование): в прямом проходе веса и активации проходят через операцию квантования-деквантования (например, torch.quantization.fake_quantize). Градиенты при этом вычисляются через Straight-Through Estimator (STE), который аппроксимирует производную округления как 1.
  • Обучение с нуля или дообучение: QAT может применяться как при обучении с нуля, так и при дообучении предобученной модели. В контексте LLM чаще используется дообучение на небольшом датасете (10–100K токенов).
  • Адаптация весов: модель учится компенсировать ошибки квантования, смещая распределение весов так, чтобы после квантования точность была максимальной.
  • Результат: после QAT модель квантуется «по-настоящему» (веса округляются до INT4/INT3), но точность остаётся близкой к full-precision.

Схема работы QAT:

Вход → FakeQuant(веса) → FakeQuant(активации) → Выход → Loss → Градиенты (STE)

Пример на PyTorch:

import torch
import torch.nn as nn
import torch.quantization as quant

model = MyModel()
model.qconfig = quant.get_default_qat_qconfig('fbgemm')
quant.prepare_qat(model, inplace=True)

for epoch in range(epochs):
    for inputs, labels in dataloader:
        outputs = model(inputs)  # fake quantized forward
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Convert to quantized model
quant.convert(model, inplace=True)

3. QAT лучше для низких бит (4, 3)

Сравнение PTQ и QAT при различных битностях:

БитностьPTQ (точность, %)QAT (точность, %)Разница
INT898.598.7+0.2%
INT492.196.8+4.7%
INT378.393.5+15.2%
INT245.685.2+39.6%

Данные ориентировочные, на примере LLaMA-7B на бенчмарке WikiText-2.

Почему QAT выигрывает при низких битах:

  • Ошибка квантования растёт экспоненциально с уменьшением битности. PTQ не может её компенсировать, так как веса уже зафиксированы.
  • QAT перераспределяет веса: модель учится размещать важные значения в центре диапазона квантования, где ошибка минимальна.
  • Адаптация активаций: QAT также симулирует квантование активаций, что критично для моделей с Softmax и LayerNorm (например, в LLM).

4. LoRA + QAT: адаптер обучается поверх квантованной

Сочетание Low-Rank Adaptation (LoRA) с QAT — это эффективный способ дообучения LLM с минимальным потреблением памяти. Основная идея:

  • Базовая модель квантуется до INT4/INT3 с помощью QAT (или PTQ, но QAT даёт лучшее качество).
  • LoRA-адаптеры (ранг 8–64) обучаются в full-precision (FP16/FP32) поверх квантованных весов.
  • Преимущества:
    • Память: квантованная модель занимает ~4x меньше места (например, 7B модель ~3.5 ГБ в INT4 вместо ~14 ГБ в FP16).
    • Качество: QAT минимизирует потери от квантования, а LoRA добавляет гибкость для тонкой настройки под задачу.
    • Скорость: прямое квантование весов ускоряет инференс на CPU/GPU с поддержкой INT4.

Пример с PEFT и bitsandbytes:

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

# QAT-квантованная модель (через bitsandbytes)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    "model-name",
    quantization_config=bnb_config,
    device_map="auto"
)

# LoRA адаптер
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
)
model = get_peft_model(model, lora_config)

# Обучение адаптера
trainer = Trainer(model=model, ...)
trainer.train()

5. Пет-проект для закрепления

Задача: Сравнить PTQ и QAT для квантования небольшой LLM (например, GPT-2 или TinyLLaMA) до INT4 и измерить разницу в perplexity на WikiText-2.

Инструменты:

  • Python 3.10+
  • PyTorch 2.0+
  • Transformers
  • bitsandbytes (для PTQ)
  • Intel Neural Compressor или AIMET (для QAT)

Шаги:

  1. Загрузите предобученную модель GPT-2 (124M параметров) в FP16.
  2. Измерьте baseline perplexity на WikiText-2 (validation split).
  3. Примените PTQ: квантуйте модель до INT4 с помощью bitsandbytes (NF4). Измерьте perplexity.
  4. Реализуйте QAT:
    • Используйте torch.quantization.FakeQuantize для весов и активаций.
    • Дообучите модель на 10K токенов из WikiText-2 с learning rate 1e-5.
    • Конвертируйте в INT4 через torch.quantization.convert.
  5. Сравните результаты:
    • Baseline: perplexity = 25.3
    • PTQ (INT4): perplexity = 28.1
    • QAT (INT4): perplexity = 26.2

Ожидаемый результат:

  • Вы увидите, что QAT даёт на 2–3 пункта лучший perplexity, чем PTQ, при той же битности.
  • Поймёте, что QAT требует ~2–3 часов обучения на одном GPU (T4), в то время как PTQ выполняется за 5 минут.

Связь с другими вопросами

ВопросТема
443Основы квантования, типы (INT8, FP8, NF4)

Навигация

  • Предыдущий: 971
  • Следующий: 973
  • Индекс: 00. Индекс разборов zation (PTQ) и как оно работает?|971]]
  • Следующий: 973
  • Индекс: 00. Индекс разборов