Что такое Quantization-Aware Training (QAT)? Чем отличается от Post-Training Quantization (PTQ)?
Краткий тезис
Quantization-Aware Training (QAT) — это метод квантования, при котором модель обучается с симуляцией квантования в прямом проходе, что позволяет адаптировать веса к потере точности. В отличие от Post-Training Quantization (PTQ), где квантование применяется после завершения обучения, QAT даёт значительно меньшую деградацию качества при низкобитных представлениях (4, 3 бита), но требует дополнительных вычислительных затрат на этапе обучения.
2. QAT: симулируем квантование во время обучения
QAT — это метод, при котором квантование встраивается в процесс обучения. Ключевые элементы:
- Fake quantization (псевдоквантование): в прямом проходе веса и активации проходят через операцию квантования-деквантования (например,
torch.quantization.fake_quantize). Градиенты при этом вычисляются через Straight-Through Estimator (STE), который аппроксимирует производную округления как 1. - Обучение с нуля или дообучение: QAT может применяться как при обучении с нуля, так и при дообучении предобученной модели. В контексте LLM чаще используется дообучение на небольшом датасете (10–100K токенов).
- Адаптация весов: модель учится компенсировать ошибки квантования, смещая распределение весов так, чтобы после квантования точность была максимальной.
- Результат: после QAT модель квантуется «по-настоящему» (веса округляются до INT4/INT3), но точность остаётся близкой к full-precision.
Схема работы QAT:
Вход → FakeQuant(веса) → FakeQuant(активации) → Выход → Loss → Градиенты (STE)
Пример на PyTorch:
import torch
import torch.nn as nn
import torch.quantization as quant
model = MyModel()
model.qconfig = quant.get_default_qat_qconfig('fbgemm')
quant.prepare_qat(model, inplace=True)
for epoch in range(epochs):
for inputs, labels in dataloader:
outputs = model(inputs) # fake quantized forward
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Convert to quantized model
quant.convert(model, inplace=True)
3. QAT лучше для низких бит (4, 3)
Сравнение PTQ и QAT при различных битностях:
| Битность | PTQ (точность, %) | QAT (точность, %) | Разница |
|---|---|---|---|
| INT8 | 98.5 | 98.7 | +0.2% |
| INT4 | 92.1 | 96.8 | +4.7% |
| INT3 | 78.3 | 93.5 | +15.2% |
| INT2 | 45.6 | 85.2 | +39.6% |
Данные ориентировочные, на примере LLaMA-7B на бенчмарке WikiText-2.
Почему QAT выигрывает при низких битах:
- Ошибка квантования растёт экспоненциально с уменьшением битности. PTQ не может её компенсировать, так как веса уже зафиксированы.
- QAT перераспределяет веса: модель учится размещать важные значения в центре диапазона квантования, где ошибка минимальна.
- Адаптация активаций: QAT также симулирует квантование активаций, что критично для моделей с Softmax и LayerNorm (например, в LLM).
4. LoRA + QAT: адаптер обучается поверх квантованной
Сочетание Low-Rank Adaptation (LoRA) с QAT — это эффективный способ дообучения LLM с минимальным потреблением памяти. Основная идея:
- Базовая модель квантуется до INT4/INT3 с помощью QAT (или PTQ, но QAT даёт лучшее качество).
- LoRA-адаптеры (ранг 8–64) обучаются в full-precision (FP16/FP32) поверх квантованных весов.
- Преимущества:
- Память: квантованная модель занимает ~4x меньше места (например, 7B модель ~3.5 ГБ в INT4 вместо ~14 ГБ в FP16).
- Качество: QAT минимизирует потери от квантования, а LoRA добавляет гибкость для тонкой настройки под задачу.
- Скорость: прямое квантование весов ускоряет инференс на CPU/GPU с поддержкой INT4.
Пример с PEFT и bitsandbytes:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
# QAT-квантованная модель (через bitsandbytes)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"model-name",
quantization_config=bnb_config,
device_map="auto"
)
# LoRA адаптер
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
)
model = get_peft_model(model, lora_config)
# Обучение адаптера
trainer = Trainer(model=model, ...)
trainer.train()
5. Пет-проект для закрепления
Задача: Сравнить PTQ и QAT для квантования небольшой LLM (например, GPT-2 или TinyLLaMA) до INT4 и измерить разницу в perplexity на WikiText-2.
Инструменты:
- Python 3.10+
- PyTorch 2.0+
- Transformers
- bitsandbytes (для PTQ)
- Intel Neural Compressor или AIMET (для QAT)
Шаги:
- Загрузите предобученную модель GPT-2 (124M параметров) в FP16.
- Измерьте baseline perplexity на WikiText-2 (validation split).
- Примените PTQ: квантуйте модель до INT4 с помощью
bitsandbytes(NF4). Измерьте perplexity. - Реализуйте QAT:
- Используйте
torch.quantization.FakeQuantizeдля весов и активаций. - Дообучите модель на 10K токенов из WikiText-2 с learning rate 1e-5.
- Конвертируйте в INT4 через
torch.quantization.convert.
- Используйте
- Сравните результаты:
- Baseline: perplexity = 25.3
- PTQ (INT4): perplexity = 28.1
- QAT (INT4): perplexity = 26.2
Ожидаемый результат:
- Вы увидите, что QAT даёт на 2–3 пункта лучший perplexity, чем PTQ, при той же битности.
- Поймёте, что QAT требует ~2–3 часов обучения на одном GPU (T4), в то время как PTQ выполняется за 5 минут.
Связь с другими вопросами
Навигация
- Предыдущий: 971
- Следующий: 973
- Индекс: 00. Индекс разборов zation (PTQ) и как оно работает?|971]]
- Следующий: 973
- Индекс: 00. Индекс разборов