Как вы fine-tune embedding модель под свой домен? (sentence-transformers, SimCSE, MultipleNegativesRankingLoss)
Краткий тезис
Fine-tuning embedding модели под целевой домен критически повышает качество RAG и семантического поиска. Основные подходы: MultipleNegativesRankingLoss (обучение на парах «запрос – релевантный документ»), SimCSE (контрастивное обучение на одном тексте с дропаутом) и дообучение через SentenceTransformer.train(). Оценка проводится через recall@k на удержанном наборе.
2. SimCSE: contrastive learning на одном тексте
SimCSE (SimCSE: Simple Contrastive Learning of Sentence Embeddings) – метод без учителя, где из одного предложения с помощью dropout генерируются две копии (с разными dropout‑масками), которые учатся быть ближе друг к другу, отталкиваясь от других предложений батча.
- Без учителя (unsupervised SimCSE):
- Каждое предложение (x_i) дважды пропускается через transformer с dropout → получаем (h_i^{z1}, h_i^{z2}).
- Loss = InfoNCE, где положительная пара – два представления одного предложения, остальные – негативы.
- С учителем (supervised SimCSE):
- Используются пары (anchor, positive) из NLI или доменных данных, плюс hard negatives.
- Почему это релевантно для своего домена:
- Если нет размеченных пар запросов и документов, unsupervised SimCSE на корпусе доменных текстов даёт сильный буст к качеству эмбеддингов.
- Легко интегрируется с sentence-transformers:
model = SentenceTransformer('bert-base-uncased') model.fit(train_objectives=[(train_dataloader, losses.SimCSE(model))], epochs=1)
- Важный параметр: температура (\tau) (обычно 0.05) и размер батча (чем больше, тем лучше).
3. Sentence-transformers: SentenceTransformer.train()
sentence-transformers предоставляет высокоуровневый API для fine-tuning эмбеддингов.
-
Базовые шаги:
- Загрузка предобученной модели:
model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens'). - Подготовка данных: список
InputExampleс текстами (для MNRL – два текста, для SimCSE – один). - Создание даталоадера.
- Выбор функции потерь (MultipleNegativesRankingLoss, SimCSE, ContrastiveLoss, TripletLoss).
- Вызов
model.fit()илиmodel.train()(слова устаревшие, но концепт тот же).
- Загрузка предобученной модели:
-
Пример полного цикла для домена (MNRL):
from sentence_transformers import SentenceTransformer, InputExample, losses from torch.utils.data import DataLoader model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') train_samples = [InputExample(texts=[q, d]) for q, d in domain_pairs] dataloader = DataLoader(train_samples, batch_size=64, shuffle=True) loss = losses.MultipleNegativesRankingLoss(model) model.fit(train_objectives=[(dataloader, loss)], epochs=3, warmup_steps=100) -
Тонкости:
- Используйте
model.tokenizerдля обработки длинных текстов (max_length=512). - Заморозка нижних слоёв энкодера (optional) ускоряет обучение.
- После fine-tuning обязательно сохраните модель и протестируйте.
- Используйте
4. Оценка: recall@k на удержанном
Ключевая метрика для эмбеддингов в RAG – Recall@k: как часто релевантный документ попадает в топ-k ближайших соседей.
-
Процесс валидации:
-
Пример кода:
from sentence_transformers.util import semantic_search corpus_embeddings = model.encode(corpus_docs, convert_to_tensor=True) query_embeddings = model.encode(queries, convert_to_tensor=True) hits = semantic_search(query_embeddings, corpus_embeddings, top_k=10) recall = sum(1 for i, h in enumerate(hits) if relevant_idx[i] in [r['corpus_id'] for r in h]) / len(hits) -
Интерпретация: recall@1, 5, 10. Для RAG обычно достаточно recall@10 > 0.9.
-
Совет: оценку можно проводить на отдельном датасете из настоящих логов поиска.
N. Пет-проект для закрепления
Задача: Дообучить all-MiniLM-L6-v2 на синтетических данных из вашего профильного форума (например, вопросы‑ответы с Stack Overflow по Python) с помощью MultipleNegativesRankingLoss, проверить recall@5 до/после.
Инструменты:
- Python 3.9+, sentence-transformers 2.2+.
- FAISS для быстрого поиска.
Шаги:
- Собрать 2000 пар (вопрос, лучший ответ) с api.stackexchange.com.
- Разделить 80/20 на train/val.
- Обучить модель:
model.fit(..., loss=MultipleNegativesRankingLoss), 2 эпохи. - Закодировать валидационные вопросы и корпус ответов, вычислить recall@5.
- Сравнить с базовой моделью (без дообучения).
Ожидаемый результат:
- Прирост recall@5 на 5–15% (зависит от близости предобученных данных к домену).
- Понимание, как пайплайн fine-tuning встраивается в реальный RAG.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 35 | Метрики оценки эмбеддингов |
Навигация
- Предыдущий: 968
- Следующий: 970
- Индекс: 00. Индекс разборов