Как вы строите двухступенчатый ретривал (fast ANN + slow cross-encoder) в RAG?

Краткий тезис

Двухступенчатый ретривал — это архитектура, в которой первый этап (fast ANN) быстро отбирает top-k кандидатов (например, 100) с помощью приближённого поиска ближайших соседей, а второй этап (slow cross-encoder) дорого, но точно переранжирует эти кандидаты, оставляя top-n (например, 5–10). Такой подход позволяет совместить скорость ANN и точность cross-encoder’а, достигая высокого качества retrieval при приемлемой задержке.


1. Термины: ANN, HNSW, cross-encoder, bi-encoder

ANN (Approximate Nearest Neighbor) — алгоритмы приближённого поиска ближайших соседей. Они жертвуют небольшой точностью ради радикального ускорения. HNSW (Hierarchical Navigable Small World) — один из самых популярных ANN-индексов, строит многоуровневый граф, позволяя находить соседей за O(log N).

Bi-encoder — модель, которая независимо кодирует запрос и документ в один вектор (эмбеддинг). Используется в ANN: эмбеддинги заранее проиндексированы, поиск идёт по косинусной близости. Быстро, но теряет тонкие взаимодействия между запросом и документом.

Cross-encoder — модель, которая принимает на вход пару (запрос, документ) и выдаёт оценку релевантности. Учитывает полное взаимодействие токенов, поэтому точнее, но требует O(N) прямых проходов для N кандидатов — медленно.

Latency — задержка ответа системы. Для RAG критична: пользователь ждёт секунды, а не минуты.


2. Архитектура двухступенчатого ретривала

Pipeline выглядит так:

  1. Пользовательский запрос → эмбеддинг через bi-encoder (например, all-MiniLM-L6-v2).
  2. Fast ANN (HNSW) — поиск top-100 ближайших векторов в индексе.
  3. Cross-encoder reranking — для каждой из 100 пар (запрос, документ) вычисляется оценка релевантности (например, cross-encoder/ms-marco-MiniLM-L-6-v2).
  4. Отбор top-5–10 по оценкам cross-encoder’а → передача в LLM.
# Псевдокод pipeline
def two_stage_retrieve(query, index, bi_encoder, cross_encoder, top_k=100, top_n=5):
    # Stage 1: ANN
    query_emb = bi_encoder.encode(query)
    distances, indices = index.search(query_emb, top_k)  # HNSW
    candidates = [documents[i] for i in indices[0]]
    
    # Stage 2: Cross-encoder reranking
    pairs = [(query, doc) for doc in candidates]
    scores = cross_encoder.predict(pairs)  # список float
    ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
    return [doc for doc, _ in ranked[:top_n]]

3. Fast ANN (HNSW) — первый этап

HNSW строит иерархический граф: на верхних уровнях — длинные связи (быстрый переход к области), на нижних — точные соседи. Поиск начинается с верхнего уровня и спускается вниз.

Параметры HNSW

  • M — количество связей на узел (типично 16–64). Больше M → выше точность, но больше памяти и время построения.
  • ef_construction — размер динамического списка при построении (типично 200–400). Влияет на качество графа.
  • ef_search — размер списка при поиске (типично 50–200). Больше → точнее, но медленнее.

Почему fast поиск в HNSW имеет сложность O(log N), тогда как полный перебор (brute force) — O(N). Для 10⁶ документов ANN находит top-100 за миллисекунды.

Инструменты FAISS (Facebook AI Similarity Search) — де-факто стандарт. Пример создания индекса:

import faiss
dim = 384  # размерность эмбеддинга
index = faiss.IndexHNSWFlat(dim, M=32)
index.hnsw.efConstruction = 200
index.add(all_embeddings)  # numpy array shape (N, dim)

4. Cross-encoder reranking — второй этап

Cross-encoder принимает пару (запрос, документ) как одну последовательность токенов: [CLS] query [SEP] document [SEP]. На выходе — скаляр (логит) релевантности, часто через сигмоиду в [0,1].

Почему slow для каждого кандидата нужно выполнить полный forward pass трансформера. Если top_k=100, то 100 проходов. Даже маленькая модель (6 слоёв) на CPU может занимать 0.5–1 секунду на 100 пар. GPU ускоряет, но всё равно дороже ANN.

Точность: cross-encoder значительно лучше bi-encoder’а, так как видит пересечение токенов. Например, на бенчмарке MS MARCO cross-encoder может давать +10–15% NDCG@10 по сравнению с bi-encoder.

