"""Retriever protocol + concrete implementations for the labels corpus. A single matrix dimension per knob (dense / reranked / bm25 / hybrid) so the eval harness can compare them apples-to-apples. Each retriever returns a ranked list of ``(source, source_key)`` tuples deduplicated to the label level (chunks within the same label collapse to one entry; the highest-ranked chunk's position wins). The page-level view matches how MCP consumers think — "give me the right label" not "give me the right chunk". """ from __future__ import annotations import logging import os import sqlite3 from pathlib import Path from typing import Iterable, Protocol log = logging.getLogger(__name__) REPO_ROOT = Path(__file__).resolve().parent.parent CHROMA_DIR = Path(os.environ.get("CHROMA_DIR_OVERRIDE") or REPO_ROOT / "chroma") BM25_DB = Path(os.environ.get("BM25_DB", str(REPO_ROOT / "bm25" / "crop_chem_docs.db"))) COLLECTION = f"{os.environ.get('PRODUCT_NAME', 'crop_chem')}_docs" class Retriever(Protocol): name: str def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: """Return up to k (source, source_key) tuples in rank order.""" ... def _collapse_chunks_to_labels( ranked_chunks: Iterable[tuple[str, str, int]], k: int ) -> list[tuple[str, str]]: """Stream of (source, source_key, ordinal) → top-k unique (source, source_key) in first-seen order.""" seen: set[tuple[str, str]] = set() out: list[tuple[str, str]] = [] for source, source_key, _ord in ranked_chunks: key = (source, source_key) if key in seen: continue seen.add(key) out.append(key) if len(out) >= k: break return out def _parse_chunk_id(chunk_id: str) -> tuple[str, str, int]: """Chunk IDs look like 'source::source_key::ordinal'. Robust to source_keys that contain '::' (none do today, but be defensive).""" parts = chunk_id.rsplit("::", 2) if len(parts) != 3: return ("", chunk_id, 0) source, source_key, ord_str = parts try: ord_int = int(ord_str) except ValueError: ord_int = 0 return (source, source_key, ord_int) # --------------------------------------------------------------------------- # Dense (Chroma) retriever # --------------------------------------------------------------------------- class DenseRetriever: name = "dense" def __init__(self, collection=None, over_fetch_factor: int = 4): self.over_fetch_factor = over_fetch_factor self._col = collection def _collection(self): if self._col is not None: return self._col import chromadb from chromadb.config import Settings from rag.embeddings import embedding_function client = chromadb.PersistentClient( path=str(CHROMA_DIR), settings=Settings(anonymized_telemetry=False), ) self._col = client.get_collection( COLLECTION, embedding_function=embedding_function() ) return self._col def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: col = self._collection() n_fetch = max(k * self.over_fetch_factor, k) res = col.query(query_texts=[query], n_results=n_fetch) ids = res.get("ids", [[]])[0] ranked: list[tuple[str, str, int]] = [] for cid in ids: ranked.append(_parse_chunk_id(cid)) return _collapse_chunks_to_labels(ranked, k) # --------------------------------------------------------------------------- # BM25 (SQLite FTS5) retriever # --------------------------------------------------------------------------- class BM25Retriever: """Wraps ``rag.bm25.BM25Index`` so eval/server can call .retrieve() on it the same way as the dense retriever. The index itself handles FTS5 query sanitization + OR-of-tokens semantics.""" name = "bm25" def __init__(self, db_path: Path = BM25_DB, over_fetch_factor: int = 4): from rag.bm25 import BM25Index self._idx = BM25Index(db_path) self.over_fetch_factor = over_fetch_factor def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: n_fetch = max(k * self.over_fetch_factor, k) hits = self._idx.query(query, n=n_fetch) ranked = [_parse_chunk_id(cid) for cid, _score in hits] return _collapse_chunks_to_labels(ranked, k) # --------------------------------------------------------------------------- # Hybrid retriever (BM25 + dense, RRF fusion) # --------------------------------------------------------------------------- class HybridRetriever: """Reciprocal Rank Fusion of dense + BM25 results. The fused score for a page p is sum over retrievers r of 1 / (k_rrf + rank_r(p)). Pages absent from a retriever contribute 0 from it.""" name = "hybrid-rrf" def __init__( self, dense: DenseRetriever | None = None, bm25: BM25Retriever | None = None, k_rrf: int = 60, pool: int = 50, ): self.dense = dense or DenseRetriever() self.bm25 = bm25 or BM25Retriever() self.k_rrf = k_rrf self.pool = pool def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: dense_pages = self.dense.retrieve(query, k=self.pool) bm25_pages = self.bm25.retrieve(query, k=self.pool) scores: dict[tuple[str, str], float] = {} for rank, page in enumerate(dense_pages, start=1): scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank) for rank, page in enumerate(bm25_pages, start=1): scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank) fused = sorted(scores.items(), key=lambda kv: -kv[1]) return [page for page, _ in fused[:k]] # --------------------------------------------------------------------------- # Reranker (jina-reranker via llama.cpp /v1/rerank) # --------------------------------------------------------------------------- class RerankedRetriever: """Take a base retriever's pool, fetch full chunk text for each page's top chunk, send (query, chunk_text) pairs to a llama.cpp /v1/rerank endpoint, then rerank pages by the returned scores. For eval we operate page-level. We pick the first chunk per page from the base retriever's chunk-level output. To get the chunk text we re-query Chroma by chunk id.""" name = "dense+rerank" def __init__( self, base: Retriever | None = None, rerank_url: str | None = None, pool: int = 50, timeout: float = 30.0, ): self.base = base or DenseRetriever() self.rerank_url = (rerank_url or os.environ.get("RERANK_URL", "")).rstrip("/") self.pool = pool self.timeout = timeout self._col = None @property def name_with_base(self) -> str: return f"{self.base.name}+rerank" def _collection(self): if self._col is not None: return self._col import chromadb from chromadb.config import Settings client = chromadb.PersistentClient( path=str(CHROMA_DIR), settings=Settings(anonymized_telemetry=False), ) # We don't need the embedder for fetch-by-id; pass embedding_function=None self._col = client.get_collection(COLLECTION) return self._col def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: if not self.rerank_url: # Fail open to base retriever — useful in eval to compare base vs # base+rerank when the reranker is offline. log.warning("RERANK_URL unset; falling back to base retriever") return self.base.retrieve(query, k=k) pages = self.base.retrieve(query, k=self.pool) if not pages: return [] # Fetch one representative chunk per page (the first chunk, ordinal=0 # if it exists, else any). For eval simplicity we approximate by # fetching by metadata where (source, source_key) and taking the # first hit. col = self._collection() docs: list[str] = [] kept_pages: list[tuple[str, str]] = [] for source, source_key in pages: where = {"$and": [{"source": source}, {"source_key": source_key}]} got = col.get(where=where, limit=1, include=["documents"]) d = (got.get("documents") or [None])[0] if not d: continue # Truncate to keep under the reranker's per-pair context limit docs.append(d[:2000]) kept_pages.append((source, source_key)) if not docs: return [] import httpx try: r = httpx.post( f"{self.rerank_url}/v1/rerank", json={"query": query, "documents": docs}, timeout=self.timeout, ) r.raise_for_status() data = r.json() except Exception as exc: # noqa: BLE001 log.warning("rerank failed (%s) — falling back to base order", exc) return kept_pages[:k] # llama.cpp returns {"results": [{"index": i, "relevance_score": s}, ...]} results = data.get("results") or [] scored: list[tuple[float, tuple[str, str]]] = [] for r_item in results: idx = r_item.get("index") score = r_item.get("relevance_score") or r_item.get("score") or 0.0 if isinstance(idx, int) and 0 <= idx < len(kept_pages): scored.append((score, kept_pages[idx])) scored.sort(key=lambda x: -x[0]) return [p for _, p in scored[:k]]