Как вы строите двухступенчатый ретривал (fast ANN + slow cross-encoder) в RAG?
Краткий тезис
Двухступенчатый ретривал — это архитектура, в которой первый этап (fast ANN) быстро отбирает top-k кандидатов (например, 100) с помощью приближённого поиска ближайших соседей, а второй этап (slow cross-encoder) дорого, но точно переранжирует эти кандидаты, оставляя top-n (например, 5–10). Такой подход позволяет совместить скорость ANN и точность cross-encoder’а, достигая высокого качества retrieval при приемлемой задержке.
1. Термины: ANN, HNSW, cross-encoder, bi-encoder
ANN (Approximate Nearest Neighbor) — алгоритмы приближённого поиска ближайших соседей. Они жертвуют небольшой точностью ради радикального ускорения. HNSW (Hierarchical Navigable Small World) — один из самых популярных ANN-индексов, строит многоуровневый граф, позволяя находить соседей за O(log N).
Bi-encoder — модель, которая независимо кодирует запрос и документ в один вектор (эмбеддинг). Используется в ANN: эмбеддинги заранее проиндексированы, поиск идёт по косинусной близости. Быстро, но теряет тонкие взаимодействия между запросом и документом.
Cross-encoder — модель, которая принимает на вход пару (запрос, документ) и выдаёт оценку релевантности. Учитывает полное взаимодействие токенов, поэтому точнее, но требует O(N) прямых проходов для N кандидатов — медленно.
Latency — задержка ответа системы. Для RAG критична: пользователь ждёт секунды, а не минуты.
2. Архитектура двухступенчатого ретривала
Pipeline выглядит так:
- Пользовательский запрос → эмбеддинг через bi-encoder (например, all-MiniLM-L6-v2).
- Fast ANN (HNSW) — поиск top-100 ближайших векторов в индексе.
- Cross-encoder reranking — для каждой из 100 пар (запрос, документ) вычисляется оценка релевантности (например,
cross-encoder/ms-marco-MiniLM-L-6-v2). - Отбор top-5–10 по оценкам cross-encoder’а → передача в LLM.
# Псевдокод pipeline
def two_stage_retrieve(query, index, bi_encoder, cross_encoder, top_k=100, top_n=5):
# Stage 1: ANN
query_emb = bi_encoder.encode(query)
distances, indices = index.search(query_emb, top_k) # HNSW
candidates = [documents[i] for i in indices[0]]
# Stage 2: Cross-encoder reranking
pairs = [(query, doc) for doc in candidates]
scores = cross_encoder.predict(pairs) # список float
ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
return [doc for doc, _ in ranked[:top_n]]
3. Fast ANN (HNSW) — первый этап
HNSW строит иерархический граф: на верхних уровнях — длинные связи (быстрый переход к области), на нижних — точные соседи. Поиск начинается с верхнего уровня и спускается вниз.
Параметры HNSW
M— количество связей на узел (типично 16–64). Больше M → выше точность, но больше памяти и время построения.- ef_construction — размер динамического списка при построении (типично 200–400). Влияет на качество графа.
- ef_search — размер списка при поиске (типично 50–200). Больше → точнее, но медленнее.
Почему fast поиск в HNSW имеет сложность O(log N), тогда как полный перебор (brute force) — O(N). Для 10⁶ документов ANN находит top-100 за миллисекунды.
Инструменты FAISS (Facebook AI Similarity Search) — де-факто стандарт. Пример создания индекса:
import faiss
dim = 384 # размерность эмбеддинга
index = faiss.IndexHNSWFlat(dim, M=32)
index.hnsw.efConstruction = 200
index.add(all_embeddings) # numpy array shape (N, dim)
4. Cross-encoder reranking — второй этап
Cross-encoder принимает пару (запрос, документ) как одну последовательность токенов: [CLS] query [SEP] document [SEP]. На выходе — скаляр (логит) релевантности, часто через сигмоиду в [0,1].
Почему slow для каждого кандидата нужно выполнить полный forward pass трансформера. Если top_k=100, то 100 проходов. Даже маленькая модель (6 слоёв) на CPU может занимать 0.5–1 секунду на 100 пар. GPU ускоряет, но всё равно дороже ANN.
Точность: cross-encoder значительно лучше bi-encoder’а, так как видит пересечение токенов. Например, на бенчмарке MS MARCO cross-encoder может давать +10–15% NDCG@10 по сравнению с bi-encoder.
Популярные модели cross-encoder/ms-marco-MiniLM-L-6-v2 (быстрый), cross-encoder/ms-marco-electra-base (точнее, но медленнее).
Batch processing: можно обрабатывать пары батчами (например, batch_size=32) для ускорения на GPU.
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
scores = model.predict([(query, doc) for doc in candidates])
5. Trade-off качество/латенси
| Этап | Время (пример) | Точность (Recall@10) | Комментарий |
|---|---|---|---|
| Только ANN (top-10) | 2–5 мс | 70–80% | Быстро, но может пропустить релевантные документы |
| Только cross-encoder (по всем документам) | 10+ секунд | 95%+ | Неприемлемо для реального времени |
| Двухступенчатый (ANN 100 + cross-encoder 5) | 50–200 мс | 90–95% | Хороший баланс |
Ключевой компромисс размер top_k на первом этапе. Чем больше top_k, тем выше вероятность, что cross-encoder найдёт лучшие документы, но растёт latency. Обычно выбирают top_k = 50–200.
Дополнительные оптимизации
- Использовать кэширование результатов cross-encoder для частых запросов.
- Distillation — обучить bi-encoder имитировать cross-encoder (как в ColBERT-v2).
- Адаптивный top_k — для простых запросов брать меньше кандидатов, для сложных — больше.
6. Практическая реализация с FAISS и Hugging Face
Полный пример на Python:
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
# Инициализация моделей
bi_encoder = SentenceTransformer('all-MiniLM-L6-v2')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# Документы (пример)
documents = ["Документ 1 ...", "Документ 2 ...", ...]
doc_embeddings = bi_encoder.encode(documents, show_progress_bar=True)
# Построение HNSW индекса
dim = doc_embeddings.shape[1]
index = faiss.IndexHNSWFlat(dim, M=32)
index.hnsw.efConstruction = 200
index.add(doc_embeddings.astype('float32'))
def retrieve(query, top_k=100, top_n=5):
# Stage 1
q_emb = bi_encoder.encode([query]).astype('float32')
distances, indices = index.search(q_emb, top_k)
candidates = [documents[i] for i in indices[0]]
# Stage 2
pairs = [(query, doc) for doc in candidates]
scores = cross_encoder.predict(pairs)
ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
return [doc for doc, _ in ranked[:top_n]]
# Использование
result = retrieve("Как работает HNSW?")
7. Когда использовать двухступенчатый ретривал
- Высокие требования к качеству (юридические, медицинские RAG).
- Большая база знаний (миллионы документов) — ANN обязателен.
- Допустимая задержка до 500 мс — cross-encoder успевает обработать 100–200 кандидатов.
- Альтернативы если latency критична (<50 мс), можно использовать ColBERT (late interaction) или обученный bi-encoder с hard negative mining. Если качество важнее всего — многоступенчатый ретривал (ANN → lightweight reranker → cross-encoder).
8. Метрики для оценки двухступенчатого ретривала
- Recall@k (на первом этапе) — сколько релевантных документов попало в top_k.
- NDCG@n (после реранжинга) — качество упорядочивания top_n.
- Latency p50/p99 — время выполнения pipeline.
- Trade-off curve — график Recall@10 vs latency при разных top_k.
Пет-проект для закрепления
Задача Реализовать двухступенчатый ретривал для датасета научных статей (например, ArXiv abstracts). Сравнить качество и скорость с одноступенчатым ANN.
Инструменты Python, FAISS, SentenceTransformers, Hugging Face CrossEncoder, Streamlit (для демо).
Шаги:
- Загрузить 10k–100k абстрактов (например, через Hugging Face Datasets).
- Вычислить эмбеддинги bi-encoder’ом, построить HNSW индекс.
- Реализовать pipeline с cross-encoder реранжингом.
- Создать тестовый набор запросов с ручной разметкой релевантности (или использовать готовый датасет).
- Измерить Recall@10, NDCG@5, latency для разных top_k (50, 100, 200).
- Визуализировать trade-off.
Ожидаемый результат Вы увидите, что двухступенчатый подход даёт прирост NDCG на 5–15% при увеличении latency в 10–20 раз по сравнению с чистым ANN. Научитесь настраивать top_k под свои latency-бюджеты.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 371 | Как выбрать алгоритм ANN для RAG? |
| 373 | Как обучать cross-encoder для реранжинга? |
| 374 | Что такое ColBERT и как он сочетает bi-encoder и cross-encoder? |
| 375 | Как оценивать качество retrieval в RAG? |
| 380 | Как уменьшить latency RAG-системы? |
Навигация
- Предыдущий: 371
- Следующий: 373
- Индекс: 00. Индекс разборов