Phase 6/7: wire rerank + eval harness — 100% pass on 21 golden queries
Phase 6 — Reranker integration
- New _rerank(query, [(cid, doc), ...]) helper in server.py calls
llama.cpp's /v1/rerank endpoint, returns reranker-ordered ids
or None on failure (graceful fallback — search never blocks
on the sidecar).
- search_docs + search_trials both call _rerank() on the post-
hybrid pool BEFORE truncating to k. The variety-code prefilter
still pins exact matches on top.
- Per-doc truncation to 2000 chars to fit jina-reranker-v2-base's
per-pair token budget. Full chunk text still returned to the
caller — truncation is rerank-input-only.
- Telemetry adds `reranked: true|false` so usage logs distinguish
reranked calls.
Phase 7 — Eval harness
- eval/queries.jsonl: 21 golden queries spanning:
* variety-code lookups (DKC62-08RIB, AG29XF4, WB6430, E085Z5,
AP Iliad)
* semantic variety queries (drought-tolerant corn, SCN MG-3
soy, Rps3a, XtendFlex, HRS stripe rust, SWW PNW, Goss's Wilt)
* trial queries (IA/IN/MN regional, AP Iliad ID, NK1701 head-
to-head, silage Ton/Acre, product=DKC65-95)
* anti-hallucination (Pioneer P1142 fallback, DKC65-20 not-in-
corpus expected_empty)
- eval/retrievers.py: 4 named retrievers — dense, bm25, hybrid
(dense+bm25+RRF), hybrid+rerank — all sharing the same filter
shape as docs_mcp/server.py._build_where.
- eval/run_eval.py: runs each retriever against each query,
reports Recall / Precision@1 / MRR / avg latency. Markdown
output in eval/results/baseline.md.
Baseline results (k=5, 21 queries):
| Retriever | Pass | Recall | P@1 | MRR | Avg ms |
|-----------------|-------|--------|-------|-------|--------|
| hybrid+rerank | 21/21 | 100% | 90% | 0.905 | 2064 |
| bm25 | 20/21 | 95% | 81% | 0.833 | 5 |
| hybrid | 15/21 | 71% | 62% | 0.619 | 73 |
| dense | 14/21 | 67% | 38% | 0.440 | 79 |
Key findings:
1. hybrid+rerank wins on quality — 100% pass, 90% P@1.
2. BM25 alone is surprisingly competitive (95% pass) at 5 ms —
excellent fallback when rerank is down. The variety-code
prefilter in search_docs is doing a lot of work here.
3. Dense embedding alone is the WEAKEST configuration on this
corpus — variety identity tokens (DKC62-08RIB, AP Iliad,
Rps3a) have no semantic neighbors, so nomic-embed-text returns
noise. The hybrid (no rerank) layer actively hurts because
RRF dilutes the BM25 ranking with dense noise.
4. Anti-hallucination queries (Pioneer fallback, DKC65-20 not-
in-corpus) pass on ALL retrievers including dense-only —
the must_not_contain + expected_empty design holds.
Deploy decision: HYBRID_SEARCH=true + RERANK_URL set
(production env already has both — refresh.yml + image-only.yml
+ deploy/docker-compose.yml all configured).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -289,6 +289,75 @@ def _rrf_fuse(rankings: list[list[str]], k: int = RRF_K) -> list[str]:
|
||||
return sorted(scores, key=lambda d: scores[d], reverse=True)
|
||||
|
||||
|
||||
# Per-doc character cap when sending to the reranker. jina-reranker-v2-base
|
||||
# accepts up to ~1024 tokens PER QUERY+DOC PAIR (n_ctx_train) and rejects
|
||||
# the WHOLE BATCH if any one pair exceeds it. Truncating each doc to
|
||||
# ~2000 chars (≈ 500-700 tokens) leaves headroom for the query + chat
|
||||
# template overhead. The truncation is reranking-only — full chunk text
|
||||
# still goes back to the LLM caller.
|
||||
RERANK_DOC_MAX_CHARS = 2000
|
||||
|
||||
|
||||
def _rerank(query: str, candidates: list[tuple[str, str]]) -> list[str] | None:
|
||||
"""Call the llama.cpp /v1/rerank endpoint and return the candidate
|
||||
chunk ids in reranker-preferred order.
|
||||
|
||||
Args:
|
||||
query: the user's natural-language query
|
||||
candidates: list of ``(chunk_id, chunk_text)`` to rerank.
|
||||
|
||||
Returns:
|
||||
A list of chunk_ids ordered best-first by reranker score, OR
|
||||
``None`` if reranking is disabled, the endpoint is unreachable,
|
||||
or any other error. The caller treats ``None`` as "fall back to
|
||||
the input ranking" — rerank failures must NEVER block a search.
|
||||
|
||||
Anti-hallucination: rerank only reorders chunks the retrievers
|
||||
already surfaced. It cannot introduce content not in the corpus.
|
||||
"""
|
||||
if not RERANK_URL or not candidates:
|
||||
return None
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
# Truncate each doc to fit the per-pair token budget. jina-reranker
|
||||
# rejects the entire batch on any oversize doc.
|
||||
docs = [(text[:RERANK_DOC_MAX_CHARS] if text else "") for _cid, text in candidates]
|
||||
ids = [cid for cid, _ in candidates]
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=RERANK_TIMEOUT) as c:
|
||||
r = c.post(
|
||||
f"{RERANK_URL}/v1/rerank",
|
||||
json={
|
||||
"model": "rerank", # llama.cpp ignores this; jina passes through
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("rerank request failed (%s) — falling back to input order", exc)
|
||||
return None
|
||||
|
||||
results = payload.get("results") or []
|
||||
if not results:
|
||||
log.warning("rerank returned empty results — falling back to input order")
|
||||
return None
|
||||
|
||||
# llama.cpp returns results as [{"index": int, "relevance_score": float}, ...]
|
||||
# Higher relevance_score = better; sort descending.
|
||||
try:
|
||||
ordered = sorted(results, key=lambda r: -r.get("relevance_score", float("-inf")))
|
||||
return [ids[r["index"]] for r in ordered if 0 <= r.get("index", -1) < len(ids)]
|
||||
except (KeyError, IndexError, TypeError) as exc:
|
||||
log.warning("rerank response malformed (%s) — falling back to input order", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _structured_ratings_block(sidecar: dict) -> str:
|
||||
"""Render the sidecar's grouped characteristics + identity as a
|
||||
fact-checkable block, with the source URL pinned at top.
|
||||
@@ -534,6 +603,34 @@ def search_docs(
|
||||
else:
|
||||
fuzzy_ids = dense_ids
|
||||
|
||||
# Optional reranker pass over the fuzzy pool BEFORE truncating
|
||||
# to k. The cross-encoder is much more accurate at the query/
|
||||
# doc pairing than dense embedding alone, especially when the
|
||||
# query mentions specific ag terms that share-token-cosine
|
||||
# might miss. Skipped if RERANK_URL is unset or the call
|
||||
# fails — search is never blocked on the sidecar.
|
||||
used_rerank = False
|
||||
if RERANK_URL and fuzzy_ids:
|
||||
# Need docs to rerank — fetch any missing.
|
||||
need = [i for i in fuzzy_ids if i not in id_to_doc]
|
||||
if need:
|
||||
try:
|
||||
extra = col.get(ids=need[:RERANK_POOL], include=["documents", "metadatas"])
|
||||
for cid, doc, meta in zip(
|
||||
extra.get("ids") or [],
|
||||
extra.get("documents") or [],
|
||||
extra.get("metadatas") or [],
|
||||
):
|
||||
id_to_doc[cid] = doc
|
||||
id_to_meta[cid] = meta
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("pre-rerank get-by-id failed: %s", exc)
|
||||
pool = [(cid, id_to_doc.get(cid, "")) for cid in fuzzy_ids[:RERANK_POOL]]
|
||||
reranked = _rerank(query, pool)
|
||||
if reranked:
|
||||
fuzzy_ids = reranked + [c for c in fuzzy_ids if c not in set(reranked)]
|
||||
used_rerank = True
|
||||
|
||||
# Pin exact-code matches at top, then fill remainder from fuzzy
|
||||
# retrieval (deduped). Pinned matches are deterministic and
|
||||
# high-confidence; they should never lose to a fuzzy match.
|
||||
@@ -566,6 +663,7 @@ def search_docs(
|
||||
_call.set(
|
||||
hits_returned=len(final_ids),
|
||||
hybrid=used_hybrid,
|
||||
reranked=used_rerank,
|
||||
pool_size=pool_size,
|
||||
)
|
||||
|
||||
@@ -885,6 +983,30 @@ def search_trials(
|
||||
else:
|
||||
fuzzy_ids = dense_ids
|
||||
|
||||
# Optional reranker pass over the fuzzy pool — same shape as
|
||||
# in search_docs. Skipped silently if RERANK_URL is unset or
|
||||
# the rerank call fails.
|
||||
used_rerank = False
|
||||
if RERANK_URL and fuzzy_ids:
|
||||
need = [i for i in fuzzy_ids if i not in id_to_doc]
|
||||
if need:
|
||||
try:
|
||||
extra = col.get(ids=need[:RERANK_POOL], include=["documents", "metadatas"])
|
||||
for cid, doc, meta in zip(
|
||||
extra.get("ids") or [],
|
||||
extra.get("documents") or [],
|
||||
extra.get("metadatas") or [],
|
||||
):
|
||||
id_to_doc[cid] = doc
|
||||
id_to_meta[cid] = meta
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("pre-rerank get-by-id failed: %s", exc)
|
||||
pool = [(cid, id_to_doc.get(cid, "")) for cid in fuzzy_ids[:RERANK_POOL]]
|
||||
reranked = _rerank(full_query, pool)
|
||||
if reranked:
|
||||
fuzzy_ids = reranked + [c for c in fuzzy_ids if c not in set(reranked)]
|
||||
used_rerank = True
|
||||
|
||||
# Optional product-substring post-filter: if user supplied
|
||||
# ``product``, require the chunk to actually contain the
|
||||
# token. This re-checks the bytes since BM25 only sees stems.
|
||||
@@ -931,6 +1053,7 @@ def search_trials(
|
||||
_call.set(
|
||||
hits_returned=len(final_ids),
|
||||
hybrid=used_hybrid,
|
||||
reranked=used_rerank,
|
||||
pool_size=pool_size,
|
||||
data_type="trial",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
{"query": "DKC62-08RIB ratings", "tool": "search_docs", "expected_source_keys": ["dekalb-dkc62-08rib"], "tags": ["variety-code", "exact"]}
|
||||
{"query": "AG29XF4 disease ratings", "tool": "search_docs", "expected_source_keys": ["asgrow-ag29xf4"], "tags": ["variety-code", "exact"]}
|
||||
{"query": "WB6430 westbred wheat", "tool": "search_docs", "expected_source_keys": ["westbred-wb6430"], "tags": ["variety-code", "exact"]}
|
||||
{"query": "E085Z5 corn", "tool": "search_docs", "expected_source_keys": ["golden_harvest-e085z5"], "tags": ["variety-code", "exact", "golden-harvest"]}
|
||||
{"query": "AP Iliad wheat performance", "tool": "search_docs", "expected_source_keys": ["agripro-ap-iliad"], "tags": ["variety-code", "agripro"]}
|
||||
{"query": "drought tolerant corn for sandy soil short season Iowa", "tool": "search_docs", "expected_metadata": {"crop": "corn"}, "expected_substrings": ["drought"], "k": 5, "tags": ["semantic", "corn"]}
|
||||
{"query": "soybean cyst nematode SCN resistant variety", "tool": "search_docs", "expected_metadata": {"crop": "soybeans"}, "expected_substrings": ["soybean cyst nematode", "SCN"], "k": 5, "tags": ["semantic", "soybean", "disease"]}
|
||||
{"query": "Phytophthora resistance Rps3a soybean", "tool": "search_docs", "expected_metadata": {"crop": "soybeans"}, "expected_substrings": ["Rps3a"], "k": 5, "tags": ["semantic", "soybean", "gene"]}
|
||||
{"query": "XtendFlex soybean Northern Plains", "tool": "search_docs", "expected_metadata": {"crop": "soybeans"}, "expected_substrings": ["XF", "XtendFlex"], "k": 5, "tags": ["trait", "soybean"]}
|
||||
{"query": "Hard Red Spring wheat stripe rust resistance", "tool": "search_docs", "expected_metadata": {"crop": "wheat"}, "expected_substrings": ["stripe rust", "Hard Red Spring", "HRS"], "k": 5, "tags": ["wheat", "disease", "class"]}
|
||||
{"query": "Soft White Winter wheat Pacific Northwest", "tool": "search_docs", "expected_metadata": {"crop": "wheat"}, "expected_substrings": ["Soft White Winter", "SWW"], "k": 5, "tags": ["wheat", "class"]}
|
||||
{"query": "Goss's Wilt resistance corn", "tool": "search_docs", "expected_metadata": {"crop": "corn"}, "expected_substrings": ["Goss"], "k": 5, "tags": ["corn", "disease"]}
|
||||
{"query": "best corn 2024 Iowa", "tool": "search_trials", "filters": {"crop": "corn", "state": "IA", "year": 2024}, "expected_metadata": {"crop": "corn", "year": 2024, "state": "IA"}, "k": 5, "tags": ["trial", "regional"]}
|
||||
{"query": "Indiana corn yield comparison 2024", "tool": "search_trials", "filters": {"crop": "corn", "state": "IN", "year": 2024}, "expected_metadata": {"crop": "corn", "year": 2024, "state": "IN"}, "k": 5, "tags": ["trial", "regional"]}
|
||||
{"query": "AP Iliad Idaho wheat trial", "tool": "search_trials", "filters": {"crop": "wheat"}, "expected_substrings": ["AP Iliad", "Idaho"], "k": 3, "tags": ["trial", "wheat", "agripro"]}
|
||||
{"query": "DKC65-95 corn yield in trials", "tool": "search_trials", "filters": {"crop": "corn", "product": "DKC65-95"}, "expected_substrings": ["DKC65-95"], "k": 3, "tags": ["trial", "product-filter"]}
|
||||
{"query": "NK1701 corn trials head to head", "tool": "search_trials", "filters": {"crop": "corn", "product": "NK1701"}, "expected_substrings": ["NK1701"], "k": 3, "tags": ["trial", "cross-vendor", "product-filter"]}
|
||||
{"query": "silage corn high milk per acre dairy", "tool": "search_trials", "filters": {"crop": "silage"}, "expected_metadata": {"crop": "silage"}, "expected_substrings": ["Milk Per Acre", "Ton/Acre"], "k": 5, "tags": ["trial", "silage"]}
|
||||
{"query": "soybean 2025 Minnesota top performers", "tool": "search_trials", "filters": {"crop": "soybeans", "state": "MN", "year": 2025}, "expected_metadata": {"crop": "soybeans", "state": "MN", "year": 2025}, "k": 5, "tags": ["trial", "regional"]}
|
||||
{"query": "Pioneer P1142 hybrid recommendation", "tool": "search_docs", "must_not_contain_source_keys": ["pioneer", "p1142"], "expect_lessons_call": true, "tags": ["pioneer-fallback", "anti-hallucination"]}
|
||||
{"query": "DKC65-20 yield Alabama trial", "tool": "search_trials", "filters": {"crop": "corn", "product": "DKC65-20"}, "expected_empty": true, "tags": ["trial", "not-in-corpus", "anti-hallucination"]}
|
||||
@@ -0,0 +1,41 @@
|
||||
# seed-mcp retrieval eval — k=5
|
||||
|
||||
_21 golden queries × 4 retrievers_
|
||||
|
||||
## Summary
|
||||
|
||||
| Retriever | Passed | Recall | P@1 | MRR | Avg ms |
|
||||
|---|---|---|---|---|---|
|
||||
| **hybrid+rerank** | 21/21 | 100.00% | 90.48% | 0.905 | 2064 |
|
||||
| **bm25** | 20/21 | 95.24% | 80.95% | 0.833 | 5 |
|
||||
| **hybrid** | 15/21 | 71.43% | 61.90% | 0.619 | 73 |
|
||||
| **dense** | 14/21 | 66.67% | 38.10% | 0.440 | 79 |
|
||||
|
||||
**Recall** = % of queries where ≥1 top-k chunk satisfied the spec. **P@1** = % where the very first result satisfied it. **MRR** = mean of `1 / rank-of-first-satisfying-result` (0 if missed).
|
||||
|
||||
## Per-query results
|
||||
|
||||
| Query | bm25 | dense | hybrid | hybrid+rerank |
|
||||
|---|---|---|---|---|
|
||||
| `DKC62-08RIB ratings` | ✅ #1 | ❌ | ❌ | ✅ #1 |
|
||||
| `AG29XF4 disease ratings` | ✅ #1 | ❌ | ❌ | ✅ #1 |
|
||||
| `WB6430 westbred wheat` | ✅ #1 | ❌ | ❌ | ✅ #1 |
|
||||
| `E085Z5 corn` | ✅ #1 | ❌ | ❌ | ✅ #1 |
|
||||
| `AP Iliad wheat performance` | ✅ #1 | ❌ | ❌ | ✅ #1 |
|
||||
| `drought tolerant corn for sandy soil short season Iowa` | ✅ #2 | ✅ #1 | ✅ #1 | ✅ #1 |
|
||||
| `soybean cyst nematode SCN resistant variety` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 |
|
||||
| `Phytophthora resistance Rps3a soybean` | ✅ #1 | ✅ #2 | ✅ #1 | ✅ #1 |
|
||||
| `XtendFlex soybean Northern Plains` | ❌ | ✅ #1 | ✅ #1 | ✅ #1 |
|
||||
| `Hard Red Spring wheat stripe rust resistance` | ✅ #1 | ✅ #3 | ✅ #1 | ✅ #1 |
|
||||
| `Soft White Winter wheat Pacific Northwest` | ✅ #1 | ✅ #5 | ✅ #1 | ✅ #1 |
|
||||
| `Goss's Wilt resistance corn` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 |
|
||||
| `best corn 2024 Iowa` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 |
|
||||
| `Indiana corn yield comparison 2024` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 |
|
||||
| `AP Iliad Idaho wheat trial` | ✅ #1 | ✅ #5 | ✅ #1 | ✅ #1 |
|
||||
| `DKC65-95 corn yield in trials` | ✅ #1 | ❌ | ✅ #1 | ✅ #1 |
|
||||
| `NK1701 corn trials head to head` | ✅ #1 | ❌ | ❌ | ✅ #1 |
|
||||
| `silage corn high milk per acre dairy` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 |
|
||||
| `soybean 2025 Minnesota top performers` | ✅ #1 | ✅ #1 | ✅ #1 | ✅ #1 |
|
||||
| `Pioneer P1142 hybrid recommendation` | ✅ | ✅ | ✅ | ✅ |
|
||||
| `DKC65-20 yield Alabama trial` | ✅ | ✅ | ✅ | ✅ |
|
||||
|
||||
+186
-48
@@ -1,62 +1,200 @@
|
||||
"""Retriever protocol + concrete implementations.
|
||||
"""Retriever protocol + concrete implementations for seed-mcp eval.
|
||||
|
||||
A single matrix dimension per knob (dense / reranked / bm25 / hybrid)
|
||||
so the eval harness can compare them apples-to-apples. Implement these
|
||||
once at Phase 7 and reuse them across every retrieval change.
|
||||
Each retriever returns a ranked list of chunk_ids. The eval harness
|
||||
in ``run_eval.py`` measures each retriever against the golden
|
||||
``queries.jsonl`` set across MRR / Recall@K / nDCG@K.
|
||||
|
||||
Each retriever returns a ranked list of (bundle_id, page_id) tuples
|
||||
deduplicated to the page level (chunks within the same page collapse
|
||||
to one entry; the highest-ranked chunk's position wins).
|
||||
Four named configurations, matching the four switches in
|
||||
``docs_mcp/server.py``:
|
||||
|
||||
dense — Chroma dense retrieval alone
|
||||
bm25 — SQLite FTS5 BM25 alone
|
||||
hybrid — dense + bm25 fused via RRF
|
||||
hybrid_rerank — hybrid pool → cross-encoder rerank
|
||||
|
||||
Each retriever takes ``filters`` (the same dict shape
|
||||
``_build_where`` accepts in server.py) so trial-specific facets
|
||||
(data_type, state, year, crop) work consistently across the
|
||||
four configurations.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, Iterable
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
# Add repo root so we can import docs_mcp and rag from here.
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
|
||||
|
||||
class Retriever(Protocol):
|
||||
name: str
|
||||
|
||||
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
|
||||
"""Return up to k (bundle_id, page_id) tuples in rank order."""
|
||||
def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]:
|
||||
"""Return up to k chunk_ids in rank order."""
|
||||
...
|
||||
|
||||
|
||||
def _collapse_to_pages(chunk_ids: Iterable[tuple[str, str, str]], k: int) -> list[tuple[str, str]]:
|
||||
"""Take a stream of (bundle_id, page_id, chunk_ordinal) and return
|
||||
the first k unique pages in their first-seen order."""
|
||||
seen: set[tuple[str, str]] = set()
|
||||
out: list[tuple[str, str]] = []
|
||||
for bid, pid, _ord in chunk_ids:
|
||||
key = (bid, pid)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(key)
|
||||
if len(out) >= k:
|
||||
break
|
||||
def _build_where(filters: dict | None) -> dict | None:
|
||||
"""Mirror of docs_mcp.server._build_where but accepts the eval's
|
||||
looser shape."""
|
||||
if not filters:
|
||||
return None
|
||||
conds: list[dict] = []
|
||||
if filters.get("data_type"):
|
||||
conds.append({"data_type": filters["data_type"]})
|
||||
if filters.get("crop"):
|
||||
conds.append({"crop": filters["crop"].lower()})
|
||||
if filters.get("brand"):
|
||||
conds.append({"brand": filters["brand"].upper()})
|
||||
if filters.get("state"):
|
||||
s = filters["state"]
|
||||
conds.append({"state": s.upper() if len(s) <= 3 else s})
|
||||
if filters.get("year"):
|
||||
conds.append({"year": int(filters["year"])})
|
||||
if not conds:
|
||||
return None
|
||||
if len(conds) == 1:
|
||||
return conds[0]
|
||||
return {"$and": conds}
|
||||
|
||||
|
||||
class DenseRetriever:
|
||||
name = "dense"
|
||||
|
||||
def __init__(self, collection, pool: int = 50):
|
||||
self.col = collection
|
||||
self.pool = pool
|
||||
|
||||
def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]:
|
||||
where = _build_where(filters)
|
||||
try:
|
||||
r = self.col.query(
|
||||
query_texts=[query], n_results=max(k, self.pool), where=where,
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
return (r.get("ids") or [[]])[0][:k]
|
||||
|
||||
|
||||
class BM25Retriever:
|
||||
name = "bm25"
|
||||
|
||||
def __init__(self, bm25, pool: int = 50):
|
||||
self.bm25 = bm25
|
||||
self.pool = pool
|
||||
|
||||
def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]:
|
||||
where = _build_where(filters)
|
||||
hits = self.bm25.query(query, n=max(k, self.pool), where=where)
|
||||
return [cid for cid, _ in hits[:k]]
|
||||
|
||||
|
||||
class HybridRetriever:
|
||||
"""Dense + BM25 fused via RRF — same fusion the server uses."""
|
||||
name = "hybrid"
|
||||
|
||||
def __init__(self, collection, bm25, pool: int = 50, rrf_k: int = 60):
|
||||
self.col = collection
|
||||
self.bm25 = bm25
|
||||
self.pool = pool
|
||||
self.rrf_k = rrf_k
|
||||
|
||||
def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]:
|
||||
where = _build_where(filters)
|
||||
try:
|
||||
d = self.col.query(query_texts=[query], n_results=self.pool, where=where)
|
||||
dense_ids = (d.get("ids") or [[]])[0]
|
||||
except Exception:
|
||||
dense_ids = []
|
||||
bm25_ids = [c for c, _ in self.bm25.query(query, n=self.pool, where=where)]
|
||||
scores: dict[str, float] = {}
|
||||
for ranking in (dense_ids, bm25_ids):
|
||||
for rank, cid in enumerate(ranking):
|
||||
scores[cid] = scores.get(cid, 0.0) + 1.0 / (self.rrf_k + rank + 1)
|
||||
fused = sorted(scores, key=lambda d: scores[d], reverse=True)
|
||||
return fused[:k]
|
||||
|
||||
|
||||
class HybridRerankRetriever:
|
||||
"""Hybrid pool → cross-encoder rerank via the llama.cpp endpoint."""
|
||||
name = "hybrid+rerank"
|
||||
|
||||
def __init__(self, collection, bm25, rerank_url: str,
|
||||
pool: int = 50, rerank_pool: int = 50,
|
||||
rrf_k: int = 60, doc_max_chars: int = 2000,
|
||||
timeout: float = 30.0):
|
||||
self.col = collection
|
||||
self.bm25 = bm25
|
||||
self.rerank_url = rerank_url.rstrip("/")
|
||||
self.pool = pool
|
||||
self.rerank_pool = rerank_pool
|
||||
self.rrf_k = rrf_k
|
||||
self.doc_max_chars = doc_max_chars
|
||||
self.timeout = timeout
|
||||
|
||||
def retrieve(self, query: str, k: int, filters: dict | None) -> list[str]:
|
||||
where = _build_where(filters)
|
||||
try:
|
||||
d = self.col.query(
|
||||
query_texts=[query], n_results=self.pool, where=where,
|
||||
include=["documents"],
|
||||
)
|
||||
dense_ids = (d.get("ids") or [[]])[0]
|
||||
dense_docs = (d.get("documents") or [[]])[0]
|
||||
id_to_doc = dict(zip(dense_ids, dense_docs))
|
||||
except Exception:
|
||||
dense_ids = []
|
||||
id_to_doc = {}
|
||||
bm25_ids = [c for c, _ in self.bm25.query(query, n=self.pool, where=where)]
|
||||
|
||||
# Fuse to a hybrid pool
|
||||
scores: dict[str, float] = {}
|
||||
for ranking in (dense_ids, bm25_ids):
|
||||
for rank, cid in enumerate(ranking):
|
||||
scores[cid] = scores.get(cid, 0.0) + 1.0 / (self.rrf_k + rank + 1)
|
||||
fused = sorted(scores, key=lambda d: scores[d], reverse=True)
|
||||
|
||||
# Fetch docs for any BM25-only ids in the rerank pool
|
||||
missing = [cid for cid in fused[: self.rerank_pool] if cid not in id_to_doc]
|
||||
if missing:
|
||||
try:
|
||||
extra = self.col.get(ids=missing, include=["documents"])
|
||||
for cid, doc in zip(extra.get("ids") or [], extra.get("documents") or []):
|
||||
id_to_doc[cid] = doc
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Rerank
|
||||
pool_ids = fused[: self.rerank_pool]
|
||||
docs = [(id_to_doc.get(cid, "") or "")[: self.doc_max_chars] for cid in pool_ids]
|
||||
try:
|
||||
import httpx
|
||||
with httpx.Client(timeout=self.timeout) as c:
|
||||
r = c.post(
|
||||
f"{self.rerank_url}/v1/rerank",
|
||||
json={"model": "rerank", "query": query, "documents": docs},
|
||||
)
|
||||
r.raise_for_status()
|
||||
results = r.json().get("results") or []
|
||||
if not results:
|
||||
return fused[:k]
|
||||
ordered = sorted(results, key=lambda x: -x.get("relevance_score", float("-inf")))
|
||||
reranked = [pool_ids[x["index"]] for x in ordered if 0 <= x.get("index", -1) < len(pool_ids)]
|
||||
return reranked[:k]
|
||||
except Exception:
|
||||
return fused[:k]
|
||||
|
||||
|
||||
def build_all_retrievers(collection, bm25, rerank_url: str | None = None) -> list[Retriever]:
|
||||
"""Return the four named retrievers ready to evaluate."""
|
||||
out: list[Retriever] = [
|
||||
DenseRetriever(collection),
|
||||
BM25Retriever(bm25),
|
||||
HybridRetriever(collection, bm25),
|
||||
]
|
||||
if rerank_url:
|
||||
out.append(HybridRerankRetriever(collection, bm25, rerank_url))
|
||||
return out
|
||||
|
||||
|
||||
# TODO Phase 2/3 — implement these once Chroma + the bm25 module are
|
||||
# in place. Each one is small (15-30 LOC). The eval harness imports
|
||||
# from this module by class name.
|
||||
#
|
||||
# class DenseRetriever:
|
||||
# name = "dense"
|
||||
# def __init__(self, collection): self.col = collection
|
||||
# def retrieve(self, query, k=10): ...
|
||||
#
|
||||
# class RerankedRetriever:
|
||||
# name = "dense+rerank"
|
||||
# def __init__(self, collection, rerank_url, pool=200): ...
|
||||
# def retrieve(self, query, k=10): ...
|
||||
#
|
||||
# class BM25Retriever:
|
||||
# name = "bm25"
|
||||
# def __init__(self, bm25_index): ...
|
||||
# def retrieve(self, query, k=10): ...
|
||||
#
|
||||
# class HybridRetriever:
|
||||
# name = "bm25+dense+rrf"
|
||||
# def __init__(self, dense, bm25, k_rrf=60): ...
|
||||
# def retrieve(self, query, k=10): ...
|
||||
|
||||
+272
-41
@@ -1,32 +1,60 @@
|
||||
"""Run all retrievers against eval/queries.jsonl, emit a markdown report.
|
||||
|
||||
For seed-mcp, the "expected" answer for many queries isn't a single
|
||||
chunk — it's "a chunk satisfying these constraints." So per-query
|
||||
scoring is one of:
|
||||
|
||||
expected_source_keys — at least one of these source_keys appears
|
||||
in top-k (used for variety-code queries
|
||||
with a single canonical answer)
|
||||
expected_metadata — all top-k must match these key=value
|
||||
constraints (e.g. crop=corn, year=2024)
|
||||
expected_substrings — at least one top-k chunk's text/metadata
|
||||
contains each substring (e.g. "SCN" must
|
||||
appear when querying SCN resistance)
|
||||
must_not_contain_source_keys — anti-hallucination: NO top-k chunk's
|
||||
source_key may contain these tokens
|
||||
(Pioneer fallback queries)
|
||||
expected_empty — top-k MUST be empty (anti-hallucination)
|
||||
expect_lessons_call — the agent should call api_lessons; not
|
||||
measurable from retrieval alone, recorded
|
||||
as an advisory note
|
||||
|
||||
Metrics computed per retriever:
|
||||
|
||||
MRR — mean reciprocal rank of the FIRST expected page in the
|
||||
ranked result list (0 if not in top-k).
|
||||
Recall@K — fraction of expected pages that appear in top-K.
|
||||
nDCG@K — discounted gain weighted by rank position.
|
||||
recall_known — fraction of queries where the retriever returned
|
||||
a chunk satisfying the query's expectations
|
||||
precision_top1 — fraction of queries where the FIRST result
|
||||
satisfied expectations
|
||||
mrr — mean reciprocal rank of the FIRST satisfying chunk
|
||||
|
||||
The "right" number depends on what you're measuring. MRR tracks "the
|
||||
first-line answer is correct"; Recall@K tracks "everything relevant
|
||||
is there to draw from"; nDCG@K is a smoother combination of both.
|
||||
For docs-RAG, MRR is usually the headline metric.
|
||||
Plus a per-query breakdown table so you can see exactly where each
|
||||
retriever wins or loses.
|
||||
|
||||
Usage:
|
||||
|
||||
python -m eval.run_eval \\
|
||||
--queries eval/queries.jsonl \\
|
||||
--k 5 \\
|
||||
--rerank-url http://localhost:18080 \\
|
||||
--output eval/results/baseline.md
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
# Add repo root for imports
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from eval.retrievers import build_all_retrievers # noqa: E402
|
||||
|
||||
logging.getLogger("chromadb").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def load_queries(path: Path) -> list[dict]:
|
||||
@@ -34,31 +62,203 @@ def load_queries(path: Path) -> list[dict]:
|
||||
return [json.loads(line) for line in fh if line.strip()]
|
||||
|
||||
|
||||
def reciprocal_rank(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]]) -> float:
|
||||
expected_set = set(expected)
|
||||
for i, page in enumerate(retrieved, start=1):
|
||||
if page in expected_set:
|
||||
return 1.0 / i
|
||||
return 0.0
|
||||
def _doc_satisfies(meta: dict, doc: str, query_spec: dict) -> bool:
|
||||
"""Does this single retrieved (metadata, doc) tuple satisfy the
|
||||
query spec? Used by the 'first satisfying' metric."""
|
||||
sk = meta.get("source_key") or ""
|
||||
# exact source_key match
|
||||
if "expected_source_keys" in query_spec:
|
||||
for want in query_spec["expected_source_keys"]:
|
||||
if want.lower() == sk.lower():
|
||||
return True
|
||||
return False
|
||||
# all metadata constraints match
|
||||
if "expected_metadata" in query_spec:
|
||||
for k, v in query_spec["expected_metadata"].items():
|
||||
mv = meta.get(k)
|
||||
if isinstance(v, int):
|
||||
if mv != v:
|
||||
return False
|
||||
else:
|
||||
if (mv or "").lower() != str(v).lower():
|
||||
return False
|
||||
# if no substring requirement, metadata match is enough
|
||||
if "expected_substrings" not in query_spec:
|
||||
return True
|
||||
# at least one substring present (in doc OR metadata values)
|
||||
if "expected_substrings" in query_spec:
|
||||
haystack = (doc + " " + " ".join(str(v) for v in meta.values())).lower()
|
||||
return any(s.lower() in haystack for s in query_spec["expected_substrings"])
|
||||
return False
|
||||
|
||||
|
||||
def recall_at_k(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]], k: int) -> float:
|
||||
if not expected:
|
||||
return 0.0
|
||||
retrieved_set = set(retrieved[:k])
|
||||
hits = sum(1 for e in expected if e in retrieved_set)
|
||||
return hits / len(expected)
|
||||
def _evaluate_one(retriever, query_spec: dict, k: int, col) -> dict:
|
||||
"""Return per-query metrics for one retriever."""
|
||||
query = query_spec["query"]
|
||||
filters = dict(query_spec.get("filters") or {})
|
||||
# search_trials queries imply data_type=trial; search_docs implies variety
|
||||
tool = query_spec.get("tool", "search_docs")
|
||||
if tool == "search_trials":
|
||||
filters.setdefault("data_type", "trial")
|
||||
elif tool == "search_docs":
|
||||
filters.setdefault("data_type", "variety")
|
||||
# 'product' is a server-side post-filter, not Chroma; strip
|
||||
product = filters.pop("product", None)
|
||||
|
||||
t0 = time.monotonic()
|
||||
ids = retriever.retrieve(query, k, filters)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
# Anti-hallucination queries: expected_empty should return nothing
|
||||
# (BUT we still allow the retriever to surface chunks if the
|
||||
# product filter would filter them out at the server level — so
|
||||
# we re-apply the product filter here).
|
||||
if product:
|
||||
try:
|
||||
extra = col.get(ids=ids, include=["documents"])
|
||||
id_to_doc = dict(zip(extra.get("ids") or [], extra.get("documents") or []))
|
||||
except Exception:
|
||||
id_to_doc = {}
|
||||
ids = [cid for cid in ids if product.lower() in id_to_doc.get(cid, "").lower()]
|
||||
|
||||
if query_spec.get("expected_empty"):
|
||||
passed = len(ids) == 0
|
||||
return {
|
||||
"query": query, "retriever": retriever.name,
|
||||
"k": k, "n_hits": len(ids), "rank_first_match": None,
|
||||
"passed": passed, "elapsed_ms": round(elapsed_ms, 1),
|
||||
"kind": "expected_empty",
|
||||
}
|
||||
|
||||
if "must_not_contain_source_keys" in query_spec:
|
||||
bad_tokens = [t.lower() for t in query_spec["must_not_contain_source_keys"]]
|
||||
try:
|
||||
extra = col.get(ids=ids, include=["metadatas"])
|
||||
metas = extra.get("metadatas") or []
|
||||
except Exception:
|
||||
metas = []
|
||||
# PASS = no top-k chunk's source_key contains a forbidden token
|
||||
for m in metas:
|
||||
sk = (m.get("source_key") or "").lower()
|
||||
if any(t in sk for t in bad_tokens):
|
||||
return {
|
||||
"query": query, "retriever": retriever.name,
|
||||
"k": k, "n_hits": len(ids), "rank_first_match": None,
|
||||
"passed": False, "elapsed_ms": round(elapsed_ms, 1),
|
||||
"kind": "must_not_contain",
|
||||
}
|
||||
return {
|
||||
"query": query, "retriever": retriever.name,
|
||||
"k": k, "n_hits": len(ids), "rank_first_match": None,
|
||||
"passed": True, "elapsed_ms": round(elapsed_ms, 1),
|
||||
"kind": "must_not_contain",
|
||||
}
|
||||
|
||||
# Positive-match query: pull docs+meta and check each
|
||||
try:
|
||||
extra = col.get(ids=ids, include=["documents", "metadatas"])
|
||||
docs = extra.get("documents") or []
|
||||
metas = extra.get("metadatas") or []
|
||||
ext_ids = extra.get("ids") or []
|
||||
order_idx = {cid: i for i, cid in enumerate(ext_ids)}
|
||||
except Exception:
|
||||
docs = []
|
||||
metas = []
|
||||
order_idx = {}
|
||||
|
||||
rank_first = None
|
||||
for rank, cid in enumerate(ids, start=1):
|
||||
i = order_idx.get(cid)
|
||||
if i is None:
|
||||
continue
|
||||
if _doc_satisfies(metas[i], docs[i], query_spec):
|
||||
rank_first = rank
|
||||
break
|
||||
|
||||
return {
|
||||
"query": query, "retriever": retriever.name,
|
||||
"k": k, "n_hits": len(ids),
|
||||
"rank_first_match": rank_first,
|
||||
"passed": rank_first is not None,
|
||||
"elapsed_ms": round(elapsed_ms, 1),
|
||||
"kind": "positive",
|
||||
}
|
||||
|
||||
|
||||
def ndcg_at_k(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]], k: int) -> float:
|
||||
expected_set = set(expected)
|
||||
dcg = 0.0
|
||||
for i, page in enumerate(retrieved[:k], start=1):
|
||||
if page in expected_set:
|
||||
dcg += 1.0 / math.log2(i + 1)
|
||||
# Ideal DCG: every expected page in the top positions.
|
||||
idcg = sum(1.0 / math.log2(i + 1) for i in range(1, min(len(expected), k) + 1))
|
||||
return dcg / idcg if idcg else 0.0
|
||||
def _aggregate(results: list[dict]) -> dict:
|
||||
"""Aggregate per-query results into MRR / recall / precision@1."""
|
||||
by_retriever: dict[str, list[dict]] = {}
|
||||
for r in results:
|
||||
by_retriever.setdefault(r["retriever"], []).append(r)
|
||||
out: dict[str, dict] = {}
|
||||
for name, rows in by_retriever.items():
|
||||
n = len(rows)
|
||||
passed = sum(1 for r in rows if r["passed"])
|
||||
ranks = [r["rank_first_match"] for r in rows
|
||||
if r["passed"] and r.get("rank_first_match")]
|
||||
mrr = sum(1.0 / r for r in ranks) / n if n else 0.0
|
||||
precision1 = sum(1 for r in rows if r["passed"] and r.get("rank_first_match") == 1) / n if n else 0.0
|
||||
avg_ms = sum(r["elapsed_ms"] for r in rows) / n if n else 0.0
|
||||
out[name] = {
|
||||
"n_queries": n,
|
||||
"passed": passed,
|
||||
"recall_known": passed / n if n else 0.0,
|
||||
"precision_top1": precision1,
|
||||
"mrr": mrr,
|
||||
"avg_latency_ms": round(avg_ms, 1),
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def _emit_markdown(queries: list[dict], results: list[dict],
|
||||
summary: dict, k: int) -> str:
|
||||
lines: list[str] = []
|
||||
lines.append(f"# seed-mcp retrieval eval — k={k}")
|
||||
lines.append("")
|
||||
lines.append(f"_{len(queries)} golden queries × {len(summary)} retrievers_")
|
||||
lines.append("")
|
||||
lines.append("## Summary")
|
||||
lines.append("")
|
||||
lines.append("| Retriever | Passed | Recall | P@1 | MRR | Avg ms |")
|
||||
lines.append("|---|---|---|---|---|---|")
|
||||
for name in sorted(summary, key=lambda n: -summary[n]["mrr"]):
|
||||
s = summary[name]
|
||||
lines.append(
|
||||
f"| **{name}** | {s['passed']}/{s['n_queries']} "
|
||||
f"| {s['recall_known']:.2%} | {s['precision_top1']:.2%} "
|
||||
f"| {s['mrr']:.3f} | {s['avg_latency_ms']:.0f} |"
|
||||
)
|
||||
lines.append("")
|
||||
lines.append("**Recall** = % of queries where ≥1 top-k chunk satisfied the spec. "
|
||||
"**P@1** = % where the very first result satisfied it. "
|
||||
"**MRR** = mean of `1 / rank-of-first-satisfying-result` (0 if missed).")
|
||||
lines.append("")
|
||||
|
||||
# Per-query breakdown
|
||||
lines.append("## Per-query results")
|
||||
lines.append("")
|
||||
by_query: dict[str, list[dict]] = {}
|
||||
for r in results:
|
||||
by_query.setdefault(r["query"], []).append(r)
|
||||
retriever_names = sorted({r["retriever"] for r in results})
|
||||
header = "| Query | " + " | ".join(retriever_names) + " |"
|
||||
sep = "|" + "---|" * (len(retriever_names) + 1)
|
||||
lines.append(header)
|
||||
lines.append(sep)
|
||||
for q in queries:
|
||||
cells = [f"`{q['query'][:60]}`"]
|
||||
for name in retriever_names:
|
||||
r = next((x for x in by_query.get(q["query"], []) if x["retriever"] == name), None)
|
||||
if r is None:
|
||||
cells.append("?")
|
||||
elif r["passed"]:
|
||||
rk = r.get("rank_first_match")
|
||||
cells.append(f"✅ #{rk}" if rk else "✅")
|
||||
else:
|
||||
cells.append("❌")
|
||||
lines.append("| " + " | ".join(cells) + " |")
|
||||
lines.append("")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
@@ -66,25 +266,56 @@ def main() -> int:
|
||||
p.add_argument("--queries", type=Path, default=Path("eval/queries.jsonl"))
|
||||
p.add_argument("--k", type=int, default=5)
|
||||
p.add_argument("--output", type=Path, default=Path("eval/results/baseline.md"))
|
||||
p.add_argument("--rerank-url", default=os.environ.get("RERANK_URL", ""))
|
||||
p.add_argument("--product-name", default=os.environ.get("PRODUCT_NAME", "crop_seed"))
|
||||
args = p.parse_args()
|
||||
|
||||
if not args.queries.exists():
|
||||
print(f"queries file not found: {args.queries}")
|
||||
print("hint: copy eval/queries.jsonl.example and edit")
|
||||
return 1
|
||||
|
||||
queries = load_queries(args.queries)
|
||||
print(f"loaded {len(queries)} queries")
|
||||
|
||||
# TODO Phase 7: instantiate the retrievers you implemented in
|
||||
# eval/retrievers.py and run each one against each query.
|
||||
# Aggregate MRR / Recall@K / nDCG@K per retriever. Emit a
|
||||
# markdown table to args.output. Commit the file alongside the
|
||||
# PR that changes retrieval.
|
||||
raise NotImplementedError(
|
||||
"Wire up the retrievers in eval/retrievers.py first, then "
|
||||
"fill in this evaluation loop. See PLAN.md Phase 7."
|
||||
# Connect to Chroma + BM25
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from rag.embeddings import embedding_function
|
||||
from rag.bm25 import BM25Index
|
||||
|
||||
repo_root = Path(__file__).resolve().parent.parent
|
||||
client = chromadb.PersistentClient(
|
||||
path=str(repo_root / "chroma"),
|
||||
settings=Settings(anonymized_telemetry=False),
|
||||
)
|
||||
col = client.get_collection(f"{args.product_name}_docs",
|
||||
embedding_function=embedding_function())
|
||||
bm25 = BM25Index(repo_root / "bm25" / f"{args.product_name}_docs.db")
|
||||
print(f"chroma: {col.count()} chunks; bm25: {bm25.count()} chunks")
|
||||
|
||||
retrievers = build_all_retrievers(col, bm25, args.rerank_url or None)
|
||||
print(f"retrievers: {[r.name for r in retrievers]}")
|
||||
|
||||
all_results: list[dict] = []
|
||||
for r in retrievers:
|
||||
print(f"running {r.name}...")
|
||||
for q in queries:
|
||||
res = _evaluate_one(r, q, args.k, col)
|
||||
all_results.append(res)
|
||||
|
||||
summary = _aggregate(all_results)
|
||||
md = _emit_markdown(queries, all_results, summary, args.k)
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
args.output.write_text(md, encoding="utf-8")
|
||||
print(f"\nreport: {args.output}")
|
||||
print()
|
||||
# Print summary to stdout too
|
||||
for line in md.split("\n"):
|
||||
if line.startswith("|"):
|
||||
print(line)
|
||||
if line.startswith("## Per-query"):
|
||||
break
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user