"""Retriever protocol + concrete implementations. A single matrix dimension per knob (dense / reranked / bm25 / hybrid) so the eval harness can compare them apples-to-apples. Implement these once at Phase 7 and reuse them across every retrieval change. Each retriever returns a ranked list of (bundle_id, page_id) tuples deduplicated to the page level (chunks within the same page collapse to one entry; the highest-ranked chunk's position wins). """ from __future__ import annotations from typing import Iterable, Protocol class Retriever(Protocol): name: str def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: """Return up to k (bundle_id, page_id) tuples in rank order.""" ... def _split_chunk_id(chunk_id: str) -> tuple[str, str, int]: """`bundle::page::ordinal` -> (bundle, page, int(ordinal)).""" bid, pid, ordinal = chunk_id.split("::") return bid, pid, int(ordinal) def _collapse_to_pages(chunk_ids: Iterable[str], k: int) -> list[tuple[str, str]]: seen: set[tuple[str, str]] = set() out: list[tuple[str, str]] = [] for cid in chunk_ids: bid, pid, _ord = _split_chunk_id(cid) key = (bid, pid) if key in seen: continue seen.add(key) out.append(key) if len(out) >= k: break return out class DenseRetriever: """Chroma cosine search via the live embedding function.""" name = "dense" def __init__(self, collection, pool: int = 50): self.col = collection self.pool = pool def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: res = self.col.query(query_texts=[query], n_results=self.pool) ids = (res.get("ids") or [[]])[0] return _collapse_to_pages(ids, k) class BM25Retriever: """SQLite FTS5 lexical search.""" name = "bm25" def __init__(self, bm25_index, pool: int = 200): self.bm = bm25_index self.pool = pool def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: hits = self.bm.query(query, n=self.pool) return _collapse_to_pages((cid for cid, _score in hits), k) class HybridRetriever: """Reciprocal Rank Fusion of dense + BM25 rankings.""" name = "hybrid_rrf" def __init__(self, dense: DenseRetriever, bm25: BM25Retriever, k_rrf: int = 60, pool: int = 100): self.dense = dense self.bm25 = bm25 self.k_rrf = k_rrf self.pool = pool def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: dense_pages = self.dense.retrieve(query, k=self.pool) bm25_pages = self.bm25.retrieve(query, k=self.pool) scores: dict[tuple[str, str], float] = {} for rank, page in enumerate(dense_pages, start=1): scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank) for rank, page in enumerate(bm25_pages, start=1): scores[page] = scores.get(page, 0.0) + 1.0 / (self.k_rrf + rank) ranked = sorted(scores.items(), key=lambda kv: -kv[1]) return [page for page, _s in ranked[:k]] def _rerank_pool(rerank_url: str, query: str, ids_and_texts: list[tuple[str, str]], timeout: float = 30.0) -> list[str] | None: """POST to /v1/rerank, return ids in reranked order. None on failure.""" if not ids_and_texts: return [] import httpx try: with httpx.Client(timeout=timeout) as c: r = c.post(f"{rerank_url}/v1/rerank", json={ "query": query, "documents": [(t or "")[:2000] for _i, t in ids_and_texts], "top_n": len(ids_and_texts), }) r.raise_for_status() results = r.json().get("results") or [] return [ids_and_texts[item["index"]][0] for item in results if isinstance(item.get("index"), int) and 0 <= item["index"] < len(ids_and_texts)] except Exception: return None class RerankedRetriever: """Pull a candidate pool via a base retriever, then cross-encoder re-rank.""" def __init__(self, base: Retriever, collection, rerank_url: str, name_suffix: str = "rerank", pool: int = 50, timeout: float = 30.0): self.base = base self.col = collection self.url = rerank_url self.name = f"{base.name}+{name_suffix}" self.pool = pool self.timeout = timeout def retrieve(self, query: str, k: int = 10) -> list[tuple[str, str]]: # Base returns deduplicated page-level tuples; rerank needs CHUNK-level # texts to be informative. Pull each page's chunk 0 text from Chroma. pages = self.base.retrieve(query, k=self.pool) if not pages: return [] chunk_ids = [f"{bid}::{pid}::0" for bid, pid in pages] g = self.col.get(ids=chunk_ids, include=["documents"]) by_id = dict(zip(g["ids"], g["documents"])) ids_and_texts = [(cid, by_id.get(cid, "")) for cid in chunk_ids] order = _rerank_pool(self.url, query, ids_and_texts, timeout=self.timeout) if order is None: return pages[:k] out: list[tuple[str, str]] = [] seen: set[tuple[str, str]] = set() for cid in order: bid, pid, _ = cid.split("::") key = (bid, pid) if key in seen: continue seen.add(key) out.append(key) if len(out) >= k: break return out