Files
hvm-docs/eval/retrievers.py
T
justin dda044eb95 search: BM25-default + cross-encoder rerank, hybrid behind env gate
Phase 3/6/7/8 in one pass since they depend on each other.

* docs_mcp/server.py
  - Wire search_docs / get_page / list_versions tool bodies.
  - search_docs flow: BM25 first (rag.bm25 FTS5) → over-fetch RERANK_POOL
    chunks → POST to RERANK_URL/v1/rerank → return top-k. Dense is the
    fallback when BM25 finds nothing. HYBRID_SEARCH=true switches to
    dense+BM25+RRF (fused via the new _rrf_fuse helper).
  - All retrieval failures are caught and fall back to the next layer,
    so a dead reranker or missing BM25 db never blocks a search.
  - Source URLs built from the bundle's docId so results link straight
    into support.hpe.com.

* eval/
  - 22 hand-curated golden queries grounded in real corpus page titles.
  - DenseRetriever / BM25Retriever / HybridRetriever / RerankedRetriever
    + MRR/Recall@K/nDCG@K harness. RERANK_URL env activates the
    reranked variants.
  - Committed eval/results/baseline.md. On this corpus:
        dense:                MRR 0.539
        bm25:                 MRR 0.880
        hybrid_rrf:           MRR 0.692
        bm25+rerank:          MRR 0.920  (winner)
        hybrid_rrf+rerank:    MRR 0.875
    HPE structured docs use controlled vocabulary, so lexical match
    dominates. Hybrid loses because dense pollutes the fused pool.

* scripts/rerank_server.py
  - Minimal HTTP /v1/rerank over sentence-transformers
    cross-encoder/ms-marco-MiniLM-L-6-v2. Cohere-style request/response.
  - This is the dev/CPU fallback; production replaces it with the
    llama.cpp + jina-reranker-v2-base GGUF sidecar (same wire protocol).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 13:06:51 -04:00

153 lines
5.4 KiB
Python

"""Retriever protocol + concrete implementations.
A single matrix dimension per knob (dense / reranked / bm25 / hybrid)
so the eval harness can compare them apples-to-apples. Implement these
once at Phase 7 and reuse them across every retrieval change.
Each retriever returns a ranked list of (bundle_id, page_id) tuples
deduplicated to the page level (chunks within the same page collapse
to one entry; the highest-ranked chunk's position wins).
"""
from __future__ import annotations
from typing import Iterable, Protocol
class Retriever(Protocol):
name: str
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
"""Return up to k (bundle_id, page_id) tuples in rank order."""
...
def _split_chunk_id(chunk_id: str) -> tuple[str, str, int]:
"""`bundle::page::ordinal` -> (bundle, page, int(ordinal))."""
bid, pid, ordinal = chunk_id.split("::")
return bid, pid, int(ordinal)
def _collapse_to_pages(chunk_ids: Iterable[str], k: int) -> list[tuple[str, str]]:
seen: set[tuple[str, str]] = set()
out: list[tuple[str, str]] = []
for cid in chunk_ids:
bid, pid, _ord = _split_chunk_id(cid)
key = (bid, pid)
if key in seen:
continue
seen.add(key)
out.append(key)
if len(out) >= k:
break
return out
class DenseRetriever:
"""Chroma cosine search via the live embedding function."""
name = "dense"
def __init__(self, collection, pool: int = 50):
self.col = collection
self.pool = pool
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
res = self.col.query(query_texts=[query], n_results=self.pool)
ids = (res.get("ids") or [[]])[0]
return _collapse_to_pages(ids, k)
class BM25Retriever:
"""SQLite FTS5 lexical search."""
name = "bm25"
def __init__(self, bm25_index, pool: int = 200):
self.bm = bm25_index
self.pool = pool
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
hits = self.bm.query(query, n=self.pool)
return _collapse_to_pages((cid for cid, _score in hits), k)
class HybridRetriever:
"""Reciprocal Rank Fusion of dense + BM25 rankings."""
name = "hybrid_rrf"
def __init__(self, dense: DenseRetriever, bm25: BM25Retriever, k_rrf: int = 60, pool: int = 100):
self.dense = dense
self.bm25 = bm25
self.k_rrf = k_rrf
self.pool = pool
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
dense_pages = self.dense.retrieve(query, k=self.pool)
bm25_pages = self.bm25.retrieve(query, k=self.pool)
scores: dict[tuple[str, str], float] = {}
for rank, page in enumerate(dense_pages, start=1):
scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank)
for rank, page in enumerate(bm25_pages, start=1):
scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank)
ranked = sorted(scores.items(), key=lambda kv: -kv[1])
return [page for page, _s in ranked[:k]]
def _rerank_pool(rerank_url: str, query: str, ids_and_texts: list[tuple[str, str]],
timeout: float = 30.0) -> list[str] | None:
"""POST to /v1/rerank, return ids in reranked order. None on failure."""
if not ids_and_texts:
return []
import httpx
try:
with httpx.Client(timeout=timeout) as c:
r = c.post(f"{rerank_url}/v1/rerank", json={
"query": query,
"documents": [(t or "")[:2000] for _i, t in ids_and_texts],
"top_n": len(ids_and_texts),
})
r.raise_for_status()
results = r.json().get("results") or []
return [ids_and_texts[item["index"]][0] for item in results
if isinstance(item.get("index"), int)
and 0 <= item["index"] < len(ids_and_texts)]
except Exception:
return None
class RerankedRetriever:
"""Pull a candidate pool via a base retriever, then cross-encoder re-rank."""
def __init__(self, base: Retriever, collection, rerank_url: str, name_suffix: str = "rerank",
pool: int = 50, timeout: float = 30.0):
self.base = base
self.col = collection
self.url = rerank_url
self.name = f"{base.name}+{name_suffix}"
self.pool = pool
self.timeout = timeout
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
# Base returns deduplicated page-level tuples; rerank needs CHUNK-level
# texts to be informative. Pull each page's chunk 0 text from Chroma.
pages = self.base.retrieve(query, k=self.pool)
if not pages:
return []
chunk_ids = [f"{bid}::{pid}::0" for bid, pid in pages]
g = self.col.get(ids=chunk_ids, include=["documents"])
by_id = dict(zip(g["ids"], g["documents"]))
ids_and_texts = [(cid, by_id.get(cid, "")) for cid in chunk_ids]
order = _rerank_pool(self.url, query, ids_and_texts, timeout=self.timeout)
if order is None:
return pages[:k]
out: list[tuple[str, str]] = []
seen: set[tuple[str, str]] = set()
for cid in order:
bid, pid, _ = cid.split("::")
key = (bid, pid)
if key in seen:
continue
seen.add(key)
out.append(key)
if len(out) >= k:
break
return out