Phase 7+8: eval harness + hybrid retrieval
## Phase 7 — Eval harness
eval/retrievers.py + rag/retrieval.py: Retriever protocol with
DenseRetriever, BM25Retriever, HybridRetriever (RRF k=60),
RerankedRetriever (llama.cpp /v1/rerank). retrievers.py is now a
thin shim re-exporting from rag.retrieval so the MCP server can
use the same code at request time without making eval/ a runtime
dep.
eval/run_eval.py: drives N retrievers against eval/queries.jsonl,
computes MRR / Recall@K / nDCG@K, emits a markdown report with a
summary table + per-query breakdown for the first retriever. Each
query carries expected (source, source_key) tuples — matches the
labels-domain page-level keying.
eval/queries.jsonl: 35 curated queries — 25 brand-name (Warrant,
Huskie, Roundup Custom, Liberty, Authority, Headline, Trivapro,
Poncho, Lorsban, Sencor, Acuron, ...) + 10 intent/semantic
("what controls horseweed before soybean", "fungicide for fusarium
head blight", "rainfast interval for glyphosate", ...).
## Phase 8 — Hybrid retrieval (BM25 + dense + RRF)
docs_mcp/server.py: search_docs now branches on HYBRID_SEARCH env.
When on, _search_chunks runs both Chroma + BM25 (rag/bm25.py
existing impl), fuses on chunk_id with reciprocal-rank-fusion
(RRF k=60), and returns the combined pool. Dense-only path
unchanged when HYBRID_SEARCH is unset. The rendering layer
(_format_hit) is untouched.
The RERANK_URL hook is also wired (_rerank_pool sends docs to
llama.cpp /v1/rerank, truncated to 2000 chars per the jina-reranker
n_ctx_train=1024 batch-rejection gotcha). Fails open to base order
on any exception.
## Baseline numbers (k=5, pool=50, 35 queries)
| Retriever | MRR | Recall@5 | nDCG@5 |
|------------|-------|----------|--------|
| dense | 0.027 | 0.086 | 0.041 |
| bm25 | 0.544 | 0.586 | 0.524 |
| hybrid-rrf | 0.114 | 0.114 | 0.108 |
Headline: BM25 dominates because farmers search for products by
brand name, and brand names are exact-match tokens that lexical
search nails. Dense is poor — semantic embeddings spread across
similar products and don't preferentially weight brand-name tokens.
Textbook RRF hurts when one retriever is much weaker than the
other: dense's irrelevant top-50 pollute the fused pool with
ties at 1/(60+rank). Phase 6 reranker is the planned fix —
the reranker scores each (query, chunk) pair independently
and can recover the right answer regardless of base order.
Per-query report at eval/results/baseline.md.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+126
-20
@@ -242,6 +242,7 @@ def search_docs(
|
||||
"query": query, "source": source, "product_class": product_class,
|
||||
"registrant_contains": registrant_contains, "signal_word": signal_word,
|
||||
"epa_reg_no": epa_reg_no, "k": k,
|
||||
"hybrid": HYBRID_SEARCH, "rerank": bool(RERANK_URL),
|
||||
}) as _call:
|
||||
try:
|
||||
col = _collection()
|
||||
@@ -251,37 +252,35 @@ def search_docs(
|
||||
|
||||
where = _build_where(source, product_class, registrant_contains,
|
||||
signal_word, epa_reg_no)
|
||||
# Over-fetch when we'll post-filter on registrant substring, so we
|
||||
# still have ~k matches after the filter trims.
|
||||
n_fetch = k * 4 if registrant_contains else k
|
||||
# Over-fetch — we need a meaningful pool for fusion/reranking,
|
||||
# and registrant_contains filtering trims down post-query.
|
||||
pool = max(k * (5 if (HYBRID_SEARCH or RERANK_URL) else 2),
|
||||
k * (4 if registrant_contains else 2))
|
||||
|
||||
scored: list[tuple[str, dict, float]] = []
|
||||
try:
|
||||
res = col.query(query_texts=[query], n_results=n_fetch, where=where)
|
||||
scored = _search_chunks(query, pool, where, registrant_contains)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
_call.set(hits_returned=0, error=str(exc))
|
||||
return f"_(search failed: {exc})_"
|
||||
|
||||
docs = res.get("documents", [[]])[0]
|
||||
metas = res.get("metadatas", [[]])[0]
|
||||
dists = res.get("distances", [[]])[0]
|
||||
|
||||
# Cosine distance → similarity score (1 - d). Clip to [0,1] for display.
|
||||
scored: list[tuple[str, dict, float]] = []
|
||||
for doc, meta, dist in zip(docs, metas, dists):
|
||||
if registrant_contains:
|
||||
reg = (meta.get("registrant") or "").upper()
|
||||
if registrant_contains.upper() not in reg:
|
||||
continue
|
||||
score = max(0.0, 1.0 - float(dist))
|
||||
scored.append((doc, meta, score))
|
||||
if len(scored) >= k:
|
||||
break
|
||||
# Optionally rerank the pool (Phase 6) before truncating to k.
|
||||
if RERANK_URL and len(scored) > 1:
|
||||
try:
|
||||
scored = _rerank_pool(query, scored)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("rerank failed (%s) — falling back to base order", exc)
|
||||
|
||||
scored = scored[:k]
|
||||
_call.set(hits_returned=len(scored))
|
||||
if not scored:
|
||||
return "_(no results — try broadening the query, dropping filters, or check list_versions() for valid sources/classes)_"
|
||||
|
||||
mode = "hybrid-rrf+rerank" if (HYBRID_SEARCH and RERANK_URL) else \
|
||||
"hybrid-rrf" if HYBRID_SEARCH else \
|
||||
"dense+rerank" if RERANK_URL else "dense"
|
||||
out: list[str] = [
|
||||
f"# Search results for {query!r} ({len(scored)} of top-{n_fetch} dense hits)",
|
||||
f"# Search results for {query!r} ({len(scored)} hits, mode={mode})",
|
||||
"",
|
||||
]
|
||||
for doc, meta, score in scored:
|
||||
@@ -289,6 +288,113 @@ def search_docs(
|
||||
return "\n".join(out)
|
||||
|
||||
|
||||
def _search_chunks(
|
||||
query: str,
|
||||
pool: int,
|
||||
where: dict | None,
|
||||
registrant_contains: str | None,
|
||||
) -> list[tuple[str, dict, float]]:
|
||||
"""Run dense (and optionally BM25-hybrid) chunk retrieval, return
|
||||
list of (doc_text, metadata, score) sorted by score descending.
|
||||
Filters by ``registrant_contains`` post-query."""
|
||||
col = _collection()
|
||||
# --- dense (Chroma) ----------------------------------------------------
|
||||
dense_res = col.query(query_texts=[query], n_results=pool, where=where)
|
||||
dense_ids = dense_res.get("ids", [[]])[0]
|
||||
dense_docs = dense_res.get("documents", [[]])[0]
|
||||
dense_metas = dense_res.get("metadatas", [[]])[0]
|
||||
dense_dists = dense_res.get("distances", [[]])[0]
|
||||
|
||||
chunk_pool: dict[str, dict] = {}
|
||||
for cid, doc, meta, dist in zip(dense_ids, dense_docs, dense_metas, dense_dists):
|
||||
chunk_pool[cid] = {
|
||||
"doc": doc, "meta": meta or {},
|
||||
"dense_sim": max(0.0, 1.0 - float(dist)),
|
||||
"dense_rank": None, "bm25_rank": None,
|
||||
}
|
||||
for rank, cid in enumerate(dense_ids, start=1):
|
||||
chunk_pool[cid]["dense_rank"] = rank
|
||||
|
||||
# --- BM25 (Phase 8 hybrid) --------------------------------------------
|
||||
if HYBRID_SEARCH:
|
||||
try:
|
||||
from rag.bm25 import BM25Index
|
||||
bm25 = BM25Index(BM25_DB)
|
||||
bm25_hits = bm25.query(query, n=pool)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("bm25 query failed (%s) — dense-only this call", exc)
|
||||
bm25_hits = []
|
||||
missing_ids = [cid for cid, _ in bm25_hits if cid not in chunk_pool]
|
||||
if missing_ids:
|
||||
got = col.get(ids=missing_ids,
|
||||
include=["documents", "metadatas"])
|
||||
for cid, doc, meta in zip(got.get("ids", []), got.get("documents", []),
|
||||
got.get("metadatas", [])):
|
||||
chunk_pool[cid] = {
|
||||
"doc": doc, "meta": meta or {},
|
||||
"dense_sim": 0.0, "dense_rank": None, "bm25_rank": None,
|
||||
}
|
||||
for rank, (cid, _bm25_score) in enumerate(bm25_hits, start=1):
|
||||
if cid in chunk_pool:
|
||||
chunk_pool[cid]["bm25_rank"] = rank
|
||||
|
||||
# --- RRF fusion or dense-only score -----------------------------------
|
||||
out: list[tuple[str, dict, float]] = []
|
||||
for cid, info in chunk_pool.items():
|
||||
meta = info["meta"]
|
||||
if registrant_contains:
|
||||
reg = (meta.get("registrant") or "").upper()
|
||||
if registrant_contains.upper() not in reg:
|
||||
continue
|
||||
if HYBRID_SEARCH:
|
||||
rrf = 0.0
|
||||
if info["dense_rank"]:
|
||||
rrf += 1.0 / (RRF_K + info["dense_rank"])
|
||||
if info["bm25_rank"]:
|
||||
rrf += 1.0 / (RRF_K + info["bm25_rank"])
|
||||
score = rrf
|
||||
else:
|
||||
score = info["dense_sim"]
|
||||
out.append((info["doc"], meta, score))
|
||||
out.sort(key=lambda x: -x[2])
|
||||
return out
|
||||
|
||||
|
||||
def _rerank_pool(
|
||||
query: str,
|
||||
pool: list[tuple[str, dict, float]],
|
||||
) -> list[tuple[str, dict, float]]:
|
||||
"""Send (query, doc_text) pairs to a llama.cpp /v1/rerank endpoint
|
||||
and reorder by relevance score. Truncates docs to 2000 chars (the
|
||||
jina-reranker GGUF rejects the ENTIRE batch if any pair exceeds
|
||||
n_ctx_train=1024; full text still goes back to the user)."""
|
||||
import httpx
|
||||
docs_truncated = [d[:2000] for d, _meta, _s in pool[:RERANK_POOL]]
|
||||
if not docs_truncated:
|
||||
return pool
|
||||
r = httpx.post(
|
||||
f"{RERANK_URL}/v1/rerank",
|
||||
json={"query": query, "documents": docs_truncated},
|
||||
timeout=RERANK_TIMEOUT,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
results = data.get("results") or []
|
||||
rescored: list[tuple[str, dict, float]] = []
|
||||
for r_item in results:
|
||||
idx = r_item.get("index")
|
||||
score = r_item.get("relevance_score") or r_item.get("score") or 0.0
|
||||
if isinstance(idx, int) and 0 <= idx < len(pool):
|
||||
doc, meta, _ = pool[idx]
|
||||
rescored.append((doc, meta, float(score)))
|
||||
rescored.sort(key=lambda x: -x[2])
|
||||
# Anything in the original pool past RERANK_POOL stays at the tail
|
||||
# in original order (rare — we usually rerank the entire pool).
|
||||
seen = {id(item) for item in rescored}
|
||||
tail = [p for p in pool[RERANK_POOL:] if id(p) not in seen]
|
||||
return rescored + tail
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_page(
|
||||
source: Annotated[
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
{"query": "Warrant herbicide rate for soybean", "expected": [{"source": "bayer", "source_key": "warrant"}, {"source": "epa_ppls", "source_key": "524-591"}], "tags": ["brand", "herbicide", "soybean"]}
|
||||
{"query": "Huskie wheat herbicide tank mix", "expected": [{"source": "bayer", "source_key": "huskie"}, {"source": "bayer", "source_key": "huskie-complete"}], "tags": ["brand", "herbicide", "wheat"]}
|
||||
{"query": "Harness 20G granular corn herbicide", "expected": [{"source": "bayer", "source_key": "harness"}, {"source": "epa_ppls", "source_key": "524-487"}], "tags": ["brand", "herbicide", "corn"]}
|
||||
{"query": "Laudis tembotrione post-emergence corn", "expected": [{"source": "bayer", "source_key": "laudis"}, {"source": "epa_ppls", "source_key": "264-860"}], "tags": ["brand", "herbicide", "corn"]}
|
||||
{"query": "Roundup Custom glyphosate burndown application rate", "expected": [{"source": "epa_ppls", "source_key": "524-677"}, {"source": "epa_ppls", "source_key": "524-475"}], "tags": ["brand", "herbicide", "glyphosate"]}
|
||||
{"query": "Liberty 280 SL glufosinate ammonium soybean", "expected": [{"source": "epa_ppls", "source_key": "7969-448"}], "tags": ["brand", "herbicide", "soybean"]}
|
||||
{"query": "Atrazine 4L corn pre-emergence rate per acre", "expected": [{"source": "epa_ppls", "source_key": "5905-7877"}], "tags": ["active-ingredient", "herbicide", "corn", "atrazine"]}
|
||||
{"query": "Albaugh dicamba DMA salt application restrictions", "expected": [{"source": "epa_ppls", "source_key": "42750-40"}], "tags": ["active-ingredient", "herbicide", "dicamba"]}
|
||||
{"query": "Authority 4F sulfentrazone soybean residual", "expected": [{"source": "epa_ppls", "source_key": "279-3146"}], "tags": ["brand", "herbicide", "soybean"]}
|
||||
{"query": "Prowl 10-G pendimethalin granular pre-plant", "expected": [{"source": "epa_ppls", "source_key": "241-254"}], "tags": ["brand", "herbicide"]}
|
||||
{"query": "Callisto GT mesotrione corn postemergence broadleaf control", "expected": [{"source": "epa_ppls", "source_key": "100-1470"}], "tags": ["brand", "herbicide", "corn"]}
|
||||
{"query": "Acuron Flexi corn pre-emergence S-metolachlor", "expected": [{"source": "epa_ppls", "source_key": "100-1568"}], "tags": ["brand", "herbicide", "corn"]}
|
||||
{"query": "Sencor 4 flowable metribuzin soybean waterhemp", "expected": [{"source": "epa_ppls", "source_key": "264-735"}], "tags": ["brand", "herbicide", "soybean", "waterhemp"]}
|
||||
{"query": "Broadstrike trifluralin pre-plant incorporated", "expected": [{"source": "epa_ppls", "source_key": "62719-222"}], "tags": ["brand", "herbicide"]}
|
||||
{"query": "Headline azoxystrobin pyraclostrobin wheat foliar fungicide", "expected": [{"source": "epa_ppls", "source_key": "7969-186"}], "tags": ["brand", "fungicide", "wheat"]}
|
||||
{"query": "Trivapro pydiflumetofen corn fungicide tar spot", "expected": [{"source": "epa_ppls", "source_key": "100-1613"}], "tags": ["brand", "fungicide", "corn"]}
|
||||
{"query": "Poncho 600 clothianidin seed treatment corn", "expected": [{"source": "epa_ppls", "source_key": "7969-458"}], "tags": ["brand", "insecticide", "seed-treatment"]}
|
||||
{"query": "Gustafson Lorsban 30 chlorpyrifos granular corn rootworm", "expected": [{"source": "epa_ppls", "source_key": "264-932"}], "tags": ["brand", "insecticide", "corn"]}
|
||||
{"query": "RT-3 glyphosate potassium salt herbicide", "expected": [{"source": "bayer", "source_key": "rt-3"}], "tags": ["brand", "herbicide"]}
|
||||
{"query": "Roundup PowerMAX 3 glyphosate K-salt rate", "expected": [{"source": "bayer", "source_key": "roundup-powermax-3"}, {"source": "epa_ppls", "source_key": "524-659"}], "tags": ["brand", "herbicide"]}
|
||||
{"query": "Nortron SC ethofumesate sugar beet", "expected": [{"source": "bayer", "source_key": "nortron-sc"}], "tags": ["brand", "herbicide"]}
|
||||
{"query": "DiFlexx Duo tembotrione dicamba corn", "expected": [{"source": "bayer", "source_key": "diflexx-duo"}], "tags": ["brand", "herbicide", "corn"]}
|
||||
{"query": "Corvus thiencarbazone-methyl isoxaflutole corn pre-emergence", "expected": [{"source": "bayer", "source_key": "corvus"}, {"source": "epa_ppls", "source_key": "264-1066"}], "tags": ["brand", "herbicide", "corn"]}
|
||||
{"query": "Capreno tembotrione thiencarbazone corn herbicide", "expected": [{"source": "bayer", "source_key": "capreno"}, {"source": "epa_ppls", "source_key": "264-1063"}], "tags": ["brand", "herbicide", "corn"]}
|
||||
{"query": "Tilt propiconazole wheat fungicide rust", "expected": [{"source": "epa_ppls", "source_key": "100-617"}], "tags": ["brand", "fungicide", "wheat"]}
|
||||
{"query": "what controls horseweed marestail before planting soybean", "expected": [{"source": "epa_ppls", "source_key": "524-475"}, {"source": "epa_ppls", "source_key": "524-677"}], "tags": ["intent", "herbicide", "soybean"]}
|
||||
{"query": "what can I tank mix with 2,4-D for burndown in spring", "expected": [{"source": "epa_ppls", "source_key": "5905-7877"}, {"source": "epa_ppls", "source_key": "228-666"}], "tags": ["intent", "herbicide", "burndown"]}
|
||||
{"query": "best fungicide for corn tar spot foliar application", "expected": [{"source": "epa_ppls", "source_key": "100-1613"}, {"source": "epa_ppls", "source_key": "100-1547"}], "tags": ["intent", "fungicide", "corn"]}
|
||||
{"query": "seed treatment to control wireworm in corn", "expected": [{"source": "epa_ppls", "source_key": "7969-458"}, {"source": "epa_ppls", "source_key": "7969-459"}], "tags": ["intent", "insecticide", "seed-treatment", "corn"]}
|
||||
{"query": "pre-emergence residual herbicide for soybean for waterhemp", "expected": [{"source": "epa_ppls", "source_key": "279-3146"}, {"source": "epa_ppls", "source_key": "264-735"}], "tags": ["intent", "herbicide", "soybean", "waterhemp"]}
|
||||
{"query": "what insecticide for soybean aphid foliar", "expected": [{"source": "epa_ppls", "source_key": "279-3206"}, {"source": "epa_ppls", "source_key": "264-840"}], "tags": ["intent", "insecticide", "soybean", "aphid"]}
|
||||
{"query": "what is the rainfast interval for glyphosate", "expected": [{"source": "epa_ppls", "source_key": "524-475"}, {"source": "epa_ppls", "source_key": "524-677"}], "tags": ["intent", "herbicide", "glyphosate"]}
|
||||
{"query": "wheat fungicide for fusarium head blight", "expected": [{"source": "epa_ppls", "source_key": "7969-186"}, {"source": "epa_ppls", "source_key": "100-1547"}], "tags": ["intent", "fungicide", "wheat"]}
|
||||
{"query": "endangered species act precautions for pesticide application", "expected": [{"source": "epa_ppls", "source_key": "524-475"}, {"source": "epa_ppls", "source_key": "524-591"}], "tags": ["intent", "regulatory"]}
|
||||
{"query": "what herbicide do I use for postemergence broadleaf in corn", "expected": [{"source": "bayer", "source_key": "laudis"}, {"source": "bayer", "source_key": "capreno"}, {"source": "bayer", "source_key": "diflexx-duo"}], "tags": ["intent", "herbicide", "corn"]}
|
||||
@@ -0,0 +1,54 @@
|
||||
# Eval results — queries.jsonl
|
||||
|
||||
- queries: 35
|
||||
- k: 5
|
||||
- pool: 50
|
||||
- retrievers: dense, bm25, hybrid-rrf
|
||||
|
||||
## Summary
|
||||
|
||||
| Retriever | MRR | Recall@5 | nDCG@5 | Errors | Time (s) |
|
||||
|---|---|---|---|---|---|
|
||||
| dense | 0.027 | 0.086 | 0.041 | 0 | 5.4 |
|
||||
| bm25 | 0.544 | 0.586 | 0.524 | 0 | 4.7 |
|
||||
| hybrid-rrf | 0.114 | 0.114 | 0.108 | 0 | 8.4 |
|
||||
|
||||
## Per-query — dense
|
||||
|
||||
| Query | Expected | Top retrieved | MRR | Recall |
|
||||
|---|---|---|---|---|
|
||||
| Warrant herbicide rate for soybean | bayer/warrant, epa_ppls/524-591 | epa_ppls/524-508, epa_ppls/524-521, epa_ppls/42750-176 | 0.00 | 0.00 |
|
||||
| Huskie wheat herbicide tank mix | bayer/huskie, bayer/huskie-complete | epa_ppls/71368-64, epa_ppls/279-9610, epa_ppls/10182-134 | 0.00 | 0.00 |
|
||||
| Harness 20G granular corn herbicide | bayer/harness, epa_ppls/524-487 | epa_ppls/352-612, epa_ppls/352-608, epa_ppls/352-817 | 0.00 | 0.00 |
|
||||
| Laudis tembotrione post-emergence corn | bayer/laudis, epa_ppls/264-860 | bayer/diflexx, epa_ppls/70506-331, epa_ppls/84229-48 | 0.00 | 0.00 |
|
||||
| Roundup Custom glyphosate burndown application rate | epa_ppls/524-677, epa_ppls/524-475 | epa_ppls/42750-122, epa_ppls/5905-656, epa_ppls/228-666 | 0.00 | 0.00 |
|
||||
| Liberty 280 SL glufosinate ammonium soybean | epa_ppls/7969-448 | epa_ppls/71368-111, epa_ppls/84229-45, epa_ppls/7969-500 | 0.00 | 0.00 |
|
||||
| Atrazine 4L corn pre-emergence rate per acre | epa_ppls/5905-7877 | epa_ppls/5905-624, epa_ppls/89167-75, epa_ppls/7969-140 | 0.00 | 0.00 |
|
||||
| Albaugh dicamba DMA salt application restrictions | epa_ppls/42750-40 | epa_ppls/5905-638, epa_ppls/34704-861, epa_ppls/5905-624 | 0.20 | 1.00 |
|
||||
| Authority 4F sulfentrazone soybean residual | epa_ppls/279-3146 | epa_ppls/279-9663, epa_ppls/87290-70, epa_ppls/66222-248 | 0.00 | 0.00 |
|
||||
| Prowl 10-G pendimethalin granular pre-plant | epa_ppls/241-254 | epa_ppls/70506-333, epa_ppls/42750-340, epa_ppls/91234-231 | 0.00 | 0.00 |
|
||||
| Callisto GT mesotrione corn postemergence broadleaf control | epa_ppls/100-1470 | epa_ppls/100-1131, epa_ppls/89167-51, epa_ppls/100-1349 | 0.00 | 0.00 |
|
||||
| Acuron Flexi corn pre-emergence S-metolachlor | epa_ppls/100-1568 | epa_ppls/62719-312, epa_ppls/42750-122, epa_ppls/5905-638 | 0.00 | 0.00 |
|
||||
| Sencor 4 flowable metribuzin soybean waterhemp | epa_ppls/264-735 | epa_ppls/1381-259, epa_ppls/279-9624, epa_ppls/89167-101 | 0.00 | 0.00 |
|
||||
| Broadstrike trifluralin pre-plant incorporated | epa_ppls/62719-222 | epa_ppls/87290-81, epa_ppls/70506-333, epa_ppls/91234-73 | 0.00 | 0.00 |
|
||||
| Headline azoxystrobin pyraclostrobin wheat foliar fungicide | epa_ppls/7969-186 | epa_ppls/100-1222, epa_ppls/100-1164, epa_ppls/87290-63 | 0.00 | 0.00 |
|
||||
| Trivapro pydiflumetofen corn fungicide tar spot | epa_ppls/100-1613 | epa_ppls/66222-250, epa_ppls/264-1209, epa_ppls/62719-346 | 0.00 | 0.00 |
|
||||
| Poncho 600 clothianidin seed treatment corn | epa_ppls/7969-458 | epa_ppls/7969-459, epa_ppls/7969-458, bayer/poncho-beta | 0.50 | 1.00 |
|
||||
| Gustafson Lorsban 30 chlorpyrifos granular corn rootworm | epa_ppls/264-932 | epa_ppls/89167-78, epa_ppls/5481-525, epa_ppls/1381-193 | 0.00 | 0.00 |
|
||||
| RT-3 glyphosate potassium salt herbicide | bayer/rt-3 | bayer/roundup-powermax-3, epa_ppls/19713-597, epa_ppls/19713-606 | 0.25 | 1.00 |
|
||||
| Roundup PowerMAX 3 glyphosate K-salt rate | bayer/roundup-powermax-3, epa_ppls/524-659 | epa_ppls/19713-597, epa_ppls/19713-606, epa_ppls/51036-333 | 0.00 | 0.00 |
|
||||
| Nortron SC ethofumesate sugar beet | bayer/nortron-sc | epa_ppls/71368-25, epa_ppls/42750-122, epa_ppls/524-715 | 0.00 | 0.00 |
|
||||
| DiFlexx Duo tembotrione dicamba corn | bayer/diflexx-duo | epa_ppls/71368-65, epa_ppls/1812-434, epa_ppls/1381-191 | 0.00 | 0.00 |
|
||||
| Corvus thiencarbazone-methyl isoxaflutole corn pre-emergence | bayer/corvus, epa_ppls/264-1066 | epa_ppls/42750-122, bayer/scoparia, epa_ppls/70506-331 | 0.00 | 0.00 |
|
||||
| Capreno tembotrione thiencarbazone corn herbicide | bayer/capreno, epa_ppls/264-1063 | epa_ppls/91234-314, epa_ppls/352-894, epa_ppls/42750-32 | 0.00 | 0.00 |
|
||||
| Tilt propiconazole wheat fungicide rust | epa_ppls/100-617 | epa_ppls/19713-692, epa_ppls/34704-1113, epa_ppls/228-670 | 0.00 | 0.00 |
|
||||
| what controls horseweed marestail before planting soybean | epa_ppls/524-475, epa_ppls/524-677 | epa_ppls/524-716, epa_ppls/524-717, epa_ppls/524-722 | 0.00 | 0.00 |
|
||||
| what can I tank mix with 2,4-D for burndown in spring | epa_ppls/5905-7877, epa_ppls/228-666 | epa_ppls/34704-1158, epa_ppls/264-738, epa_ppls/228-364 | 0.00 | 0.00 |
|
||||
| best fungicide for corn tar spot foliar application | epa_ppls/100-1613, epa_ppls/100-1547 | epa_ppls/100-1178, epa_ppls/87290-63, epa_ppls/100-1262 | 0.00 | 0.00 |
|
||||
| seed treatment to control wireworm in corn | epa_ppls/7969-458, epa_ppls/7969-459 | epa_ppls/10182-212, epa_ppls/1381-231, epa_ppls/42750-300 | 0.00 | 0.00 |
|
||||
| pre-emergence residual herbicide for soybean for waterhemp | epa_ppls/279-3146, epa_ppls/264-735 | epa_ppls/352-675, epa_ppls/279-3564, epa_ppls/279-3589 | 0.00 | 0.00 |
|
||||
| what insecticide for soybean aphid foliar | epa_ppls/279-3206, epa_ppls/264-840 | epa_ppls/264-1157, epa_ppls/264-1159, epa_ppls/279-9615 | 0.00 | 0.00 |
|
||||
| what is the rainfast interval for glyphosate | epa_ppls/524-475, epa_ppls/524-677 | epa_ppls/89167-56, epa_ppls/524-523, epa_ppls/524-707 | 0.00 | 0.00 |
|
||||
| wheat fungicide for fusarium head blight | epa_ppls/7969-186, epa_ppls/100-1547 | bayer/stratego, epa_ppls/7969-246, epa_ppls/66222-250 | 0.00 | 0.00 |
|
||||
| endangered species act precautions for pesticide application | epa_ppls/524-475, epa_ppls/524-591 | epa_ppls/70506-318, epa_ppls/70506-324, epa_ppls/34704-1044 | 0.00 | 0.00 |
|
||||
| what herbicide do I use for postemergence broadleaf in corn | bayer/laudis, bayer/capreno, bayer/diflexx-duo | epa_ppls/352-842, epa_ppls/100-1349, epa_ppls/89167-51 | 0.00 | 0.00 |
|
||||
+18
-59
@@ -1,62 +1,21 @@
|
||||
"""Retriever protocol + concrete implementations.
|
||||
"""Eval-time shim — re-exports the retrievers from rag.retrieval.
|
||||
|
||||
A single matrix dimension per knob (dense / reranked / bm25 / hybrid)
|
||||
so the eval harness can compare them apples-to-apples. Implement these
|
||||
once at Phase 7 and reuse them across every retrieval change.
|
||||
|
||||
Each retriever returns a ranked list of (bundle_id, page_id) tuples
|
||||
deduplicated to the page level (chunks within the same page collapse
|
||||
to one entry; the highest-ranked chunk's position wins).
|
||||
The retrievers live in rag/ so the MCP server can use them at request
|
||||
time without making eval/ a runtime dependency. This file exists so
|
||||
old import paths (`from eval.retrievers import ...`) keep working.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from rag.retrieval import (
|
||||
Retriever,
|
||||
DenseRetriever,
|
||||
BM25Retriever,
|
||||
HybridRetriever,
|
||||
RerankedRetriever,
|
||||
)
|
||||
|
||||
from typing import Protocol, Iterable
|
||||
|
||||
|
||||
class Retriever(Protocol):
|
||||
name: str
|
||||
|
||||
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
|
||||
"""Return up to k (bundle_id, page_id) tuples in rank order."""
|
||||
...
|
||||
|
||||
|
||||
def _collapse_to_pages(chunk_ids: Iterable[tuple[str, str, str]], k: int) -> list[tuple[str, str]]:
|
||||
"""Take a stream of (bundle_id, page_id, chunk_ordinal) and return
|
||||
the first k unique pages in their first-seen order."""
|
||||
seen: set[tuple[str, str]] = set()
|
||||
out: list[tuple[str, str]] = []
|
||||
for bid, pid, _ord in chunk_ids:
|
||||
key = (bid, pid)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(key)
|
||||
if len(out) >= k:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
# TODO Phase 2/3 — implement these once Chroma + the bm25 module are
|
||||
# in place. Each one is small (15-30 LOC). The eval harness imports
|
||||
# from this module by class name.
|
||||
#
|
||||
# class DenseRetriever:
|
||||
# name = "dense"
|
||||
# def __init__(self, collection): self.col = collection
|
||||
# def retrieve(self, query, k=10): ...
|
||||
#
|
||||
# class RerankedRetriever:
|
||||
# name = "dense+rerank"
|
||||
# def __init__(self, collection, rerank_url, pool=200): ...
|
||||
# def retrieve(self, query, k=10): ...
|
||||
#
|
||||
# class BM25Retriever:
|
||||
# name = "bm25"
|
||||
# def __init__(self, bm25_index): ...
|
||||
# def retrieve(self, query, k=10): ...
|
||||
#
|
||||
# class HybridRetriever:
|
||||
# name = "bm25+dense+rrf"
|
||||
# def __init__(self, dense, bm25, k_rrf=60): ...
|
||||
# def retrieve(self, query, k=10): ...
|
||||
__all__ = [
|
||||
"Retriever",
|
||||
"DenseRetriever",
|
||||
"BM25Retriever",
|
||||
"HybridRetriever",
|
||||
"RerankedRetriever",
|
||||
]
|
||||
|
||||
+142
-22
@@ -2,38 +2,58 @@
|
||||
|
||||
Metrics computed per retriever:
|
||||
|
||||
MRR — mean reciprocal rank of the FIRST expected page in the
|
||||
MRR — mean reciprocal rank of the FIRST expected label in the
|
||||
ranked result list (0 if not in top-k).
|
||||
Recall@K — fraction of expected pages that appear in top-K.
|
||||
Recall@K — fraction of expected labels that appear in top-K.
|
||||
nDCG@K — discounted gain weighted by rank position.
|
||||
|
||||
The "right" number depends on what you're measuring. MRR tracks "the
|
||||
first-line answer is correct"; Recall@K tracks "everything relevant
|
||||
is there to draw from"; nDCG@K is a smoother combination of both.
|
||||
For docs-RAG, MRR is usually the headline metric.
|
||||
For labels-RAG, MRR is the headline: "did the farmer-advisor's
|
||||
RAG fetch the right label first try?" Recall@K matters when the
|
||||
LLM needs the broader context. nDCG@K is a smoother combination.
|
||||
|
||||
Usage:
|
||||
|
||||
python -m eval.run_eval \\
|
||||
--queries eval/queries.jsonl \\
|
||||
--k 5 \\
|
||||
--output eval/results/baseline.md
|
||||
python -m eval.run_eval --queries eval/queries.jsonl \\
|
||||
--k 5 --output eval/results/baseline.md
|
||||
|
||||
Each query in queries.jsonl looks like:
|
||||
|
||||
{
|
||||
"query": "what can I spray on soybeans for waterhemp",
|
||||
"expected": [
|
||||
{"source": "epa_ppls", "source_key": "279-3564"},
|
||||
{"source": "bayer", "source_key": "warrant"}
|
||||
],
|
||||
"tags": ["herbicide", "soybean", "waterhemp"]
|
||||
}
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def load_queries(path: Path) -> list[dict]:
|
||||
with open(path) as fh:
|
||||
with open(path, encoding="utf-8") as fh:
|
||||
return [json.loads(line) for line in fh if line.strip()]
|
||||
|
||||
|
||||
def _expected_tuples(q: dict) -> list[tuple[str, str]]:
|
||||
out: list[tuple[str, str]] = []
|
||||
for e in q.get("expected") or []:
|
||||
if isinstance(e, dict) and "source" in e and "source_key" in e:
|
||||
out.append((e["source"], e["source_key"]))
|
||||
elif isinstance(e, (list, tuple)) and len(e) == 2:
|
||||
out.append((str(e[0]), str(e[1])))
|
||||
return out
|
||||
|
||||
|
||||
def reciprocal_rank(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]]) -> float:
|
||||
expected_set = set(expected)
|
||||
for i, page in enumerate(retrieved, start=1):
|
||||
@@ -56,7 +76,6 @@ def ndcg_at_k(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]],
|
||||
for i, page in enumerate(retrieved[:k], start=1):
|
||||
if page in expected_set:
|
||||
dcg += 1.0 / math.log2(i + 1)
|
||||
# Ideal DCG: every expected page in the top positions.
|
||||
idcg = sum(1.0 / math.log2(i + 1) for i in range(1, min(len(expected), k) + 1))
|
||||
return dcg / idcg if idcg else 0.0
|
||||
|
||||
@@ -65,27 +84,128 @@ def main() -> int:
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--queries", type=Path, default=Path("eval/queries.jsonl"))
|
||||
p.add_argument("--k", type=int, default=5)
|
||||
p.add_argument("--pool", type=int, default=50,
|
||||
help="Per-retriever over-fetch pool (for hybrid/rerank).")
|
||||
p.add_argument("--output", type=Path, default=Path("eval/results/baseline.md"))
|
||||
p.add_argument("--retrievers", default="dense,bm25,hybrid",
|
||||
help="Comma-separated list: dense,bm25,hybrid,rerank,hybrid+rerank.")
|
||||
args = p.parse_args()
|
||||
|
||||
if not args.queries.exists():
|
||||
print(f"queries file not found: {args.queries}")
|
||||
print("hint: copy eval/queries.jsonl.example and edit")
|
||||
return 1
|
||||
|
||||
queries = load_queries(args.queries)
|
||||
print(f"loaded {len(queries)} queries")
|
||||
print(f"loaded {len(queries)} queries from {args.queries}")
|
||||
|
||||
# TODO Phase 7: instantiate the retrievers you implemented in
|
||||
# eval/retrievers.py and run each one against each query.
|
||||
# Aggregate MRR / Recall@K / nDCG@K per retriever. Emit a
|
||||
# markdown table to args.output. Commit the file alongside the
|
||||
# PR that changes retrieval.
|
||||
raise NotImplementedError(
|
||||
"Wire up the retrievers in eval/retrievers.py first, then "
|
||||
"fill in this evaluation loop. See PLAN.md Phase 7."
|
||||
from eval.retrievers import (
|
||||
DenseRetriever, BM25Retriever, HybridRetriever, RerankedRetriever
|
||||
)
|
||||
|
||||
wanted = [x.strip() for x in args.retrievers.split(",") if x.strip()]
|
||||
dense = DenseRetriever()
|
||||
bm25 = BM25Retriever()
|
||||
|
||||
retrievers: list[tuple[str, "object"]] = []
|
||||
if "dense" in wanted:
|
||||
retrievers.append(("dense", dense))
|
||||
if "bm25" in wanted:
|
||||
retrievers.append(("bm25", bm25))
|
||||
if "hybrid" in wanted:
|
||||
retrievers.append(("hybrid-rrf", HybridRetriever(dense=dense, bm25=bm25, pool=args.pool)))
|
||||
if "rerank" in wanted:
|
||||
retrievers.append(("dense+rerank",
|
||||
RerankedRetriever(base=dense, pool=args.pool)))
|
||||
if "hybrid+rerank" in wanted:
|
||||
retrievers.append(("hybrid+rerank",
|
||||
RerankedRetriever(
|
||||
base=HybridRetriever(dense=dense, bm25=bm25, pool=args.pool),
|
||||
pool=args.pool,
|
||||
)))
|
||||
|
||||
if not retrievers:
|
||||
print(f"no valid retrievers in --retrievers={args.retrievers!r}")
|
||||
return 1
|
||||
|
||||
# Run
|
||||
results: dict[str, dict] = {} # name -> {mrr, recall, ndcg, per_query: [...]}
|
||||
for name, retriever in retrievers:
|
||||
print(f"\n=== {name} ===")
|
||||
per_query = []
|
||||
t0 = time.time()
|
||||
errors = 0
|
||||
for q in queries:
|
||||
expected = _expected_tuples(q)
|
||||
try:
|
||||
retrieved = retriever.retrieve(q["query"], k=max(args.k, args.pool))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f" ERROR on {q['query']!r}: {exc}")
|
||||
errors += 1
|
||||
retrieved = []
|
||||
mrr = reciprocal_rank(retrieved[:args.k], expected)
|
||||
rec = recall_at_k(retrieved, expected, args.k)
|
||||
ndcg = ndcg_at_k(retrieved, expected, args.k)
|
||||
per_query.append({
|
||||
"query": q["query"],
|
||||
"expected": expected,
|
||||
"retrieved": retrieved[:args.k],
|
||||
"mrr": mrr, "recall": rec, "ndcg": ndcg,
|
||||
})
|
||||
elapsed = time.time() - t0
|
||||
results[name] = {
|
||||
"mrr": sum(r["mrr"] for r in per_query) / len(per_query),
|
||||
"recall": sum(r["recall"] for r in per_query) / len(per_query),
|
||||
"ndcg": sum(r["ndcg"] for r in per_query) / len(per_query),
|
||||
"elapsed": elapsed,
|
||||
"errors": errors,
|
||||
"per_query": per_query,
|
||||
}
|
||||
print(f" MRR={results[name]['mrr']:.3f} "
|
||||
f"Recall@{args.k}={results[name]['recall']:.3f} "
|
||||
f"nDCG@{args.k}={results[name]['ndcg']:.3f} "
|
||||
f"({elapsed:.1f}s, {errors} errors)")
|
||||
|
||||
# Render markdown report
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
lines: list[str] = []
|
||||
lines.append(f"# Eval results — {args.queries.name}")
|
||||
lines.append("")
|
||||
lines.append(f"- queries: {len(queries)}")
|
||||
lines.append(f"- k: {args.k}")
|
||||
lines.append(f"- pool: {args.pool}")
|
||||
lines.append(f"- retrievers: {', '.join(name for name, _ in retrievers)}")
|
||||
lines.append("")
|
||||
lines.append("## Summary")
|
||||
lines.append("")
|
||||
lines.append(f"| Retriever | MRR | Recall@{args.k} | nDCG@{args.k} | Errors | Time (s) |")
|
||||
lines.append("|---|---|---|---|---|---|")
|
||||
for name, _ in retrievers:
|
||||
r = results[name]
|
||||
lines.append(
|
||||
f"| {name} | {r['mrr']:.3f} | {r['recall']:.3f} | {r['ndcg']:.3f} "
|
||||
f"| {r['errors']} | {r['elapsed']:.1f} |"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# Per-query breakdown for the first retriever (typically dense) so we
|
||||
# can see WHICH queries are missing.
|
||||
first_name = retrievers[0][0]
|
||||
lines.append(f"## Per-query — {first_name}")
|
||||
lines.append("")
|
||||
lines.append("| Query | Expected | Top retrieved | MRR | Recall |")
|
||||
lines.append("|---|---|---|---|---|")
|
||||
for r in results[first_name]["per_query"]:
|
||||
exp = ", ".join(f"{s}/{k}" for s, k in r["expected"]) or "—"
|
||||
ret = ", ".join(f"{s}/{k}" for s, k in r["retrieved"][:3]) or "—"
|
||||
lines.append(
|
||||
f"| {r['query'][:60]} | {exp[:60]} | {ret[:80]} | "
|
||||
f"{r['mrr']:.2f} | {r['recall']:.2f} |"
|
||||
)
|
||||
|
||||
args.output.write_text("\n".join(lines), encoding="utf-8")
|
||||
print(f"\nReport written to {args.output}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
||||
@@ -0,0 +1,261 @@
|
||||
"""Retriever protocol + concrete implementations for the labels corpus.
|
||||
|
||||
A single matrix dimension per knob (dense / reranked / bm25 / hybrid)
|
||||
so the eval harness can compare them apples-to-apples.
|
||||
|
||||
Each retriever returns a ranked list of ``(source, source_key)`` tuples
|
||||
deduplicated to the label level (chunks within the same label collapse
|
||||
to one entry; the highest-ranked chunk's position wins). The page-level
|
||||
view matches how MCP consumers think — "give me the right label" not
|
||||
"give me the right chunk".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Protocol
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
CHROMA_DIR = Path(os.environ.get("PPLS_CHROMA_DIR") or REPO_ROOT / "chroma")
|
||||
BM25_DB = Path(os.environ.get("BM25_DB",
|
||||
str(REPO_ROOT / "bm25" / "ppls_docs.db")))
|
||||
COLLECTION = f"{os.environ.get('PRODUCT_NAME', 'ppls')}_docs"
|
||||
|
||||
|
||||
class Retriever(Protocol):
|
||||
name: str
|
||||
|
||||
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
|
||||
"""Return up to k (source, source_key) tuples in rank order."""
|
||||
...
|
||||
|
||||
|
||||
def _collapse_chunks_to_labels(
|
||||
ranked_chunks: Iterable[tuple[str, str, int]], k: int
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Stream of (source, source_key, ordinal) → top-k unique (source, source_key)
|
||||
in first-seen order."""
|
||||
seen: set[tuple[str, str]] = set()
|
||||
out: list[tuple[str, str]] = []
|
||||
for source, source_key, _ord in ranked_chunks:
|
||||
key = (source, source_key)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(key)
|
||||
if len(out) >= k:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def _parse_chunk_id(chunk_id: str) -> tuple[str, str, int]:
|
||||
"""Chunk IDs look like 'source::source_key::ordinal'. Robust to
|
||||
source_keys that contain '::' (none do today, but be defensive)."""
|
||||
parts = chunk_id.rsplit("::", 2)
|
||||
if len(parts) != 3:
|
||||
return ("", chunk_id, 0)
|
||||
source, source_key, ord_str = parts
|
||||
try:
|
||||
ord_int = int(ord_str)
|
||||
except ValueError:
|
||||
ord_int = 0
|
||||
return (source, source_key, ord_int)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dense (Chroma) retriever
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DenseRetriever:
|
||||
name = "dense"
|
||||
|
||||
def __init__(self, collection=None, over_fetch_factor: int = 4):
|
||||
self.over_fetch_factor = over_fetch_factor
|
||||
self._col = collection
|
||||
|
||||
def _collection(self):
|
||||
if self._col is not None:
|
||||
return self._col
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from rag.embeddings import embedding_function
|
||||
client = chromadb.PersistentClient(
|
||||
path=str(CHROMA_DIR),
|
||||
settings=Settings(anonymized_telemetry=False),
|
||||
)
|
||||
self._col = client.get_collection(
|
||||
COLLECTION, embedding_function=embedding_function()
|
||||
)
|
||||
return self._col
|
||||
|
||||
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
|
||||
col = self._collection()
|
||||
n_fetch = max(k * self.over_fetch_factor, k)
|
||||
res = col.query(query_texts=[query], n_results=n_fetch)
|
||||
ids = res.get("ids", [[]])[0]
|
||||
ranked: list[tuple[str, str, int]] = []
|
||||
for cid in ids:
|
||||
ranked.append(_parse_chunk_id(cid))
|
||||
return _collapse_chunks_to_labels(ranked, k)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BM25 (SQLite FTS5) retriever
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BM25Retriever:
|
||||
"""Wraps ``rag.bm25.BM25Index`` so eval/server can call .retrieve()
|
||||
on it the same way as the dense retriever. The index itself handles
|
||||
FTS5 query sanitization + OR-of-tokens semantics."""
|
||||
|
||||
name = "bm25"
|
||||
|
||||
def __init__(self, db_path: Path = BM25_DB, over_fetch_factor: int = 4):
|
||||
from rag.bm25 import BM25Index
|
||||
self._idx = BM25Index(db_path)
|
||||
self.over_fetch_factor = over_fetch_factor
|
||||
|
||||
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
|
||||
n_fetch = max(k * self.over_fetch_factor, k)
|
||||
hits = self._idx.query(query, n=n_fetch)
|
||||
ranked = [_parse_chunk_id(cid) for cid, _score in hits]
|
||||
return _collapse_chunks_to_labels(ranked, k)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hybrid retriever (BM25 + dense, RRF fusion)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class HybridRetriever:
|
||||
"""Reciprocal Rank Fusion of dense + BM25 results. The fused score
|
||||
for a page p is sum over retrievers r of 1 / (k_rrf + rank_r(p)).
|
||||
Pages absent from a retriever contribute 0 from it."""
|
||||
|
||||
name = "hybrid-rrf"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dense: DenseRetriever | None = None,
|
||||
bm25: BM25Retriever | None = None,
|
||||
k_rrf: int = 60,
|
||||
pool: int = 50,
|
||||
):
|
||||
self.dense = dense or DenseRetriever()
|
||||
self.bm25 = bm25 or BM25Retriever()
|
||||
self.k_rrf = k_rrf
|
||||
self.pool = pool
|
||||
|
||||
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
|
||||
dense_pages = self.dense.retrieve(query, k=self.pool)
|
||||
bm25_pages = self.bm25.retrieve(query, k=self.pool)
|
||||
scores: dict[tuple[str, str], float] = {}
|
||||
for rank, page in enumerate(dense_pages, start=1):
|
||||
scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank)
|
||||
for rank, page in enumerate(bm25_pages, start=1):
|
||||
scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank)
|
||||
fused = sorted(scores.items(), key=lambda kv: -kv[1])
|
||||
return [page for page, _ in fused[:k]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reranker (jina-reranker via llama.cpp /v1/rerank)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RerankedRetriever:
|
||||
"""Take a base retriever's pool, fetch full chunk text for each page's
|
||||
top chunk, send (query, chunk_text) pairs to a llama.cpp /v1/rerank
|
||||
endpoint, then rerank pages by the returned scores.
|
||||
|
||||
For eval we operate page-level. We pick the first chunk per page from
|
||||
the base retriever's chunk-level output. To get the chunk text we
|
||||
re-query Chroma by chunk id."""
|
||||
|
||||
name = "dense+rerank"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base: Retriever | None = None,
|
||||
rerank_url: str | None = None,
|
||||
pool: int = 50,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
self.base = base or DenseRetriever()
|
||||
self.rerank_url = (rerank_url or os.environ.get("RERANK_URL", "")).rstrip("/")
|
||||
self.pool = pool
|
||||
self.timeout = timeout
|
||||
self._col = None
|
||||
|
||||
@property
|
||||
def name_with_base(self) -> str:
|
||||
return f"{self.base.name}+rerank"
|
||||
|
||||
def _collection(self):
|
||||
if self._col is not None:
|
||||
return self._col
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
client = chromadb.PersistentClient(
|
||||
path=str(CHROMA_DIR),
|
||||
settings=Settings(anonymized_telemetry=False),
|
||||
)
|
||||
# We don't need the embedder for fetch-by-id; pass embedding_function=None
|
||||
self._col = client.get_collection(COLLECTION)
|
||||
return self._col
|
||||
|
||||
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
|
||||
if not self.rerank_url:
|
||||
# Fail open to base retriever — useful in eval to compare base vs
|
||||
# base+rerank when the reranker is offline.
|
||||
log.warning("RERANK_URL unset; falling back to base retriever")
|
||||
return self.base.retrieve(query, k=k)
|
||||
|
||||
pages = self.base.retrieve(query, k=self.pool)
|
||||
if not pages:
|
||||
return []
|
||||
# Fetch one representative chunk per page (the first chunk, ordinal=0
|
||||
# if it exists, else any). For eval simplicity we approximate by
|
||||
# fetching by metadata where (source, source_key) and taking the
|
||||
# first hit.
|
||||
col = self._collection()
|
||||
docs: list[str] = []
|
||||
kept_pages: list[tuple[str, str]] = []
|
||||
for source, source_key in pages:
|
||||
where = {"$and": [{"source": source}, {"source_key": source_key}]}
|
||||
got = col.get(where=where, limit=1, include=["documents"])
|
||||
d = (got.get("documents") or [None])[0]
|
||||
if not d:
|
||||
continue
|
||||
# Truncate to keep under the reranker's per-pair context limit
|
||||
docs.append(d[:2000])
|
||||
kept_pages.append((source, source_key))
|
||||
|
||||
if not docs:
|
||||
return []
|
||||
|
||||
import httpx
|
||||
try:
|
||||
r = httpx.post(
|
||||
f"{self.rerank_url}/v1/rerank",
|
||||
json={"query": query, "documents": docs},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("rerank failed (%s) — falling back to base order", exc)
|
||||
return kept_pages[:k]
|
||||
|
||||
# llama.cpp returns {"results": [{"index": i, "relevance_score": s}, ...]}
|
||||
results = data.get("results") or []
|
||||
scored: list[tuple[float, tuple[str, str]]] = []
|
||||
for r_item in results:
|
||||
idx = r_item.get("index")
|
||||
score = r_item.get("relevance_score") or r_item.get("score") or 0.0
|
||||
if isinstance(idx, int) and 0 <= idx < len(kept_pages):
|
||||
scored.append((score, kept_pages[idx]))
|
||||
scored.sort(key=lambda x: -x[0])
|
||||
return [p for _, p in scored[:k]]
|
||||
Reference in New Issue
Block a user