Популярные модели cross-encoder/ms-marco-MiniLM-L-6-v2 (быстрый), cross-encoder/ms-marco-electra-base (точнее, но медленнее).

Batch processing: можно обрабатывать пары батчами (например, batch_size=32) для ускорения на GPU.

from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
scores = model.predict([(query, doc) for doc in candidates])

5. Trade-off качество/латенси

ЭтапВремя (пример)Точность (Recall@10)Комментарий
Только ANN (top-10)2–5 мс70–80%Быстро, но может пропустить релевантные документы
Только cross-encoder (по всем документам)10+ секунд95%+Неприемлемо для реального времени
Двухступенчатый (ANN 100 + cross-encoder 5)50–200 мс90–95%Хороший баланс

Ключевой компромисс размер top_k на первом этапе. Чем больше top_k, тем выше вероятность, что cross-encoder найдёт лучшие документы, но растёт latency. Обычно выбирают top_k = 50–200.

Дополнительные оптимизации

  • Использовать кэширование результатов cross-encoder для частых запросов.
  • Distillation — обучить bi-encoder имитировать cross-encoder (как в ColBERT-v2).
  • Адаптивный top_k — для простых запросов брать меньше кандидатов, для сложных — больше.

6. Практическая реализация с FAISS и Hugging Face

Полный пример на Python:

import faiss
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder

# Инициализация моделей
bi_encoder = SentenceTransformer('all-MiniLM-L6-v2')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# Документы (пример)
documents = ["Документ 1 ...", "Документ 2 ...", ...]
doc_embeddings = bi_encoder.encode(documents, show_progress_bar=True)

# Построение HNSW индекса
dim = doc_embeddings.shape[1]
index = faiss.IndexHNSWFlat(dim, M=32)
index.hnsw.efConstruction = 200
index.add(doc_embeddings.astype('float32'))

def retrieve(query, top_k=100, top_n=5):
    # Stage 1
    q_emb = bi_encoder.encode([query]).astype('float32')
    distances, indices = index.search(q_emb, top_k)
    candidates = [documents[i] for i in indices[0]]
    
    # Stage 2
    pairs = [(query, doc) for doc in candidates]
    scores = cross_encoder.predict(pairs)
    ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
    return [doc for doc, _ in ranked[:top_n]]

# Использование
result = retrieve("Как работает HNSW?")

7. Когда использовать двухступенчатый ретривал

  • Высокие требования к качеству (юридические, медицинские RAG).
  • Большая база знаний (миллионы документов) — ANN обязателен.
  • Допустимая задержка до 500 мс — cross-encoder успевает обработать 100–200 кандидатов.
  • Альтернативы если latency критична (<50 мс), можно использовать ColBERT (late interaction) или обученный bi-encoder с hard negative mining. Если качество важнее всего — многоступенчатый ретривал (ANN → lightweight reranker → cross-encoder).

8. Метрики для оценки двухступенчатого ретривала

  • Recall@k (на первом этапе) — сколько релевантных документов попало в top_k.
  • NDCG@n (после реранжинга) — качество упорядочивания top_n.
  • Latency p50/p99 — время выполнения pipeline.
  • Trade-off curve — график Recall@10 vs latency при разных top_k.

Пет-проект для закрепления

Задача Реализовать двухступенчатый ретривал для датасета научных статей (например, ArXiv abstracts). Сравнить качество и скорость с одноступенчатым ANN.

Инструменты Python, FAISS, SentenceTransformers, Hugging Face CrossEncoder, Streamlit (для демо).

Шаги:

  1. Загрузить 10k–100k абстрактов (например, через Hugging Face Datasets).
  2. Вычислить эмбеддинги bi-encoder’ом, построить HNSW индекс.
  3. Реализовать pipeline с cross-encoder реранжингом.
  4. Создать тестовый набор запросов с ручной разметкой релевантности (или использовать готовый датасет).
  5. Измерить Recall@10, NDCG@5, latency для разных top_k (50, 100, 200).
  6. Визуализировать trade-off.

Ожидаемый результат Вы увидите, что двухступенчатый подход даёт прирост NDCG на 5–15% при увеличении latency в 10–20 раз по сравнению с чистым ANN. Научитесь настраивать top_k под свои latency-бюджеты.


Связь с другими вопросами

ВопросТема
371Как выбрать алгоритм ANN для RAG?
373Как обучать cross-encoder для реранжинга?
374Что такое ColBERT и как он сочетает bi-encoder и cross-encoder?
375Как оценивать качество retrieval в RAG?
380Как уменьшить latency RAG-системы?

Навигация