Как вы избегаете переобучения при fine-tuning на маленьком датасете?
Краткий тезис
Переобучение на малом датасете — главный враг fine-tuning. Чтобы его избежать, применяют комбинацию из LoRA (или других PEFT-методов), агрессивной регуляризации (dropout, weight decay), early stopping по валидационной метрике, data augmentation и тщательного контроля гиперпараметров (низкий learning rate, маленький batch size). Ключевой индикатор — расхождение между train loss и validation loss: если val loss перестаёт снижаться или растёт, а train loss продолжает падать, это явный сигнал переобучения.
1. Мониторинг разницы train/validation loss
Первый и самый важный шаг — не просто смотреть на финальную метрику, а отслеживать динамику лоссов на каждой эпохе.
- train loss — ошибка на обучающей выборке.
- validation loss — ошибка на отложенной валидационной выборке (не участвовавшей в обучении).
Паттерны
- Оба лосса снижаются → всё хорошо, модель учится.
- train loss падает, validation loss перестаёт падать или начинает расти → переобучение.
- Оба не падают → проблема с обучением (недообучение, неправильные гиперпараметры, плохие данные).
Практический совет используйте TensorBoard или Weights & Biases для визуализации лоссов. Если val loss не улучшился в течение 3–5 эпох, можно остановить обучение.
2. LoRA (Low-Rank Adaptation) — параметрически эффективное обучение
При fine-tuning полной модели (например, LLaMA-7B) мы обновляем все 7 миллиардов параметров. На маленьком датасете это практически гарантирует переобучение.
Решение — LoRA (и другие PEFT-методы: AdaLoRA, DoRA, Prefix Tuning). LoRA фиксирует исходные веса и вставляет маленькие обучаемые матрицы ранга ( r ) (обычно 8–16) в слои attention. Количество обучаемых параметров сокращается в 100–10 000 раз.
Почему это помогает избежать переобучения
- Малое число параметров → меньше ёмкость модели → модель не может выучить шум.
- Исходные веса остаются нетронутыми — сохраняется общее знание.
Пример конфигурации LoRA
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"], # только query и value
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(base_model, lora_config)
3. Регуляризация: dropout и weight decay
Даже при LoRA необходима регуляризация.
Dropout — случайное обнуление нейронов во время обучения. Для fine-tuning LLM обычно используют dropout 0.1–0.2 в LoRA-слоях и иногда в полносвязных слоях модели. При инференсе dropout отключается.
Weight decay (L2-регуляризация) — добавляет к loss штраф за большие веса: ( \mathcal{L} = \mathcal{L}_{[text](/wiki/text){orig}} + \lambda \sum w^2 ). Это заставляет веса оставаться маленькими и предотвращает запоминание outliers. Типичные значения: 0.01–0.1. Для LoRA weight decay применяется только к адаптерным матрицам.
Сравнение
| Метод | Механизм | Типичное значение | Эффект |
|---|---|---|---|
| Dropout | Случайное обнуление | 0.1–0.2 | Декоррелирует признаки, снижает ко-адаптацию |
| Weight decay | Штраф за большие веса | 0.01–0.1 | Ограничивает норму весов, сглаживает поверхность лосса |
4. Early stopping
Early stopping — остановка обучения, когда метрика на валидации перестаёт улучшаться в течение заданного числа шагов (patience).
Как настроить
- Разделить датасет на train/validation (например, 80/20).
- После каждой эпохи вычислять val loss (или метрику, например accuracy).
- Если val loss не уменьшился за patience=3 эпох → остановка и восстановление лучшей модели (с наименьшим val loss).
Код (фрагмент):
best_val_loss = float('inf')
patience_counter = 0
for epoch in range(max_epochs):
train_loss = train_one_epoch(model, train_loader)
val_loss = evaluate(model, val_loader)
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
save_checkpoint(model, "best_model.pt")
else:
patience_counter += 1
if patience_counter >= 3:
print("Early stopping triggered")
break
load_checkpoint(model, "best_model.pt")
5. Data augmentation (аугментация данных)
Маленький датасет можно расширить, генерируя синтетические или модифицированные примеры, сохраняющие смысл.
Способы для NLP / LLM
- Back-translation перевести текст на другой язык и обратно (например, EN → DE → EN).
- Синонимическая замена замена слов на синонимы (используя WordNet, тезаурус или LLM).
- Маскинг и заполнение случайно маскировать часть токенов и генерировать варианты.
- Парафразирование с помощью LLM (GPT, LLaMA) создать альтернативные формулировки того же запроса или ответа.
- Шум добавление орфографических ошибок (аккуратно — может испортить задачу).
Пример для инструктивного датасета
# Допустим, у нас есть пара (инструкция, ответ)
# Используем GPT-3.5 для генерации 5 парафразов инструкции
prompt = f"Перефразируй инструкцию, сохранив задачу:\nИнструкция: {instruction}"
variants = [call_llm(prompt) for _ in range(5)]
# Добавляем все варианты в датасет с тем же ответом
Важно не переусердствовать — аугментация должна быть семантически релевантной, иначе зашумляет данные.
6. Кросс-валидация (k-fold)
При экстремально малых датасетах (менее 500 примеров) обычное разделение на train/val может быть нестабильным. k-fold cross-validation позволяет более объективно оценить качество.
- Делим датасет на ( k ) фолдов (например, 5).
- На каждом шаге используем ( k-1 ) фолдов для обучения, один для валидации.
- Усредняем метрики по всем фолдам.
Минус дорого, требует обучения ( k ) моделей. Но для маленьких датасетов время приемлемо.
Рекомендация после кросс-валидации обучаем финальную модель на всех данных, чтобы не терять примеры.
7. Уменьшение learning rate и планировщики
На маленьком датасете высокий learning rate быстро приводит к переобучению, потому что градиенты шумные и веса «перепрыгивают» в минимумы, соответствующие шуму.
- Начальный LR 1e-5 или ниже (для LLM). Для PEFT-адаптеров можно чуть выше (1e-4).
- Планировщики cosine annealing, linear warmup + linear decay. Warmup помогает стабилизировать первые шаги.
- Batch size: маленький (2–8) для больших сетей, чтобы уменьшить ёмкость обновлений.
Пример с Hugging Face Trainer
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir="./results",
learning_rate=1e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
weight_decay=0.01,
warmup_ratio=0.1, # 10% шагов на warmup
lr_scheduler_type="cosine",
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
max_steps=1000, # ограничить общее количество шагов
)
8. Transfer learning: начинаем с хорошей предобученной модели
Чем больше знает модель изначально, тем меньше ей нужно учиться на маленьком датасете.
- Используйте LLaMA 3.1 8B или Mistral 7B, а не обучайте с нуля.
- Fine-tune только верхние слои или с помощью PEFT.
- Если доступны, используйте domain-adapted чекпоинты (например, BioMedLM для биомедицины).
Эффект предобученное представление уже содержит общие паттерны, и fine-tuning лишь подстраивает их под конкретную задачу, а не учит с нуля.
9. Dropout и ensemble на инференсе
Можно дополнительно снизить variance модели, применяя Monte Carlo Dropout (оставлять dropout включённым на инференсе) или обучать несколько моделей и усреднять их предсказания (ensemble).
- MC Dropout: дешёво, но часто даёт скромный прирост.
- Ensemble: обучаем 3–5 моделей с разными инициализациями или random seed, усредняем логиты. На маленьких датасетах ensemble может существенно улучшить стабильность.
def predict_with_mc(model, input_ids, n_samples=10):
model.train() # оставляем dropout активным
predictions = []
for _ in range(n_samples):
with torch.no_grad():
logits = model(input_ids)
predictions.append(F.softmax(logits, dim=-1))
return torch.stack(predictions).mean(dim=0)
10. Практический чек-лист
При fine-tuning на маленьком датасете я последовательно применяю:
| # | Действие | Почему |
|---|---|---|
| 1 | Использовать PEFT (LoRA, rank≤16) | Резко сокращает ёмкость модели |
| 2 | Разделить train/val (80/20 или 70/30) | Для контроля переобучения |
| 3 | Установить drop-out 0.1–0.2 в адаптерах | Регуляризация |
| 4 | Weight decay 0.01–0.05 | Ограничение нормы весов |
| 5 | Learning rate 1e-5 (cosine scheduler) | Плавное обучение |
| 6 | Early stopping (patience 3) | Остановка при ухудшении |
| 7 | Data augmentation (back-translation / paraphrase) | Увеличение объёма |
| 8 | Кросс-валидация (5-fold) для выбора гиперпараметров | Стабильная оценка |
| 9 | Мониторинг train vs val loss каждую эпоху | Раннее обнаружение |
| 10 | Включить MC Dropout или ensemble | Дополнительная стабильность |
Пет-проект для закрепления
Задача Fine-tune DistilBERT на задаче классификации тональности (SST-2) с использованием всего 200 обучающих примеров (вместо 67k). Не допустить переобучения.
Инструменты PyTorch, Hugging Face Transformers, PEFT (LoRA), sklearn (метрики).
Шаги:
- Взять предобученный
distilbert-base-uncased. - Создать подвыборку из SST-2 (200 train, 50 validation).
- Применить LoRA (r=8, target_modules=['q_lin','k_lin','v_lin']).
- Установить гиперпараметры: lr=5e-5, batch=8, dropout=0.2, weight decay=0.01.
- Обучить с early stopping (patience=3).
- Записать train loss и validation loss в конце каждой эпохи.
- Построить график лоссов; убедиться, что val loss начинает расти после 3–4 эпох.
- Сравнить с full fine-tuning (без LoRA) — вторая модель должна переобучиться быстрее.
Ожидаемый результат LoRA-модель покажет меньшее расхождение train/val loss, а полный fine-tuning — переобучение уже на 2-й эпохе (train loss ~0, val loss высокий). Вы научитесь визуально определять переобучение и использовать PEFT.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 30. Как вы выбираете learning rate для fine-tuning? | Настройка гиперпараметров |
| 31. Что такое LoRA и как она работает? | Параметрически эффективное обучение |
| 33. Какие регуляризационные техники вы используете при fine-tuning? | Регуляризация |
| 34. Как вы оцениваете качество fine-tuning на маленьком датасете? | Оценка переобучения |
| 35. Как вы подбираете размер датасета для fine-tuning? | Минимальный объём данных |
| 36. Что такое data augmentation для LLM? | Аугментация текстов |
Навигация
- Предыдущий: 36
- Следующий: 38
- Индекс: 00. Индекс разборов