Files
crop-chem-docs/rag/retrieval.py
T
justin 1a45280e45 rename: ppls-docs → crop-chem-docs
Repo/project rename to better reflect scope. PPLS is EPA's term for
their Pesticide Product Label System — accurate when the corpus was
EPA-only, narrow now that it also pulls from Bayer's own catalog
(and may expand to Syngenta/Corteva/BASF/FMC labels in the future).
crop-chem-docs scopes flexibly without acronyms to explain.

Renames:
- directory:           ppls-docs            → crop-chem-docs
- PRODUCT_NAME:        ppls                 → crop_chem
- Chroma collection:   ppls_docs            → crop_chem_docs  (in-place via .modify(), no re-embed)
- BM25 db:             bm25/ppls_docs.db    → bm25/crop_chem_docs.db
- MCP tool name:       ppls_api_lessons     → crop_chem_api_lessons
- FastMCP server name: ppls-docs            → crop-chem-docs
- Env vars:            PPLS_CORPUS_ROOT     → CORPUS_ROOT
                       PPLS_CHROMA_DIR      → CHROMA_DIR_OVERRIDE
- User-Agent:          ppls-docs-scraper    → crop-chem-docs-scraper

Preserved (intentional, correct):
- epa_ppls (source id) — refers specifically to EPA's PPLS database
- "EPA PPLS" mentions in regulatory text (lessons.md, server docstrings)
- PPLS_API_BASE / PPLS_PDF_BASE / PPLS_INDEX_URL_TEMPLATE in
  scrape/sources/epa_ppls.py — these point at EPA's actual endpoints

Memory entries get updated in a follow-up commit so the rename is
isolated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-24 12:25:59 -04:00

262 lines
9.5 KiB
Python

