Aivaro
  • Оглавление
  • Вопросы
  • Практика
  • Вики
  • Материалы сообщества
  • Тесты
  • Поиск
✈Telegram @ai_varo
RUEN中文
…
Оглавление/Вопросы/#969

Как вы fine-tune embedding модель под свой домен? (sentence-transformers, SimCSE, MultipleNegativesRankingLoss)

Краткий тезис

Fine-tuning embedding модели под целевой домен критически повышает качество RAG и семантического поиска. Основные подходы: MultipleNegativesRankingLoss (обучение на парах «запрос – релевантный документ»), SimCSE (контрастивное обучение на одном тексте с дропаутом) и дообучение через SentenceTransformer.train(). Оценка проводится через recall@k на удержанном наборе.

2. SimCSE: contrastive learning на одном тексте

SimCSE (SimCSE: Simple Contrastive Learning of Sentence Embeddings) – метод без учителя, где из одного предложения с помощью dropout генерируются две копии (с разными dropout‑масками), которые учатся быть ближе друг к другу, отталкиваясь от других предложений батча.

  • Без учителя (unsupervised SimCSE):
    • Каждое предложение (x_i) дважды пропускается через transformer с dropout → получаем (h_i^{z1}, h_i^{z2}).
    • Loss = InfoNCE, где положительная пара – два представления одного предложения, остальные – негативы.
  • С учителем (supervised SimCSE):
    • Используются пары (anchor, positive) из NLI или доменных данных, плюс hard negatives.
  • Почему это релевантно для своего домена:
    • Если нет размеченных пар запросов и документов, unsupervised SimCSE на корпусе доменных текстов даёт сильный буст к качеству эмбеддингов.
    • Легко интегрируется с sentence-transformers:
      model = SentenceTransformer('bert-base-uncased')
      model.fit(train_objectives=[(train_dataloader, losses.SimCSE(model))], epochs=1)
      
  • Важный параметр: температура (\tau) (обычно 0.05) и размер батча (чем больше, тем лучше).

3. Sentence-transformers: SentenceTransformer.train()

sentence-transformers предоставляет высокоуровневый API для fine-tuning эмбеддингов.

  • Базовые шаги:

    1. Загрузка предобученной модели: model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens').
    2. Подготовка данных: список InputExample с текстами (для MNRL – два текста, для SimCSE – один).
    3. Создание даталоадера.
    4. Выбор функции потерь (MultipleNegativesRankingLoss, SimCSE, ContrastiveLoss, TripletLoss).
    5. Вызов model.fit() или model.train() (слова устаревшие, но концепт тот же).
  • Пример полного цикла для домена (MNRL):

    from sentence_transformers import SentenceTransformer, InputExample, losses
    from torch.utils.data import DataLoader
    
    model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
    train_samples = [InputExample(texts=[q, d]) for q, d in domain_pairs]
    dataloader = DataLoader(train_samples, batch_size=64, shuffle=True)
    loss = losses.MultipleNegativesRankingLoss(model)
    model.fit(train_objectives=[(dataloader, loss)], epochs=3, warmup_steps=100)
    
  • Тонкости:

    • Используйте model.tokenizer для обработки длинных текстов (max_length=512).
    • Заморозка нижних слоёв энкодера (optional) ускоряет обучение.
    • После fine-tuning обязательно сохраните модель и протестируйте.

4. Оценка: recall@k на удержанном

Ключевая метрика для эмбеддингов в RAG – Recall@k: как часто релевантный документ попадает в топ-k ближайших соседей.

  • Процесс валидации:

    1. Удержать 10–20% пар (запрос – релевантный документ).
    2. Закодировать все документы (корпус) эмбеддингами.
    3. Для каждого запроса найти k ближайших эмбеддингов (FAISS, косинусное расстояние).
    4. Проверить, есть ли среди них релевантный: recall@k = #успешных запросов / всего запросов.
  • Пример кода:

    from sentence_transformers.util import semantic_search
    
    corpus_embeddings = model.encode(corpus_docs, convert_to_tensor=True)
    query_embeddings = model.encode(queries, convert_to_tensor=True)
    hits = semantic_search(query_embeddings, corpus_embeddings, top_k=10)
    recall = sum(1 for i, h in enumerate(hits) if relevant_idx[i] in [r['corpus_id'] for r in h]) / len(hits)
    
  • Интерпретация: recall@1, 5, 10. Для RAG обычно достаточно recall@10 > 0.9.

  • Совет: оценку можно проводить на отдельном датасете из настоящих логов поиска.


N. Пет-проект для закрепления

Задача: Дообучить all-MiniLM-L6-v2 на синтетических данных из вашего профильного форума (например, вопросы‑ответы с Stack Overflow по Python) с помощью MultipleNegativesRankingLoss, проверить recall@5 до/после.

Инструменты:

  • Python 3.9+, sentence-transformers 2.2+.
  • FAISS для быстрого поиска.

Шаги:

  1. Собрать 2000 пар (вопрос, лучший ответ) с api.stackexchange.com.
  2. Разделить 80/20 на train/val.
  3. Обучить модель: model.fit(..., loss=MultipleNegativesRankingLoss), 2 эпохи.
  4. Закодировать валидационные вопросы и корпус ответов, вычислить recall@5.
  5. Сравнить с базовой моделью (без дообучения).

Ожидаемый результат:

  • Прирост recall@5 на 5–15% (зависит от близости предобученных данных к домену).
  • Понимание, как пайплайн fine-tuning встраивается в реальный RAG.

Связь с другими вопросами

ВопросТема
35Метрики оценки эмбеддингов

Навигация

  • Предыдущий: 968
  • Следующий: 970
  • Индекс: 00. Индекс разборов