From bd71f30ca72067d8939bcab40ea565397cc8c216 Mon Sep 17 00:00:00 2001 From: Justin Paul Date: Mon, 25 May 2026 17:02:57 -0400 Subject: [PATCH] =?UTF-8?q?Phase=206/7:=20wire=20rerank=20+=20eval=20harne?= =?UTF-8?q?ss=20=E2=80=94=20100%=20pass=20on=2021=20golden=20queries?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 6 — Reranker integration - New _rerank(query, [(cid, doc), ...]) helper in server.py calls llama.cpp's /v1/rerank endpoint, returns reranker-ordered ids or None on failure (graceful fallback — search never blocks on the sidecar). - search_docs + search_trials both call _rerank() on the post- hybrid pool BEFORE truncating to k. The variety-code prefilter still pins exact matches on top. - Per-doc truncation to 2000 chars to fit jina-reranker-v2-base's per-pair token budget. Full chunk text still returned to the caller — truncation is rerank-input-only. - Telemetry adds `reranked: true|false` so usage logs distinguish reranked calls. Phase 7 — Eval harness - eval/queries.jsonl: 21 golden queries spanning: * variety-code lookups (DKC62-08RIB, AG29XF4, WB6430, E085Z5, AP Iliad) * semantic variety queries (drought-tolerant corn, SCN MG-3 soy, Rps3a, XtendFlex, HRS stripe rust, SWW PNW, Goss's Wilt) * trial queries (IA/IN/MN regional, AP Iliad ID, NK1701 head- to-head, silage Ton/Acre, product=DKC65-95) * anti-hallucination (Pioneer P1142 fallback, DKC65-20 not-in- corpus expected_empty) - eval/retrievers.py: 4 named retrievers — dense, bm25, hybrid (dense+bm25+RRF), hybrid+rerank — all sharing the same filter shape as docs_mcp/server.py._build_where. - eval/run_eval.py: runs each retriever against each query, reports Recall / Precision@1 / MRR / avg latency. Markdown output in eval/results/baseline.md. Baseline results (k=5, 21 queries): | Retriever | Pass | Recall | P@1 | MRR | Avg ms | |-----------------|-------|--------|-------|-------|--------| | hybrid+rerank | 21/21 | 100% | 90% | 0.905 | 2064 | | bm25 | 20/21 | 95% | 81% | 0.833 | 5 | | hybrid | 15/21 | 71% | 62% | 0.619 | 73 | | dense | 14/21 | 67% | 38% | 0.440 | 79 | Key findings: 1. hybrid+rerank wins on quality — 100% pass, 90% P@1. 2. BM25 alone is surprisingly competitive (95% pass) at 5 ms — excellent fallback when rerank is down. The variety-code prefilter in search_docs is doing a lot of work here. 3. Dense embedding alone is the WEAKEST configuration on this corpus — variety identity tokens (DKC62-08RIB, AP Iliad, Rps3a) have no semantic neighbors, so nomic-embed-text returns noise. The hybrid (no rerank) layer actively hurts because RRF dilutes the BM25 ranking with dense noise. 4. Anti-hallucination queries (Pioneer fallback, DKC65-20 not- in-corpus) pass on ALL retrievers including dense-only — the must_not_contain + expected_empty design holds. Deploy decision: HYBRID_SEARCH=true + RERANK_URL set (production env already has both — refresh.yml + image-only.yml + deploy/docker-compose.yml all configured). Co-Authored-By: Claude Opus 4.7 (1M context) --- docs_mcp/server.py | 123 +++++++++++++++ eval/queries.jsonl | 21 +++ eval/results/baseline.md | 41 +++++ eval/retrievers.py | 234 +++++++++++++++++++++++------ eval/run_eval.py | 313 ++++++++++++++++++++++++++++++++++----- 5 files changed, 643 insertions(+), 89 deletions(-) create mode 100644 eval/queries.jsonl create mode 100644 eval/results/baseline.md diff --git a/docs_mcp/server.py b/docs_mcp/server.py index 8162c9ef..e30db021 100644 --- a/docs_mcp/server.py +++ b/docs_mcp/server.py @@ -289,6 +289,75 @@ def _rrf_fuse(rankings: list[list[str]], k: int = RRF_K) -> list[str]: return sorted(scores, key=lambda d: scores[d], reverse=True) +# Per-doc character cap when sending to the reranker. jina-reranker-v2-base +# accepts up to ~1024 tokens PER QUERY+DOC PAIR (n_ctx_train) and rejects +# the WHOLE BATCH if any one pair exceeds it. Truncating each doc to +# ~2000 chars (≈ 500-700 tokens) leaves headroom for the query + chat +# template overhead. The truncation is reranking-only — full chunk text +# still goes back to the LLM caller. +RERANK_DOC_MAX_CHARS = 2000 + + +def _rerank(query: str, candidates: list[tuple[str, str]]) -> list[str] | None: + """Call the llama.cpp /v1/rerank endpoint and return the candidate + chunk ids in reranker-preferred order. + + Args: + query: the user's natural-language query + candidates: list of ``(chunk_id, chunk_text)`` to rerank. + + Returns: + A list of chunk_ids ordered best-first by reranker score, OR + ``None`` if reranking is disabled, the endpoint is unreachable, + or any other error. The caller treats ``None`` as "fall back to + the input ranking" — rerank failures must NEVER block a search. + + Anti-hallucination: rerank only reorders chunks the retrievers + already surfaced. It cannot introduce content not in the corpus. + """ + if not RERANK_URL or not candidates: + return None + try: + import httpx + except ImportError: + return None + + # Truncate each doc to fit the per-pair token budget. jina-reranker + # rejects the entire batch on any oversize doc. + docs = [(text[:RERANK_DOC_MAX_CHARS] if text else "") for _cid, text in candidates] + ids = [cid for cid, _ in candidates] + + try: + with httpx.Client(timeout=RERANK_TIMEOUT) as c: + r = c.post( + f"{RERANK_URL}/v1/rerank", + json={ + "model": "rerank", # llama.cpp ignores this; jina passes through + "query": query, + "documents": docs, + }, + ) + r.raise_for_status() + payload = r.json() + except Exception as exc: # noqa: BLE001 + log.warning("rerank request failed (%s) — falling back to input order", exc) + return None + + results = payload.get("results") or [] + if not results: + log.warning("rerank returned empty results — falling back to input order") + return None + + # llama.cpp returns results as [{"index": int, "relevance_score": float}, ...] + # Higher relevance_score = better; sort descending. + try: + ordered = sorted(results, key=lambda r: -r.get("relevance_score", float("-inf"))) + return [ids[r["index"]] for r in ordered if 0 <= r.get("index", -1) < len(ids)] + except (KeyError, IndexError, TypeError) as exc: + log.warning("rerank response malformed (%s) — falling back to input order", exc) + return None + + def _structured_ratings_block(sidecar: dict) -> str: """Render the sidecar's grouped characteristics + identity as a fact-checkable block, with the source URL pinned at top. @@ -534,6 +603,34 @@ def search_docs( else: fuzzy_ids = dense_ids + # Optional reranker pass over the fuzzy pool BEFORE truncating + # to k. The cross-encoder is much more accurate at the query/ + # doc pairing than dense embedding alone, especially when the + # query mentions specific ag terms that share-token-cosine + # might miss. Skipped if RERANK_URL is unset or the call + # fails — search is never blocked on the sidecar. + used_rerank = False + if RERANK_URL and fuzzy_ids: + # Need docs to rerank — fetch any missing. + need = [i for i in fuzzy_ids if i not in id_to_doc] + if need: + try: + extra = col.get(ids=need[:RERANK_POOL], include=["documents", "metadatas"]) + for cid, doc, meta in zip( + extra.get("ids") or [], + extra.get("documents") or [], + extra.get("metadatas") or [], + ): + id_to_doc[cid] = doc + id_to_meta[cid] = meta + except Exception as exc: # noqa: BLE001 + log.warning("pre-rerank get-by-id failed: %s", exc) + pool = [(cid, id_to_doc.get(cid, "")) for cid in fuzzy_ids[:RERANK_POOL]] + reranked = _rerank(query, pool) + if reranked: + fuzzy_ids = reranked + [c for c in fuzzy_ids if c not in set(reranked)] + used_rerank = True + # Pin exact-code matches at top, then fill remainder from fuzzy # retrieval (deduped). Pinned matches are deterministic and # high-confidence; they should never lose to a fuzzy match. @@ -566,6 +663,7 @@ def search_docs( _call.set( hits_returned=len(final_ids), hybrid=used_hybrid, + reranked=used_rerank, pool_size=pool_size, ) @@ -885,6 +983,30 @@ def search_trials( else: fuzzy_ids = dense_ids + # Optional reranker pass over the fuzzy pool — same shape as + # in search_docs. Skipped silently if RERANK_URL is unset or + # the rerank call fails. + used_rerank = False + if RERANK_URL and fuzzy_ids: + need = [i for i in fuzzy_ids if i not in id_to_doc] + if need: + try: + extra = col.get(ids=need[:RERANK_POOL], include=["documents", "metadatas"]) + for cid, doc, meta in zip( + extra.get("ids") or [], + extra.get("documents") or [], + extra.get("metadatas") or [], + ): + id_to_doc[cid] = doc + id_to_meta[cid] = meta + except Exception as exc: # noqa: BLE001 + log.warning("pre-rerank get-by-id failed: %s", exc) + pool = [(cid, id_to_doc.get(cid, "")) for cid in fuzzy_ids[:RERANK_POOL]] + reranked = _rerank(full_query, pool) + if reranked: + fuzzy_ids = reranked + [c for c in fuzzy_ids if c not in set(reranked)] + used_rerank = True + # Optional product-substring post-filter: if user supplied # ``product``, require the chunk to actually contain the # token. This re-checks the bytes since BM25 only sees stems. @@ -931,6 +1053,7 @@ def search_trials( _call.set( hits_returned=len(final_ids), hybrid=used_hybrid, + reranked=used_rerank, pool_size=pool_size, data_type="trial", ) diff --git a/eval/queries.jsonl b/eval/queries.jsonl new file mode 100644 index 00000000..756182ea --- /dev/null +++ b/eval/queries.jsonl @@ -0,0 +1,21 @@ +{"query": "DKC62-08RIB ratings", "tool": "search_docs", "expected_source_keys": ["dekalb-dkc62-08rib"], "tags": ["variety-code", "exact"]} +{"query": "AG29XF4 disease ratings", "tool": "search_docs", "expected_source_keys": ["asgrow-ag29xf4"], "tags": ["variety-code", "exact"]} +{"query": "WB6430 westbred wheat", "tool": "search_docs", "expected_source_keys": ["westbred-wb6430"], "tags": ["variety-code", "exact"]} +{"query": "E085Z5 corn", "tool": "search_docs", "expected_source_keys": ["golden_harvest-e085z5"], "tags": ["variety-code", "exact", "golden-harvest"]} +{"query": "AP Iliad wheat performance", "tool": "search_docs", "expected_source_keys": ["agripro-ap-iliad"], "tags": ["variety-code", "agripro"]} +{"query": "drought tolerant corn for sandy soil short season Iowa", "tool": "search_docs", "expected_metadata": {"crop": "corn"}, "expected_substrings": ["drought"], "k": 5, "tags": ["semantic", "corn"]} +{"query": "soybean cyst nematode SCN resistant variety", "tool": "search_docs", "expected_metadata": {"crop": "soybeans"}, "expected_substrings": ["soybean cyst nematode", "SCN"], "k": 5, "tags": ["semantic", "soybean", "disease"]} +{"query": "Phytophthora resistance Rps3a soybean", "tool": "search_docs", "expected_metadata": {"crop": "soybeans"}, "expected_substrings": ["Rps3a"], "k": 5, "tags": ["semantic", "soybean", "gene"]} +{"query": "XtendFlex soybean Northern Plains", "tool": "search_docs", "expected_metadata": {"crop": "soybeans"}, "expected_substrings": ["XF", "XtendFlex"], "k": 5, "tags": ["trait", "soybean"]} +{"query": "Hard Red Spring wheat stripe rust resistance", "tool": "search_docs", "expected_metadata": {"crop": "wheat"}, "expected_substrings": ["stripe rust", "Hard Red Spring", "HRS"], "k": 5, "tags": ["wheat", "disease", "class"]} +{"query": "Soft White Winter wheat Pacific Northwest", "tool": "search_docs", "expected_metadata": {"crop": "wheat"}, "expected_substrings": ["Soft White Winter", "SWW"], "k": 5, "tags": ["wheat", "class"]} +{"query": "Goss's Wilt resistance corn", "tool": "search_docs", "expected_metadata": {"crop": "corn"}, "expected_substrings": ["Goss"], "k": 5, "tags": ["corn", "disease"]} +{"query": "best corn 2024 Iowa", "tool": "search_trials", "filters": {"crop": "corn", "state": "IA", "year": 2024}, "expected_metadata": {"crop": "corn", "year": 2024, "state": "IA"}, "k": 5, "tags": ["trial", "regional"]} +{"query": "Indiana corn yield comparison 2024", "tool": "search_trials", "filters": {"crop": "corn", "state": "IN", "year": 2024}, "expected_metadata": {"crop": "corn", "year": 2024, "state": "IN"}, "k": 5, "tags": ["trial", "regional"]} +{"query": "AP Iliad Idaho wheat trial", "tool": "search_trials", "filters": {"crop": "wheat"}, "expected_substrings": ["AP Iliad", "Idaho"], "k": 3, "tags": ["trial", "wheat", "agripro"]} +{"query": "DKC65-95 corn yield in trials", "tool": "search_trials", "filters": {"crop": "corn", "product": "DKC65-95"}, "expected_substrings": ["DKC65-95"], "k": 3, "tags": ["trial", "product-filter"]} +{"query": "NK1701 corn trials head to head", "tool": "search_trials", "filters": {"crop": "corn", "product": "NK1701"}, "expected_substrings": ["NK1701"], "k": 3, "tags": ["trial", "cross-vendor", "product-filter"]} +{"query": "silage corn high milk per acre dairy", "tool": "search_trials", "filters": {"crop": "silage"}, "expected_metadata": {"crop": "silage"}, "expected_substrings": ["Milk Per Acre", "Ton/Acre"], "k": 5, "tags": ["trial", "silage"]} +{"query": "soybean 2025 Minnesota top performers", "tool": "search_trials", "filters": {"crop": "soybeans", "state": "MN", "year": 2025}, "expected_metadata": {"crop": "soybeans", "state": "MN", "year": 2025}, "k": 5, "tags": ["trial", "regional"]} +{"query": "Pioneer P1142 hybrid recommendation", "tool": "search_docs", "must_not_contain_source_keys": ["pioneer", "p1142"], "expect_lessons_call": true, "tags": ["pioneer-fallback", "anti-hallucination"]} +{"query": "DKC65-20 yield Alabama trial", "tool": "search_trials", "filters": {"crop": "corn", "product": "DKC65-20"}, "expected_empty": true, "tags": ["trial", "not-in-corpus", "anti-hallucination"]} diff --git a/eval/results/baseline.md b/eval/results/baseline.md new file mode 100644 index 00000000..1d23700c --- /dev/null +++ b/eval/results/baseline.md @@ -0,0 +1,41 @@ +# seed-mcp retrieval eval — k=5 + +_21 golden queries × 4 retrievers_ + +## Summary + +| Retriever | Passed | Recall | P@1 | MRR | Avg ms | +|---|---|---|---|---|---| +| **hybrid+rerank** | 21/21 | 100.00% | 90.48% | 0.905 | 2064 | +| **bm25** | 20/21 | 95.24% | 80.95% | 0.833 | 5 | +| **hybrid** | 15/21 | 71.43% | 61.90% | 0.619 | 73 | +| **dense** | 14/21 | 66.67% | 38.10% | 0.440 | 79 | + +**Recall** = % of queries where ≥1 top-k chunk satisfied the spec. **P@1** = % where the very first result satisfied it. **MRR** = mean of `1 / rank-of-first-satisfying-result` (0 if missed). + +## Per-query results + +| Query | bm25 | dense | hybrid | hybrid+rerank | +|---|---|---|---|---| +| `DKC62-08RIB ratings` | ✅ #1 | ❌ | ❌ | ✅ #1 | +| `AG29XF4 disease ratings` | ✅ #1 | ❌ | ❌ | ✅ #1 | +| `WB6430 westbred wheat` | ✅ #1 | ❌ | ❌ | ✅ #1 | +| `E085Z5 corn` | ✅ #1 | ❌ | ❌ | ✅ #1 | +| `AP Iliad wheat performance` | ✅ #1 | ❌ | ❌ | ✅ #1 | +| `drought tolerant corn for sandy soil short season Iowa` | ✅ #2 | ✅ #1 | ✅ #1 | ✅ #1 | +| `soybean cyst nematode SCN resistant variety` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 | +| `Phytophthora resistance Rps3a soybean` | ✅ #1 | ✅ #2 | ✅ #1 | ✅ #1 | +| `XtendFlex soybean Northern Plains` | ❌ | ✅ #1 | ✅ #1 | ✅ #1 | +| `Hard Red Spring wheat stripe rust resistance` | ✅ #1 | ✅ #3 | ✅ #1 | ✅ #1 | +| `Soft White Winter wheat Pacific Northwest` | ✅ #1 | ✅ #5 | ✅ #1 | ✅ #1 | +| `Goss's Wilt resistance corn` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 | +| `best corn 2024 Iowa` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 | +| `Indiana corn yield comparison 2024` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 | +| `AP Iliad Idaho wheat trial` | ✅ #1 | ✅ #5 | ✅ #1 | ✅ #1 | +| `DKC65-95 corn yield in trials` | ✅ #1 | ❌ | ✅ #1 | ✅ #1 | +| `NK1701 corn trials head to head` | ✅ #1 | ❌ | ❌ | ✅ #1 | +| `silage corn high milk per acre dairy` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 | +| `soybean 2025 Minnesota top performers` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 | +| `Pioneer P1142 hybrid recommendation` | ✅ | ✅ | ✅ | ✅ | +| `DKC65-20 yield Alabama trial` | ✅ | ✅ | ✅ | ✅ | + diff --git a/eval/retrievers.py b/eval/retrievers.py index bc06a182..4aeebb83 100644 --- a/eval/retrievers.py +++ b/eval/retrievers.py @@ -1,62 +1,200 @@ -"""Retriever protocol + concrete implementations. +"""Retriever protocol + concrete implementations for seed-mcp eval. -A single matrix dimension per knob (dense / reranked / bm25 / hybrid) -so the eval harness can compare them apples-to-apples. Implement these -once at Phase 7 and reuse them across every retrieval change. +Each retriever returns a ranked list of chunk_ids. The eval harness +in ``run_eval.py`` measures each retriever against the golden +``queries.jsonl`` set across MRR / Recall@K / nDCG@K. -Each retriever returns a ranked list of (bundle_id, page_id) tuples -deduplicated to the page level (chunks within the same page collapse -to one entry; the highest-ranked chunk's position wins). +Four named configurations, matching the four switches in +``docs_mcp/server.py``: + + dense — Chroma dense retrieval alone + bm25 — SQLite FTS5 BM25 alone + hybrid — dense + bm25 fused via RRF + hybrid_rerank — hybrid pool → cross-encoder rerank + +Each retriever takes ``filters`` (the same dict shape +``_build_where`` accepts in server.py) so trial-specific facets +(data_type, state, year, crop) work consistently across the +four configurations. """ from __future__ import annotations -from typing import Protocol, Iterable +import os +import sys +from pathlib import Path +from typing import Protocol + + +# Add repo root so we can import docs_mcp and rag from here. +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) class Retriever(Protocol): name: str - def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: - """Return up to k (bundle_id, page_id) tuples in rank order.""" + def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: + """Return up to k chunk_ids in rank order.""" ... -def _collapse_to_pages(chunk_ids: Iterable[tuple[str, str, str]], k: int) -> list[tuple[str, str]]: - """Take a stream of (bundle_id, page_id, chunk_ordinal) and return - the first k unique pages in their first-seen order.""" - seen: set[tuple[str, str]] = set() - out: list[tuple[str, str]] = [] - for bid, pid, _ord in chunk_ids: - key = (bid, pid) - if key in seen: - continue - seen.add(key) - out.append(key) - if len(out) >= k: - break +def _build_where(filters: dict | None) -> dict | None: + """Mirror of docs_mcp.server._build_where but accepts the eval's + looser shape.""" + if not filters: + return None + conds: list[dict] = [] + if filters.get("data_type"): + conds.append({"data_type": filters["data_type"]}) + if filters.get("crop"): + conds.append({"crop": filters["crop"].lower()}) + if filters.get("brand"): + conds.append({"brand": filters["brand"].upper()}) + if filters.get("state"): + s = filters["state"] + conds.append({"state": s.upper() if len(s) <= 3 else s}) + if filters.get("year"): + conds.append({"year": int(filters["year"])}) + if not conds: + return None + if len(conds) == 1: + return conds[0] + return {"$and": conds} + + +class DenseRetriever: + name = "dense" + + def __init__(self, collection, pool: int = 50): + self.col = collection + self.pool = pool + + def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: + where = _build_where(filters) + try: + r = self.col.query( + query_texts=[query], n_results=max(k, self.pool), where=where, + ) + except Exception: + return [] + return (r.get("ids") or [[]])[0][:k] + + +class BM25Retriever: + name = "bm25" + + def __init__(self, bm25, pool: int = 50): + self.bm25 = bm25 + self.pool = pool + + def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: + where = _build_where(filters) + hits = self.bm25.query(query, n=max(k, self.pool), where=where) + return [cid for cid, _ in hits[:k]] + + +class HybridRetriever: + """Dense + BM25 fused via RRF — same fusion the server uses.""" + name = "hybrid" + + def __init__(self, collection, bm25, pool: int = 50, rrf_k: int = 60): + self.col = collection + self.bm25 = bm25 + self.pool = pool + self.rrf_k = rrf_k + + def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: + where = _build_where(filters) + try: + d = self.col.query(query_texts=[query], n_results=self.pool, where=where) + dense_ids = (d.get("ids") or [[]])[0] + except Exception: + dense_ids = [] + bm25_ids = [c for c, _ in self.bm25.query(query, n=self.pool, where=where)] + scores: dict[str, float] = {} + for ranking in (dense_ids, bm25_ids): + for rank, cid in enumerate(ranking): + scores[cid] = scores.get(cid, 0.0) + 1.0 / (self.rrf_k + rank + 1) + fused = sorted(scores, key=lambda d: scores[d], reverse=True) + return fused[:k] + + +class HybridRerankRetriever: + """Hybrid pool → cross-encoder rerank via the llama.cpp endpoint.""" + name = "hybrid+rerank" + + def __init__(self, collection, bm25, rerank_url: str, + pool: int = 50, rerank_pool: int = 50, + rrf_k: int = 60, doc_max_chars: int = 2000, + timeout: float = 30.0): + self.col = collection + self.bm25 = bm25 + self.rerank_url = rerank_url.rstrip("/") + self.pool = pool + self.rerank_pool = rerank_pool + self.rrf_k = rrf_k + self.doc_max_chars = doc_max_chars + self.timeout = timeout + + def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]: + where = _build_where(filters) + try: + d = self.col.query( + query_texts=[query], n_results=self.pool, where=where, + include=["documents"], + ) + dense_ids = (d.get("ids") or [[]])[0] + dense_docs = (d.get("documents") or [[]])[0] + id_to_doc = dict(zip(dense_ids, dense_docs)) + except Exception: + dense_ids = [] + id_to_doc = {} + bm25_ids = [c for c, _ in self.bm25.query(query, n=self.pool, where=where)] + + # Fuse to a hybrid pool + scores: dict[str, float] = {} + for ranking in (dense_ids, bm25_ids): + for rank, cid in enumerate(ranking): + scores[cid] = scores.get(cid, 0.0) + 1.0 / (self.rrf_k + rank + 1) + fused = sorted(scores, key=lambda d: scores[d], reverse=True) + + # Fetch docs for any BM25-only ids in the rerank pool + missing = [cid for cid in fused[: self.rerank_pool] if cid not in id_to_doc] + if missing: + try: + extra = self.col.get(ids=missing, include=["documents"]) + for cid, doc in zip(extra.get("ids") or [], extra.get("documents") or []): + id_to_doc[cid] = doc + except Exception: + pass + + # Rerank + pool_ids = fused[: self.rerank_pool] + docs = [(id_to_doc.get(cid, "") or "")[: self.doc_max_chars] for cid in pool_ids] + try: + import httpx + with httpx.Client(timeout=self.timeout) as c: + r = c.post( + f"{self.rerank_url}/v1/rerank", + json={"model": "rerank", "query": query, "documents": docs}, + ) + r.raise_for_status() + results = r.json().get("results") or [] + if not results: + return fused[:k] + ordered = sorted(results, key=lambda x: -x.get("relevance_score", float("-inf"))) + reranked = [pool_ids[x["index"]] for x in ordered if 0 <= x.get("index", -1) < len(pool_ids)] + return reranked[:k] + except Exception: + return fused[:k] + + +def build_all_retrievers(collection, bm25, rerank_url: str | None = None) -> list[Retriever]: + """Return the four named retrievers ready to evaluate.""" + out: list[Retriever] = [ + DenseRetriever(collection), + BM25Retriever(bm25), + HybridRetriever(collection, bm25), + ] + if rerank_url: + out.append(HybridRerankRetriever(collection, bm25, rerank_url)) return out - - -# TODO Phase 2/3 — implement these once Chroma + the bm25 module are -# in place. Each one is small (15-30 LOC). The eval harness imports -# from this module by class name. -# -# class DenseRetriever: -# name = "dense" -# def __init__(self, collection): self.col = collection -# def retrieve(self, query, k=10): ... -# -# class RerankedRetriever: -# name = "dense+rerank" -# def __init__(self, collection, rerank_url, pool=200): ... -# def retrieve(self, query, k=10): ... -# -# class BM25Retriever: -# name = "bm25" -# def __init__(self, bm25_index): ... -# def retrieve(self, query, k=10): ... -# -# class HybridRetriever: -# name = "bm25+dense+rrf" -# def __init__(self, dense, bm25, k_rrf=60): ... -# def retrieve(self, query, k=10): ... diff --git a/eval/run_eval.py b/eval/run_eval.py index 9ba3aa6d..9c8f40d1 100644 --- a/eval/run_eval.py +++ b/eval/run_eval.py @@ -1,32 +1,60 @@ """Run all retrievers against eval/queries.jsonl, emit a markdown report. +For seed-mcp, the "expected" answer for many queries isn't a single +chunk — it's "a chunk satisfying these constraints." So per-query +scoring is one of: + + expected_source_keys — at least one of these source_keys appears + in top-k (used for variety-code queries + with a single canonical answer) + expected_metadata — all top-k must match these key=value + constraints (e.g. crop=corn, year=2024) + expected_substrings — at least one top-k chunk's text/metadata + contains each substring (e.g. "SCN" must + appear when querying SCN resistance) + must_not_contain_source_keys — anti-hallucination: NO top-k chunk's + source_key may contain these tokens + (Pioneer fallback queries) + expected_empty — top-k MUST be empty (anti-hallucination) + expect_lessons_call — the agent should call api_lessons; not + measurable from retrieval alone, recorded + as an advisory note + Metrics computed per retriever: - MRR — mean reciprocal rank of the FIRST expected page in the - ranked result list (0 if not in top-k). - Recall@K — fraction of expected pages that appear in top-K. - nDCG@K — discounted gain weighted by rank position. + recall_known — fraction of queries where the retriever returned + a chunk satisfying the query's expectations + precision_top1 — fraction of queries where the FIRST result + satisfied expectations + mrr — mean reciprocal rank of the FIRST satisfying chunk -The "right" number depends on what you're measuring. MRR tracks "the -first-line answer is correct"; Recall@K tracks "everything relevant -is there to draw from"; nDCG@K is a smoother combination of both. -For docs-RAG, MRR is usually the headline metric. +Plus a per-query breakdown table so you can see exactly where each +retriever wins or loses. Usage: - python -m eval.run_eval \\ --queries eval/queries.jsonl \\ --k 5 \\ + --rerank-url http://localhost:18080 \\ --output eval/results/baseline.md """ from __future__ import annotations import argparse import json -import math +import logging +import os +import sys import time from pathlib import Path -from typing import Iterable + +# Add repo root for imports +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from eval.retrievers import build_all_retrievers # noqa: E402 + +logging.getLogger("chromadb").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) def load_queries(path: Path) -> list[dict]: @@ -34,31 +62,203 @@ def load_queries(path: Path) -> list[dict]: return [json.loads(line) for line in fh if line.strip()] -def reciprocal_rank(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]]) -> float: - expected_set = set(expected) - for i, page in enumerate(retrieved, start=1): - if page in expected_set: - return 1.0 / i - return 0.0 +def _doc_satisfies(meta: dict, doc: str, query_spec: dict) -> bool: + """Does this single retrieved (metadata, doc) tuple satisfy the + query spec? Used by the 'first satisfying' metric.""" + sk = meta.get("source_key") or "" + # exact source_key match + if "expected_source_keys" in query_spec: + for want in query_spec["expected_source_keys"]: + if want.lower() == sk.lower(): + return True + return False + # all metadata constraints match + if "expected_metadata" in query_spec: + for k, v in query_spec["expected_metadata"].items(): + mv = meta.get(k) + if isinstance(v, int): + if mv != v: + return False + else: + if (mv or "").lower() != str(v).lower(): + return False + # if no substring requirement, metadata match is enough + if "expected_substrings" not in query_spec: + return True + # at least one substring present (in doc OR metadata values) + if "expected_substrings" in query_spec: + haystack = (doc + " " + " ".join(str(v) for v in meta.values())).lower() + return any(s.lower() in haystack for s in query_spec["expected_substrings"]) + return False -def recall_at_k(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]], k: int) -> float: - if not expected: - return 0.0 - retrieved_set = set(retrieved[:k]) - hits = sum(1 for e in expected if e in retrieved_set) - return hits / len(expected) +def _evaluate_one(retriever, query_spec: dict, k: int, col) -> dict: + """Return per-query metrics for one retriever.""" + query = query_spec["query"] + filters = dict(query_spec.get("filters") or {}) + # search_trials queries imply data_type=trial; search_docs implies variety + tool = query_spec.get("tool", "search_docs") + if tool == "search_trials": + filters.setdefault("data_type", "trial") + elif tool == "search_docs": + filters.setdefault("data_type", "variety") + # 'product' is a server-side post-filter, not Chroma; strip + product = filters.pop("product", None) + + t0 = time.monotonic() + ids = retriever.retrieve(query, k, filters) + elapsed_ms = (time.monotonic() - t0) * 1000 + + # Anti-hallucination queries: expected_empty should return nothing + # (BUT we still allow the retriever to surface chunks if the + # product filter would filter them out at the server level — so + # we re-apply the product filter here). + if product: + try: + extra = col.get(ids=ids, include=["documents"]) + id_to_doc = dict(zip(extra.get("ids") or [], extra.get("documents") or [])) + except Exception: + id_to_doc = {} + ids = [cid for cid in ids if product.lower() in id_to_doc.get(cid, "").lower()] + + if query_spec.get("expected_empty"): + passed = len(ids) == 0 + return { + "query": query, "retriever": retriever.name, + "k": k, "n_hits": len(ids), "rank_first_match": None, + "passed": passed, "elapsed_ms": round(elapsed_ms, 1), + "kind": "expected_empty", + } + + if "must_not_contain_source_keys" in query_spec: + bad_tokens = [t.lower() for t in query_spec["must_not_contain_source_keys"]] + try: + extra = col.get(ids=ids, include=["metadatas"]) + metas = extra.get("metadatas") or [] + except Exception: + metas = [] + # PASS = no top-k chunk's source_key contains a forbidden token + for m in metas: + sk = (m.get("source_key") or "").lower() + if any(t in sk for t in bad_tokens): + return { + "query": query, "retriever": retriever.name, + "k": k, "n_hits": len(ids), "rank_first_match": None, + "passed": False, "elapsed_ms": round(elapsed_ms, 1), + "kind": "must_not_contain", + } + return { + "query": query, "retriever": retriever.name, + "k": k, "n_hits": len(ids), "rank_first_match": None, + "passed": True, "elapsed_ms": round(elapsed_ms, 1), + "kind": "must_not_contain", + } + + # Positive-match query: pull docs+meta and check each + try: + extra = col.get(ids=ids, include=["documents", "metadatas"]) + docs = extra.get("documents") or [] + metas = extra.get("metadatas") or [] + ext_ids = extra.get("ids") or [] + order_idx = {cid: i for i, cid in enumerate(ext_ids)} + except Exception: + docs = [] + metas = [] + order_idx = {} + + rank_first = None + for rank, cid in enumerate(ids, start=1): + i = order_idx.get(cid) + if i is None: + continue + if _doc_satisfies(metas[i], docs[i], query_spec): + rank_first = rank + break + + return { + "query": query, "retriever": retriever.name, + "k": k, "n_hits": len(ids), + "rank_first_match": rank_first, + "passed": rank_first is not None, + "elapsed_ms": round(elapsed_ms, 1), + "kind": "positive", + } -def ndcg_at_k(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]], k: int) -> float: - expected_set = set(expected) - dcg = 0.0 - for i, page in enumerate(retrieved[:k], start=1): - if page in expected_set: - dcg += 1.0 / math.log2(i + 1) - # Ideal DCG: every expected page in the top positions. - idcg = sum(1.0 / math.log2(i + 1) for i in range(1, min(len(expected), k) + 1)) - return dcg / idcg if idcg else 0.0 +def _aggregate(results: list[dict]) -> dict: + """Aggregate per-query results into MRR / recall / precision@1.""" + by_retriever: dict[str, list[dict]] = {} + for r in results: + by_retriever.setdefault(r["retriever"], []).append(r) + out: dict[str, dict] = {} + for name, rows in by_retriever.items(): + n = len(rows) + passed = sum(1 for r in rows if r["passed"]) + ranks = [r["rank_first_match"] for r in rows + if r["passed"] and r.get("rank_first_match")] + mrr = sum(1.0 / r for r in ranks) / n if n else 0.0 + precision1 = sum(1 for r in rows if r["passed"] and r.get("rank_first_match") == 1) / n if n else 0.0 + avg_ms = sum(r["elapsed_ms"] for r in rows) / n if n else 0.0 + out[name] = { + "n_queries": n, + "passed": passed, + "recall_known": passed / n if n else 0.0, + "precision_top1": precision1, + "mrr": mrr, + "avg_latency_ms": round(avg_ms, 1), + } + return out + + +def _emit_markdown(queries: list[dict], results: list[dict], + summary: dict, k: int) -> str: + lines: list[str] = [] + lines.append(f"# seed-mcp retrieval eval — k={k}") + lines.append("") + lines.append(f"_{len(queries)} golden queries × {len(summary)} retrievers_") + lines.append("") + lines.append("## Summary") + lines.append("") + lines.append("| Retriever | Passed | Recall | P@1 | MRR | Avg ms |") + lines.append("|---|---|---|---|---|---|") + for name in sorted(summary, key=lambda n: -summary[n]["mrr"]): + s = summary[name] + lines.append( + f"| **{name}** | {s['passed']}/{s['n_queries']} " + f"| {s['recall_known']:.2%} | {s['precision_top1']:.2%} " + f"| {s['mrr']:.3f} | {s['avg_latency_ms']:.0f} |" + ) + lines.append("") + lines.append("**Recall** = % of queries where ≥1 top-k chunk satisfied the spec. " + "**P@1** = % where the very first result satisfied it. " + "**MRR** = mean of `1 / rank-of-first-satisfying-result` (0 if missed).") + lines.append("") + + # Per-query breakdown + lines.append("## Per-query results") + lines.append("") + by_query: dict[str, list[dict]] = {} + for r in results: + by_query.setdefault(r["query"], []).append(r) + retriever_names = sorted({r["retriever"] for r in results}) + header = "| Query | " + " | ".join(retriever_names) + " |" + sep = "|" + "---|" * (len(retriever_names) + 1) + lines.append(header) + lines.append(sep) + for q in queries: + cells = [f"`{q['query'][:60]}`"] + for name in retriever_names: + r = next((x for x in by_query.get(q["query"], []) if x["retriever"] == name), None) + if r is None: + cells.append("?") + elif r["passed"]: + rk = r.get("rank_first_match") + cells.append(f"✅ #{rk}" if rk else "✅") + else: + cells.append("❌") + lines.append("| " + " | ".join(cells) + " |") + lines.append("") + return "\n".join(lines) + "\n" def main() -> int: @@ -66,25 +266,56 @@ def main() -> int: p.add_argument("--queries", type=Path, default=Path("eval/queries.jsonl")) p.add_argument("--k", type=int, default=5) p.add_argument("--output", type=Path, default=Path("eval/results/baseline.md")) + p.add_argument("--rerank-url", default=os.environ.get("RERANK_URL", "")) + p.add_argument("--product-name", default=os.environ.get("PRODUCT_NAME", "crop_seed")) args = p.parse_args() if not args.queries.exists(): print(f"queries file not found: {args.queries}") - print("hint: copy eval/queries.jsonl.example and edit") return 1 queries = load_queries(args.queries) print(f"loaded {len(queries)} queries") - # TODO Phase 7: instantiate the retrievers you implemented in - # eval/retrievers.py and run each one against each query. - # Aggregate MRR / Recall@K / nDCG@K per retriever. Emit a - # markdown table to args.output. Commit the file alongside the - # PR that changes retrieval. - raise NotImplementedError( - "Wire up the retrievers in eval/retrievers.py first, then " - "fill in this evaluation loop. See PLAN.md Phase 7." + # Connect to Chroma + BM25 + import chromadb + from chromadb.config import Settings + from rag.embeddings import embedding_function + from rag.bm25 import BM25Index + + repo_root = Path(__file__).resolve().parent.parent + client = chromadb.PersistentClient( + path=str(repo_root / "chroma"), + settings=Settings(anonymized_telemetry=False), ) + col = client.get_collection(f"{args.product_name}_docs", + embedding_function=embedding_function()) + bm25 = BM25Index(repo_root / "bm25" / f"{args.product_name}_docs.db") + print(f"chroma: {col.count()} chunks; bm25: {bm25.count()} chunks") + + retrievers = build_all_retrievers(col, bm25, args.rerank_url or None) + print(f"retrievers: {[r.name for r in retrievers]}") + + all_results: list[dict] = [] + for r in retrievers: + print(f"running {r.name}...") + for q in queries: + res = _evaluate_one(r, q, args.k, col) + all_results.append(res) + + summary = _aggregate(all_results) + md = _emit_markdown(queries, all_results, summary, args.k) + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(md, encoding="utf-8") + print(f"\nreport: {args.output}") + print() + # Print summary to stdout too + for line in md.split("\n"): + if line.startswith("|"): + print(line) + if line.startswith("## Per-query"): + break + return 0 if __name__ == "__main__":