"""Retriever protocol + concrete implementations for the labels corpus.
A single matrix dimension per knob (dense / reranked / bm25 / hybrid)
so the eval harness can compare them apples-to-apples.
Each retriever returns a ranked list of ``(source, source_key)`` tuples
deduplicated to the label level (chunks within the same label collapse
to one entry; the highest-ranked chunk's position wins). The page-level
view matches how MCP consumers think — "give me the right label" not
"give me the right chunk".
"""
from __future__ import annotations
import logging
import os
import sqlite3
from pathlib import Path
from typing import Iterable, Protocol
log = logging.getLogger(__name__)
REPO_ROOT = Path(__file__).resolve().parent.parent
CHROMA_DIR = Path(os.environ.get("CHROMA_DIR_OVERRIDE") or REPO_ROOT / "chroma")
BM25_DB = Path(os.environ.get("BM25_DB",
str(REPO_ROOT / "bm25" / "crop_chem_docs.db")))
COLLECTION = f"{os.environ.get('PRODUCT_NAME', 'crop_chem')}_docs"
class Retriever(Protocol):
name: str
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
"""Return up to k (source, source_key) tuples in rank order."""
...
def _collapse_chunks_to_labels(
ranked_chunks: Iterable[tuple[str, str, int]], k: int
) -> list[tuple[str, str]]:
"""Stream of (source, source_key, ordinal) → top-k unique (source, source_key)
in first-seen order."""
seen: set[tuple[str, str]] = set()
out: list[tuple[str, str]] = []
for source, source_key, _ord in ranked_chunks:
key = (source, source_key)
if key in seen:
continue
seen.add(key)
out.append(key)
if len(out) >= k:
break
return out
def _parse_chunk_id(chunk_id: str) -> tuple[str, str, int]:
"""Chunk IDs look like 'source::source_key::ordinal'. Robust to
source_keys that contain '::' (none do today, but be defensive)."""
parts = chunk_id.rsplit("::", 2)
if len(parts) != 3:
return ("", chunk_id, 0)
source, source_key, ord_str = parts
try:
ord_int = int(ord_str)
except ValueError:
ord_int = 0
return (source, source_key, ord_int)
# ---------------------------------------------------------------------------
# Dense (Chroma) retriever
# ---------------------------------------------------------------------------
class DenseRetriever:
name = "dense"
def __init__(self, collection=None, over_fetch_factor: int = 4):
self.over_fetch_factor = over_fetch_factor
self._col = collection
def _collection(self):
if self._col is not None:
return self._col
import chromadb
from chromadb.config import Settings
from rag.embeddings import embedding_function
client = chromadb.PersistentClient(
path=str(CHROMA_DIR),
settings=Settings(anonymized_telemetry=False),
)
self._col = client.get_collection(
COLLECTION, embedding_function=embedding_function()
)
return self._col
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
col = self._collection()
n_fetch = max(k * self.over_fetch_factor, k)
res = col.query(query_texts=[query], n_results=n_fetch)
ids = res.get("ids", [[]])[0]
ranked: list[tuple[str, str, int]] = []
for cid in ids:
ranked.append(_parse_chunk_id(cid))
return _collapse_chunks_to_labels(ranked, k)
# ---------------------------------------------------------------------------
# BM25 (SQLite FTS5) retriever
# ---------------------------------------------------------------------------
class BM25Retriever:
"""Wraps ``rag.bm25.BM25Index`` so eval/server can call .retrieve()
on it the same way as the dense retriever. The index itself handles
FTS5 query sanitization + OR-of-tokens semantics."""
name = "bm25"
def __init__(self, db_path: Path = BM25_DB, over_fetch_factor: int = 4):
from rag.bm25 import BM25Index
self._idx = BM25Index(db_path)
self.over_fetch_factor = over_fetch_factor
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
n_fetch = max(k * self.over_fetch_factor, k)
hits = self._idx.query(query, n=n_fetch)
ranked = [_parse_chunk_id(cid) for cid, _score in hits]
return _collapse_chunks_to_labels(ranked, k)
# ---------------------------------------------------------------------------
# Hybrid retriever (BM25 + dense, RRF fusion)
# ---------------------------------------------------------------------------
class HybridRetriever:
"""Reciprocal Rank Fusion of dense + BM25 results. The fused score
for a page p is sum over retrievers r of 1 / (k_rrf + rank_r(p)).
Pages absent from a retriever contribute 0 from it."""
name = "hybrid-rrf"
def __init__(
self,
dense: DenseRetriever | None = None,
bm25: BM25Retriever | None = None,
k_rrf: int = 60,
pool: int = 50,
):
self.dense = dense or DenseRetriever()
self.bm25 = bm25 or BM25Retriever()
self.k_rrf = k_rrf
self.pool = pool
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
dense_pages = self.dense.retrieve(query, k=self.pool)
bm25_pages = self.bm25.retrieve(query, k=self.pool)
scores: dict[tuple[str, str], float] = {}
for rank, page in enumerate(dense_pages, start=1):
scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank)
for rank, page in enumerate(bm25_pages, start=1):
scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank)
fused = sorted(scores.items(), key=lambda kv: -kv[1])
return [page for page, _ in fused[:k]]
# ---------------------------------------------------------------------------
# Reranker (jina-reranker via llama.cpp /v1/rerank)
# ---------------------------------------------------------------------------
class RerankedRetriever:
"""Take a base retriever's pool, fetch full chunk text for each page's
top chunk, send (query, chunk_text) pairs to a llama.cpp /v1/rerank
endpoint, then rerank pages by the returned scores.
For eval we operate page-level. We pick the first chunk per page from
the base retriever's chunk-level output. To get the chunk text we
re-query Chroma by chunk id."""
name = "dense+rerank"
def __init__(
self,
base: Retriever | None = None,
rerank_url: str | None = None,
pool: int = 50,
timeout: float = 30.0,
):
self.base = base or DenseRetriever()
self.rerank_url = (rerank_url or os.environ.get("RERANK_URL", "")).rstrip("/")
self.pool = pool
self.timeout = timeout
self._col = None
@property
def name_with_base(self) -> str:
return f"{self.base.name}+rerank"
def _collection(self):
if self._col is not None:
return self._col
import chromadb
from chromadb.config import Settings
client = chromadb.PersistentClient(
path=str(CHROMA_DIR),
settings=Settings(anonymized_telemetry=False),
)
# We don't need the embedder for fetch-by-id; pass embedding_function=None
self._col = client.get_collection(COLLECTION)
return self._col
def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]:
if not self.rerank_url:
# Fail open to base retriever — useful in eval to compare base vs
# base+rerank when the reranker is offline.
log.warning("RERANK_URL unset; falling back to base retriever")
return self.base.retrieve(query, k=k)
pages = self.base.retrieve(query, k=self.pool)
if not pages:
return []
# Fetch one representative chunk per page (the first chunk, ordinal=0
# if it exists, else any). For eval simplicity we approximate by
# fetching by metadata where (source, source_key) and taking the
# first hit.
col = self._collection()
docs: list[str] = []
kept_pages: list[tuple[str, str]] = []
for source, source_key in pages:
where = {"$and": [{"source": source}, {"source_key": source_key}]}
got = col.get(where=where, limit=1, include=["documents"])
d = (got.get("documents") or [None])[0]
if not d:
continue
# Truncate to keep under the reranker's per-pair context limit
docs.append(d[:2000])
kept_pages.append((source, source_key))
if not docs:
return []
import httpx
try:
r = httpx.post(
f"{self.rerank_url}/v1/rerank",
json={"query": query, "documents": docs},
timeout=self.timeout,
)
r.raise_for_status()
data = r.json()
except Exception as exc: # noqa: BLE001
log.warning("rerank failed (%s) — falling back to base order", exc)
return kept_pages[:k]
# llama.cpp returns {"results": [{"index": i, "relevance_score": s}, ...]}
results = data.get("results") or []
scored: list[tuple[float, tuple[str, str]]] = []
for r_item in results:
idx = r_item.get("index")
score = r_item.get("relevance_score") or r_item.get("score") or 0.0
if isinstance(idx, int) and 0 <= idx < len(kept_pages):
scored.append((score, kept_pages[idx]))
scored.sort(key=lambda x: -x[0])
return [p for _, p in scored[:k]]