Files
justin 3c3178a6ad eval: GPU rerank baseline + CLI fix
GPU eval (hybrid+rerank, RERANK_URL=http://10.10.1.65:8082):
  MRR=0.672  Recall@5=0.638  nDCG@5=0.621  (35 queries, 1 transient
  500, otherwise clean)

Quality identical to the CPU rerank run as expected — only latency
changed (single rerank call dropped from ~23s to ~0.7-1.5s on the
Tesla P4). Per-query report at eval/results/with_rerank_gpu.md.

CLI parser fix: `--retrievers dense+rerank,hybrid+rerank` now
correctly wires the dense+rerank variant. Previously only literal
"rerank" (without prefix) matched the dense+rerank branch, so
combined-retriever runs silently dropped dense+rerank.

(Note: the eval's RerankedRetriever does 50 individual Chroma `get`
calls per query to fetch chunk text by (source, source_key); this
adds ~15s per query of pure SQLite lookup overhead. Not a production
concern — docs_mcp/server.py's _rerank_pool reranks docs already in
the dense pool, no extra Chroma round-trips. Worth tightening the
eval-side impl on a later pass.)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-24 12:12:51 -04:00

213 lines
7.8 KiB
Python

"""Run all retrievers against eval/queries.jsonl, emit a markdown report.
Metrics computed per retriever:
MRR — mean reciprocal rank of the FIRST expected label in the
ranked result list (0 if not in top-k).
Recall@K — fraction of expected labels that appear in top-K.
nDCG@K — discounted gain weighted by rank position.
For labels-RAG, MRR is the headline: "did the farmer-advisor's
RAG fetch the right label first try?" Recall@K matters when the
LLM needs the broader context. nDCG@K is a smoother combination.
Usage:
python -m eval.run_eval --queries eval/queries.jsonl \\
--k 5 --output eval/results/baseline.md
Each query in queries.jsonl looks like:
{
"query": "what can I spray on soybeans for waterhemp",
"expected": [
{"source": "epa_ppls", "source_key": "279-3564"},
{"source": "bayer", "source_key": "warrant"}
],
"tags": ["herbicide", "soybean", "waterhemp"]
}
"""
from __future__ import annotations
import argparse
import json
import math
import os
import time
import traceback
from pathlib import Path
from typing import Iterable
def load_queries(path: Path) -> list[dict]:
with open(path, encoding="utf-8") as fh:
return [json.loads(line) for line in fh if line.strip()]
def _expected_tuples(q: dict) -> list[tuple[str, str]]:
out: list[tuple[str, str]] = []
for e in q.get("expected") or []:
if isinstance(e, dict) and "source" in e and "source_key" in e:
out.append((e["source"], e["source_key"]))
elif isinstance(e, (list, tuple)) and len(e) == 2:
out.append((str(e[0]), str(e[1])))
return out
def reciprocal_rank(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]]) -> float:
expected_set = set(expected)
for i, page in enumerate(retrieved, start=1):
if page in expected_set:
return 1.0 / i
return 0.0
def recall_at_k(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]], k: int) -> float:
if not expected:
return 0.0
retrieved_set = set(retrieved[:k])
hits = sum(1 for e in expected if e in retrieved_set)
return hits / len(expected)
def ndcg_at_k(retrieved: list[tuple[str, str]], expected: list[tuple[str, str]], k: int) -> float:
expected_set = set(expected)
dcg = 0.0
for i, page in enumerate(retrieved[:k], start=1):
if page in expected_set:
dcg += 1.0 / math.log2(i + 1)
idcg = sum(1.0 / math.log2(i + 1) for i in range(1, min(len(expected), k) + 1))
return dcg / idcg if idcg else 0.0
def main() -> int:
p = argparse.ArgumentParser()
p.add_argument("--queries", type=Path, default=Path("eval/queries.jsonl"))
p.add_argument("--k", type=int, default=5)
p.add_argument("--pool", type=int, default=50,
help="Per-retriever over-fetch pool (for hybrid/rerank).")
p.add_argument("--output", type=Path, default=Path("eval/results/baseline.md"))
p.add_argument("--retrievers", default="dense,bm25,hybrid",
help="Comma-separated list: dense,bm25,hybrid,rerank,hybrid+rerank.")
args = p.parse_args()
if not args.queries.exists():
print(f"queries file not found: {args.queries}")
return 1
queries = load_queries(args.queries)
print(f"loaded {len(queries)} queries from {args.queries}")
from eval.retrievers import (
DenseRetriever, BM25Retriever, HybridRetriever, RerankedRetriever
)
wanted = {x.strip() for x in args.retrievers.split(",") if x.strip()}
dense = DenseRetriever()
bm25 = BM25Retriever()
retrievers: list[tuple[str, "object"]] = []
if "dense" in wanted:
retrievers.append(("dense", dense))
if "bm25" in wanted:
retrievers.append(("bm25", bm25))
if "hybrid" in wanted:
retrievers.append(("hybrid-rrf", HybridRetriever(dense=dense, bm25=bm25, pool=args.pool)))
# Accept either "rerank" or "dense+rerank" for the dense-base reranker.
if "rerank" in wanted or "dense+rerank" in wanted:
retrievers.append(("dense+rerank",
RerankedRetriever(base=dense, pool=args.pool)))
if "hybrid+rerank" in wanted:
retrievers.append(("hybrid+rerank",
RerankedRetriever(
base=HybridRetriever(dense=dense, bm25=bm25, pool=args.pool),
pool=args.pool,
)))
if not retrievers:
print(f"no valid retrievers in --retrievers={args.retrievers!r}")
return 1
# Run
results: dict[str, dict] = {} # name -> {mrr, recall, ndcg, per_query: [...]}
for name, retriever in retrievers:
print(f"\n=== {name} ===")
per_query = []
t0 = time.time()
errors = 0
for q in queries:
expected = _expected_tuples(q)
try:
retrieved = retriever.retrieve(q["query"], k=max(args.k, args.pool))
except Exception as exc: # noqa: BLE001
print(f" ERROR on {q['query']!r}: {exc}")
errors += 1
retrieved = []
mrr = reciprocal_rank(retrieved[:args.k], expected)
rec = recall_at_k(retrieved, expected, args.k)
ndcg = ndcg_at_k(retrieved, expected, args.k)
per_query.append({
"query": q["query"],
"expected": expected,
"retrieved": retrieved[:args.k],
"mrr": mrr, "recall": rec, "ndcg": ndcg,
})
elapsed = time.time() - t0
results[name] = {
"mrr": sum(r["mrr"] for r in per_query) / len(per_query),
"recall": sum(r["recall"] for r in per_query) / len(per_query),
"ndcg": sum(r["ndcg"] for r in per_query) / len(per_query),
"elapsed": elapsed,
"errors": errors,
"per_query": per_query,
}
print(f" MRR={results[name]['mrr']:.3f} "
f"Recall@{args.k}={results[name]['recall']:.3f} "
f"nDCG@{args.k}={results[name]['ndcg']:.3f} "
f"({elapsed:.1f}s, {errors} errors)")
# Render markdown report
args.output.parent.mkdir(parents=True, exist_ok=True)
lines: list[str] = []
lines.append(f"# Eval results — {args.queries.name}")
lines.append("")
lines.append(f"- queries: {len(queries)}")
lines.append(f"- k: {args.k}")
lines.append(f"- pool: {args.pool}")
lines.append(f"- retrievers: {', '.join(name for name, _ in retrievers)}")
lines.append("")
lines.append("## Summary")
lines.append("")
lines.append(f"| Retriever | MRR | Recall@{args.k} | nDCG@{args.k} | Errors | Time (s) |")
lines.append("|---|---|---|---|---|---|")
for name, _ in retrievers:
r = results[name]
lines.append(
f"| {name} | {r['mrr']:.3f} | {r['recall']:.3f} | {r['ndcg']:.3f} "
f"| {r['errors']} | {r['elapsed']:.1f} |"
)
lines.append("")
# Per-query breakdown for the first retriever (typically dense) so we
# can see WHICH queries are missing.
first_name = retrievers[0][0]
lines.append(f"## Per-query — {first_name}")
lines.append("")
lines.append("| Query | Expected | Top retrieved | MRR | Recall |")
lines.append("|---|---|---|---|---|")
for r in results[first_name]["per_query"]:
exp = ", ".join(f"{s}/{k}" for s, k in r["expected"]) or "—"
ret = ", ".join(f"{s}/{k}" for s, k in r["retrieved"][:3]) or "—"
lines.append(
f"| {r['query'][:60]} | {exp[:60]} | {ret[:80]} | "
f"{r['mrr']:.2f} | {r['recall']:.2f} |"
)
args.output.write_text("\n".join(lines), encoding="utf-8")
print(f"\nReport written to {args.output}")
return 0
if __name__ == "__main__":
raise SystemExit(main())