"""Retriever protocol + concrete implementations for seed-mcp eval. Each retriever returns a ranked list of chunk_ids. The eval harness in ``run_eval.py`` measures each retriever against the golden ``queries.jsonl`` set across MRR / Recall@K / nDCG@K. Four named configurations, matching the four switches in ``docs_mcp/server.py``: dense — Chroma dense retrieval alone bm25 — SQLite FTS5 BM25 alone hybrid — dense + bm25 fused via RRF hybrid_rerank — hybrid pool → cross-encoder rerank Each retriever takes ``filters`` (the same dict shape ``_build_where`` accepts in server.py) so trial-specific facets (data_type, state, year, crop) work consistently across the four configurations. """ from __future__ import annotations import os import sys from pathlib import Path from typing import Protocol # Add repo root so we can import docs_mcp and rag from here. sys.path.insert(0, str(Path(__file__).resolve().parents[1])) class Retriever(Protocol): name: str def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: """Return up to k chunk_ids in rank order.""" ... def _build_where(filters: dict | None) -> dict | None: """Mirror of docs_mcp.server._build_where but accepts the eval's looser shape.""" if not filters: return None conds: list[dict] = [] if filters.get("data_type"): conds.append({"data_type": filters["data_type"]}) if filters.get("crop"): conds.append({"crop": filters["crop"].lower()}) if filters.get("brand"): conds.append({"brand": filters["brand"].upper()}) if filters.get("state"): s = filters["state"] conds.append({"state": s.upper() if len(s) <= 3 else s}) if filters.get("year"): conds.append({"year": int(filters["year"])}) if not conds: return None if len(conds) == 1: return conds[0] return {"$and": conds} class DenseRetriever: name = "dense" def __init__(self, collection, pool: int = 50): self.col = collection self.pool = pool def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: where = _build_where(filters) try: r = self.col.query( query_texts=[query], n_results=max(k, self.pool), where=where, ) except Exception: return [] return (r.get("ids") or [[]])[0][:k] class BM25Retriever: name = "bm25" def __init__(self, bm25, pool: int = 50): self.bm25 = bm25 self.pool = pool def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: where = _build_where(filters) hits = self.bm25.query(query, n=max(k, self.pool), where=where) return [cid for cid, _ in hits[:k]] class HybridRetriever: """Dense + BM25 fused via RRF — same fusion the server uses.""" name = "hybrid" def __init__(self, collection, bm25, pool: int = 50, rrf_k: int = 60): self.col = collection self.bm25 = bm25 self.pool = pool self.rrf_k = rrf_k def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: where = _build_where(filters) try: d = self.col.query(query_texts=[query], n_results=self.pool, where=where) dense_ids = (d.get("ids") or [[]])[0] except Exception: dense_ids = [] bm25_ids = [c for c, _ in self.bm25.query(query, n=self.pool, where=where)] scores: dict[str, float] = {} for ranking in (dense_ids, bm25_ids): for rank, cid in enumerate(ranking): scores[cid] = scores.get(cid, 0.0) + 1.0 / (self.rrf_k + rank + 1) fused = sorted(scores, key=lambda d: scores[d], reverse=True) return fused[:k] class HybridRerankRetriever: """Hybrid pool → cross-encoder rerank via the llama.cpp endpoint.""" name = "hybrid+rerank" def __init__(self, collection, bm25, rerank_url: str, pool: int = 50, rerank_pool: int = 50, rrf_k: int = 60, doc_max_chars: int = 2000, timeout: float = 30.0): self.col = collection self.bm25 = bm25 self.rerank_url = rerank_url.rstrip("/") self.pool = pool self.rerank_pool = rerank_pool self.rrf_k = rrf_k self.doc_max_chars = doc_max_chars self.timeout = timeout def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: where = _build_where(filters) try: d = self.col.query( query_texts=[query], n_results=self.pool, where=where, include=["documents"], ) dense_ids = (d.get("ids") or [[]])[0] dense_docs = (d.get("documents") or [[]])[0] id_to_doc = dict(zip(dense_ids, dense_docs)) except Exception: dense_ids = [] id_to_doc = {} bm25_ids = [c for c, _ in self.bm25.query(query, n=self.pool, where=where)] # Fuse to a hybrid pool scores: dict[str, float] = {} for ranking in (dense_ids, bm25_ids): for rank, cid in enumerate(ranking): scores[cid] = scores.get(cid, 0.0) + 1.0 / (self.rrf_k + rank + 1) fused = sorted(scores, key=lambda d: scores[d], reverse=True) # Fetch docs for any BM25-only ids in the rerank pool missing = [cid for cid in fused[: self.rerank_pool] if cid not in id_to_doc] if missing: try: extra = self.col.get(ids=missing, include=["documents"]) for cid, doc in zip(extra.get("ids") or [], extra.get("documents") or []): id_to_doc[cid] = doc except Exception: pass # Rerank pool_ids = fused[: self.rerank_pool] docs = [(id_to_doc.get(cid, "") or "")[: self.doc_max_chars] for cid in pool_ids] try: import httpx with httpx.Client(timeout=self.timeout) as c: r = c.post( f"{self.rerank_url}/v1/rerank", json={"model": "rerank", "query": query, "documents": docs}, ) r.raise_for_status() results = r.json().get("results") or [] if not results: return fused[:k] ordered = sorted(results, key=lambda x: -x.get("relevance_score", float("-inf"))) reranked = [pool_ids[x["index"]] for x in ordered if 0 <= x.get("index", -1) < len(pool_ids)] return reranked[:k] except Exception: return fused[:k] def build_all_retrievers(collection, bm25, rerank_url: str | None = None) -> list[Retriever]: """Return the four named retrievers ready to evaluate.""" out: list[Retriever] = [ DenseRetriever(collection), BM25Retriever(bm25), HybridRetriever(collection, bm25), ] if rerank_url: out.append(HybridRerankRetriever(collection, bm25, rerank_url)) return out