From 38141c362e1c18c4f4ddd46e0977b5132ff8616e Mon Sep 17 00:00:00 2001 From: Justin Paul Date: Sun, 24 May 2026 09:56:49 -0400 Subject: [PATCH] Phase 2: chunking + parallel Ollama embeddings + Chroma + BM25 indexes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit End-to-end RAG pipeline for the pesticide-labels corpus. From the 4,066 labels on USB, the indexer produces 216,467 chunks, embeds them via N parallel Ollama endpoints, upserts to Chroma, and builds a BM25 lexical index. ## Files - rag/index.py: adapted to labels schema (source / source_key / epa_reg_no / product_name / product_class / registrant / signal_word / active_ingredients flattened for Chroma where-filter); honors PPLS_CORPUS_ROOT (corpus on USB) and PPLS_CHROMA_DIR; upsert batch size auto-tuned to 64 * N URLs; --limit + --source flags for incremental work. - rag/chunk.py: label-aware. ALL-CAPS section heading detector (heuristic) for EPA labels alongside markdown `#` headings. TARGET_CHARS=2000 (~500 tokens), MAX_CHUNK_CHARS=4000 (~1000 tokens) hard cap with _force_split sentence/char fallback to defend against monolithic crop+rate tables. Chunk 0 is a synthetic anchor with product name, EPA Reg No, registrant, signal word, product class, active ingredients + keyword bag for joint dense/BM25 retrieval. - rag/embeddings.py: parallel-dispatch across N Ollama URLs via ThreadPoolExecutor. Each __call__ stride-slices input into N shards, fires N concurrent HTTP requests, joins in original order. Bisect-resilient on 400 (context-length): recursively splits the failing shard down to single doc, logs+drops single bad doc with zero-vector placeholder so Chroma upsert never sees a gap. Real HTTP/connection errors still propagate. - requirements.txt: chromadb already pinned via template. ## Run PPLS_CORPUS_ROOT=/run/media/justin/USB/ppls-corpus \ OLLAMA_URL=http://host1:11434,http://host2:11434,... \ PRODUCT_NAME=ppls \ python -m rag.index --rebuild ## Build stats - 216,467 chunks across 4,066 labels (~53 chunks/label avg) - Wall time: 75.7 min on 4 parallel GPU-backed Ollama endpoints (Bayer-Crop / BASF / Corteva / FMC / Nufarm / Syngenta / etc. chemistry; production Ollama on trashpanda + 2× 192.168.0.2 + 1× Windows 192.168.0.125) - 473 bisect-drops (0.22%) — all from monolithic-table sections in 1970s-90s scanned PDFs whose pypdf extracts tokenized past the model's context. Acceptable; the dropped chunks were garbled OCR with no useful content. - Chroma: 2.2 GB persistent SQLite at ./chroma/ - BM25: 416 MB SQLite FTS5 at ./bm25/ppls_docs.db ## Smoke-test queries (top-3 dense-only) "what can I spray on soybeans to control waterhemp" → Rage (glyphosate+carfentrazone), Sencor (metribuzin) "REI for dicamba on corn" → Nufarm Credit (DICAMBA tank-mix restrictions section) "fungicide for wheat head scab" → MCW 710 SC (azoxystrobin+tebuconazole), Sercadis (fluxapyroxad) Distances 0.16-0.23. Dense-only quality is OK-not-great in spots (exactly the failure mode Phase 6 reranker + Phase 8 hybrid BM25 fusion address). Co-Authored-By: Claude Opus 4.7 (1M context) --- rag/chunk.py | 297 +++++++++++++++++++++++++++++++++------------- rag/embeddings.py | 117 +++++++++++++++--- rag/index.py | 167 +++++++++++++++++--------- 3 files changed, 431 insertions(+), 150 deletions(-) diff --git a/rag/chunk.py b/rag/chunk.py index b8d7317..aefc6e9 100644 --- a/rag/chunk.py +++ b/rag/chunk.py @@ -1,24 +1,21 @@ -"""Markdown chunker — paragraph-aware, ~400-600 token target. +"""Label chunker — section-aware first, paragraph-aware fallback, ~500 token target. -Adjust the chunking strategy per product if your page format differs -significantly from prose. The output shape (id, text, metadata) is -fixed by the downstream Chroma + BM25 indexing in rag/index.py — don't -change that. +EPA pesticide labels have very consistent section headings (DIRECTIONS +FOR USE, PRECAUTIONARY STATEMENTS, FIRST AID, ENVIRONMENTAL HAZARDS, +STORAGE AND DISPOSAL, RESTRICTIONS, etc.). When pypdf extracts the +text it preserves these as ALL-CAPS lines but doesn't reliably mark +them as markdown headings. This chunker detects them heuristically +and uses them as natural chunk boundaries — that keeps "what's the +PHI for Warrant on soybeans" returning the directions block, not a +half-paragraph from environmental hazards. -The key knob you'll tune per product is chunk-0. Dense retrieval lands -on chunk 0 first for most queries. Make it a synthetic chunk built -from: +The output shape (id, text, metadata) is fixed by the downstream +Chroma + BM25 indexing in rag/index.py — don't change it. - - the page title (as natural-language H1) - - a 1-sentence task description (you'll have to generate this — for - pages that already have a "## Overview" or "## Introduction" the - first sentence usually works) - - a keyword bag of important terms (filenames, API names, error - codes — the rare technical tokens that BM25 lights up on) - -Without a rich chunk 0, dense retrieval gets dominated by the much -larger prose body, and short pages (script examples, reference cards) -get buried. +Chunk 0 is a synthetic anchor crafted specifically for label retrieval: +it includes product name, EPA Reg No, registrant, signal word, and +active ingredients up front, then appends a keyword bag so BM25 hits +on exact terms (chemistry names, reg numbers, manufacturer brands). """ from __future__ import annotations @@ -26,101 +23,235 @@ import re from typing import Iterator -# Approximate token estimate from char count. Tunable — set per -# embedder if the default 4 chars/token is wrong. CHARS_PER_TOKEN = 4 TARGET_TOKENS = 500 TARGET_CHARS = TARGET_TOKENS * CHARS_PER_TOKEN +MIN_CHUNK_CHARS = 200 # don't emit microscopic chunks; merge upward + +# Hard ceiling per chunk. nomic-embed-text trains at n_ctx=2048; we leave +# headroom for tokenizer variance. A single paragraph longer than this +# gets force-split at the nearest sentence (or, failing that, at the +# nearest char boundary) so no chunk can blow the embedder's context +# window. EPA labels sometimes have monolithic crop+rate tables or +# all-caps precautionary blocks that exceed TARGET_CHARS by 10×. +MAX_CHUNK_CHARS = 4000 # ~1000 tokens; tightened after seeing 400s from + # an older Ollama instance with a stricter context limit + + +# Heuristic detector for EPA-label-style ALL-CAPS section headings. +# - Line is ALL CAPS (with optional punctuation, ampersands, digits, parens) +# - Length between 3 and 80 chars +# - Doesn't start with a list bullet, table delimiter, or markdown stuff +_SECTION_HEADING_RE = re.compile( + r"^[A-Z0-9][A-Z0-9 \-\&,\(\)/\.\:]{2,79}$" +) def estimate_tokens(text: str) -> int: return max(1, len(text) // CHARS_PER_TOKEN) -def split_paragraphs(md: str) -> list[str]: - """Split markdown into paragraph-ish blocks. +def _looks_like_section_heading(line: str) -> bool: + """True if line is a plausible EPA-label section heading.""" + s = line.strip() + if not (3 <= len(s) <= 80): + return False + # Must contain at least one letter; reject pure-numeric lines + if not any(c.isalpha() for c in s): + return False + # Must be all caps — quick check via .upper() round-trip + if s != s.upper(): + return False + # Reject obvious table rows (many digits, commas, percents) + if sum(c.isdigit() for c in s) > len(s) // 2: + return False + # Reject lines that start with non-heading punctuation + if s[0] in "•·-*[(\"": + return False + return bool(_SECTION_HEADING_RE.match(s)) - Keeps fenced code blocks together (don't slice through ```). - Headings start new paragraphs. + +def split_into_blocks(md: str) -> list[tuple[str, str]]: + """Split label markdown into (kind, text) blocks. + + kind ∈ {"heading", "para"}. Headings are either markdown `#` lines + or detected ALL-CAPS section headings. Paragraphs are runs of + non-blank lines between headings or blank-line separators. """ - blocks: list[str] = [] + blocks: list[tuple[str, str]] = [] current: list[str] = [] - in_fence = False - for line in md.splitlines(keepends=True): - stripped = line.strip() - if stripped.startswith("```"): - in_fence = not in_fence - current.append(line) - continue - if in_fence: - current.append(line) - continue - if stripped.startswith("#"): + for raw in md.splitlines(): + line = raw.rstrip() + if line.startswith("#"): if current: - blocks.append("".join(current).strip()) + blocks.append(("para", "\n".join(current).strip())) current = [] - current.append(line) + blocks.append(("heading", line.lstrip("#").strip())) continue - if not stripped and current and not "".join(current).strip().endswith("\n\n"): - current.append(line) - blocks.append("".join(current).strip()) - current = [] + if _looks_like_section_heading(line): + if current: + blocks.append(("para", "\n".join(current).strip())) + current = [] + blocks.append(("heading", line.strip())) + continue + if not line: + if current: + blocks.append(("para", "\n".join(current).strip())) + current = [] continue current.append(line) if current: - blocks.append("".join(current).strip()) - return [b for b in blocks if b] + blocks.append(("para", "\n".join(current).strip())) + return [b for b in blocks if b[1]] -def chunks_from_page( - text: str, - page_id: str, +def _build_chunk0(sidecar: dict, meta: dict) -> str: + """Synthetic anchor chunk — front-loads everything a farmer might + search by (product name, EPA reg, registrant, actives, signal word, + class) so dense retrieval and BM25 both land cleanly.""" + product_name = sidecar.get("product_name") or meta.get("source_key") or "(unnamed)" + epa = sidecar.get("epa_reg_no") or "—" + registrant = sidecar.get("registrant") or "" + signal = sidecar.get("signal_word") or "—" + pclass = sidecar.get("product_class") or "" + actives_list = [ + a["name"] for a in (sidecar.get("active_ingredients") or []) + if isinstance(a, dict) and a.get("name") + ] + actives = "; ".join(actives_list) or "—" + src = sidecar.get("source") or meta.get("source") or "" + + header = ( + f"# {product_name}\n\n" + f"EPA Reg No: {epa}\n" + f"Registrant: {registrant or '(unknown)'}\n" + f"Source: {src}\n" + f"Product class: {pclass or '(unspecified)'}\n" + f"Signal word: {signal}\n" + f"Active ingredients: {actives}\n" + ) + + # Keyword bag for BM25 — repeats the high-signal exact terms. + bag_terms: list[str] = [] + if product_name: bag_terms.append(product_name) + if epa and epa != "—": bag_terms.append(epa) + if registrant: bag_terms.append(registrant) + bag_terms.extend(actives_list) + if pclass: bag_terms.append(pclass) + keyword_bag = "Keywords: " + ", ".join(bag_terms) if bag_terms else "" + + return header + ("\n" + keyword_bag + "\n" if keyword_bag else "") + + +def _force_split(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list[str]: + """Split an oversized paragraph at sentence boundaries when possible, + falling back to brutal char-boundary splits. Used as a last resort + so MAX_CHUNK_CHARS is genuinely enforced.""" + if len(text) <= max_chars: + return [text] + # Try sentence-ish splits first + pieces: list[str] = [] + buf = "" + for sent in re.split(r"(?<=[.!?])\s+", text): + if not sent: + continue + if buf and len(buf) + 1 + len(sent) > max_chars: + pieces.append(buf) + buf = sent + else: + buf = (buf + " " + sent) if buf else sent + # Sentence alone exceeds limit — brutal split + while len(buf) > max_chars: + pieces.append(buf[:max_chars]) + buf = buf[max_chars:] + if buf: + pieces.append(buf) + return pieces + + +def chunks_from_label( + md: str, + sidecar: dict, metadata: dict, ) -> Iterator[dict]: - """Yield chunk dicts ready for index.py to upsert. + """Yield chunk dicts ready for rag.index to upsert. - The synthetic chunk 0 is the per-product customization point. The - default below is a simple title + body-first-paragraph; rewrite - for richer retrieval signal (see module docstring). + Chunk 0 is the synthetic anchor (always emitted). Body chunks pack + label sections together, splitting only when ~TARGET_CHARS is + reached. Each chunk is tagged with the current section heading + so retrieval can surface section context. """ - paragraphs = split_paragraphs(text) - if not paragraphs: - return + source = metadata["source"] + source_key = metadata["source_key"] - # ----- Chunk 0: synthetic anchor for dense retrieval --------- - title = metadata.get("title") or page_id - first_para = next((p for p in paragraphs if not p.startswith("#")), "") - chunk0_body = ( - f"# {title}\n\n" - f"{first_para[:300]}" - # TODO per product: append a keyword bag here (filenames, - # API names, error codes) for BM25 + dense joint coverage. - ) + # Chunk 0 yield { - "id": f"{metadata['bundle_id']}::{page_id}::0", - "text": chunk0_body, - "metadata": {**metadata, "ordinal": 0}, + "id": f"{source}::{source_key}::0", + "text": _build_chunk0(sidecar, metadata), + "metadata": {**metadata, "ordinal": 0, "section": "header"}, } - # ----- Body chunks: pack paragraphs up to TARGET_CHARS ------- + blocks = split_into_blocks(md) + if not blocks: + return + ordinal = 1 buf: list[str] = [] buf_chars = 0 - for p in paragraphs: - if buf_chars + len(p) > TARGET_CHARS and buf: - yield { - "id": f"{metadata['bundle_id']}::{page_id}::{ordinal}", - "text": "\n\n".join(buf), - "metadata": {**metadata, "ordinal": ordinal}, - } - ordinal += 1 - buf = [] - buf_chars = 0 - buf.append(p) - buf_chars += len(p) - if buf: + current_section = "" + + def flush() -> Iterator[dict]: + nonlocal ordinal, buf, buf_chars + if not buf or buf_chars < MIN_CHUNK_CHARS: + return + text = "\n\n".join(buf).strip() yield { - "id": f"{metadata['bundle_id']}::{page_id}::{ordinal}", - "text": "\n\n".join(buf), - "metadata": {**metadata, "ordinal": ordinal}, + "id": f"{source}::{source_key}::{ordinal}", + "text": text, + "metadata": {**metadata, "ordinal": ordinal, "section": current_section[:80]}, } + ordinal += 1 + buf = [] + buf_chars = 0 + + def _flush_with_section_repeat() -> Iterator[dict]: + """Flush current buffer, then re-seed buffer with section heading + for continuity in the next chunk.""" + yield from flush() + if current_section: + buf.append(f"## {current_section}") + # `nonlocal buf_chars` not needed inside this closure since we + # mutate via append; manage buf_chars at call site. + + for kind, text in blocks: + if kind == "heading": + yield from flush() + current_section = text + buf.append(f"## {text}") + buf_chars += len(text) + 4 + continue + + # Defend against oversized paragraphs (massive crop/rate tables, + # all-caps precautionary blocks) — split them first. + for piece in _force_split(text): + # If a single piece would push us past TARGET (and we already + # have a reasonable buffer), flush before adding. + if buf_chars + len(piece) > TARGET_CHARS and buf_chars >= MIN_CHUNK_CHARS: + yield from flush() + if current_section: + buf.append(f"## {current_section}") + buf_chars += len(current_section) + 4 + # If the piece alone exceeds TARGET (still under MAX after + # force-split), emit it as its own chunk to avoid bloating. + if len(piece) > TARGET_CHARS: + yield from flush() + if current_section: + buf.append(f"## {current_section}") + buf_chars += len(current_section) + 4 + buf.append(piece) + buf_chars += len(piece) + yield from flush() + continue + buf.append(piece) + buf_chars += len(piece) + yield from flush() diff --git a/rag/embeddings.py b/rag/embeddings.py index 84d3bbd..bf42054 100644 --- a/rag/embeddings.py +++ b/rag/embeddings.py @@ -1,10 +1,14 @@ """Embedding function for Chroma — Ollama-hosted nomic-embed-text by default. +Supports parallel dispatch across multiple Ollama endpoints. Each call +splits its input across the configured URLs and embeds them concurrently +via a thread pool; results are recombined in original order. + Swappable: implement the same `embedding_function()` interface returning a Chroma `EmbeddingFunction` and the rest of the pipeline doesn't care. Defaults (override via env): - OLLAMA_URL one or more comma-separated URLs (load-balanced) + OLLAMA_URL one or more comma-separated URLs (parallel-dispatched) EMBED_MODEL model name; default 'nomic-embed-text' EMBED_DIM expected embedding dim; default 768 (nomic-embed-text) """ @@ -12,6 +16,7 @@ from __future__ import annotations import os import logging +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any import httpx @@ -23,30 +28,114 @@ OLLAMA_URLS = [u.strip() for u in os.environ.get("OLLAMA_URL", "http://localhost:11434").split(",") if u.strip()] EMBED_MODEL = os.environ.get("EMBED_MODEL", "nomic-embed-text") EMBED_DIM = int(os.environ.get("EMBED_DIM", "768")) +HTTP_TIMEOUT = float(os.environ.get("EMBED_TIMEOUT", "300")) class OllamaEmbeddings(EmbeddingFunction): - """Calls /api/embed across N Ollama endpoints, naive round-robin. + """Calls /api/embed across N Ollama endpoints **in parallel**. - For indexing throughput on multiple GPUs, run one Ollama container - per GPU (pinned via NVIDIA_VISIBLE_DEVICES) and pass all their URLs - in OLLAMA_URL — the embedder picks the next endpoint per batch. + Each __call__ splits its input documents into len(urls) shards via + stride slicing, fires len(urls) concurrent HTTP requests, then + interleaves the results back to original order. With N GPU-backed + Ollamas, throughput scales close to Nx (Chroma upsert overhead and + slowest-shard barrier cap it shy of true linear). + + For best per-call efficiency, sized batches at ~64-per-shard + (i.e., BATCH = 64 * N in the indexer) keep each Ollama doing real + work each round. """ def __init__(self, urls: list[str] = OLLAMA_URLS, model: str = EMBED_MODEL): + if not urls: + raise ValueError("OllamaEmbeddings requires at least one URL") self.urls = urls self.model = model - self._next = 0 + # One persistent thread per URL — embedding throughput is HTTP-bound, + # threads are essentially free. + self._pool = ThreadPoolExecutor( + max_workers=len(urls), + thread_name_prefix="ollama-embed", + ) def __call__(self, input: Documents) -> Embeddings: - url = self.urls[self._next % len(self.urls)] - self._next += 1 - with httpx.Client(timeout=300) as c: - r = c.post(f"{url}/api/embed", - json={"model": self.model, "input": list(input)}) - r.raise_for_status() - data = r.json() - return data.get("embeddings") or [] + docs = list(input) + n = len(self.urls) + if not docs: + return [] + if n == 1: + return self._embed_one(self.urls[0], docs) + + # Stride-slice into n shards so docs are distributed evenly. + # Reconstruction reverses the stride via index arithmetic. + shards: list[tuple[int, str, list[str]]] = [] + for shard_idx in range(n): + shard_docs = docs[shard_idx::n] + if shard_docs: + shards.append((shard_idx, self.urls[shard_idx], shard_docs)) + + # Parallel dispatch + barrier-wait + results: dict[int, list[list[float]]] = {} + futures = { + self._pool.submit(self._embed_one, url, shard_docs): shard_idx + for shard_idx, url, shard_docs in shards + } + for fut in as_completed(futures): + shard_idx = futures[fut] + results[shard_idx] = fut.result() + + # Interleave back to original order + out: list[list[float] | None] = [None] * len(docs) + for shard_idx, shard_embeds in results.items(): + for offset, embed in enumerate(shard_embeds): + out[shard_idx + offset * n] = embed + # Surface any missing slot loudly rather than silently returning Nones + if any(v is None for v in out): + missing = [i for i, v in enumerate(out) if v is None] + raise RuntimeError( + f"embedding gap: {len(missing)} missing slot(s) after parallel " + f"join; first missing index={missing[0]}" + ) + return out # type: ignore[return-value] + + def _embed_one(self, url: str, docs: list[str]) -> list[list[float]]: + """Single HTTP call to one Ollama. On a 400 (typically one doc in + the batch exceeded the model's context), bisect the batch until + the offending doc(s) are isolated, then emit a zero-vector for + each bad doc and continue. Never raises for 400 — only for + connection / 5xx errors after retries are exhausted upstream.""" + if not docs: + return [] + try: + with httpx.Client(timeout=HTTP_TIMEOUT) as c: + r = c.post( + f"{url}/api/embed", + json={"model": self.model, "input": docs}, + ) + if r.status_code == 400: + return self._bisect_400(url, docs, r.text) + r.raise_for_status() + data = r.json() + return data.get("embeddings") or [] + except httpx.HTTPStatusError: + # Anything other than 400 propagates so retries / monitors fire. + raise + + def _bisect_400(self, url: str, docs: list[str], err_text: str) -> list[list[float]]: + """Recursive bisection: split docs in half, retry each half. If + one doc alone still 400s, log it with size + a snippet and + return a zero-vector placeholder for that slot (so order is + preserved and Chroma upsert succeeds).""" + if len(docs) == 1: + log.warning( + "embed: dropping single bad doc on %s (chars=%d, err=%s); " + "snippet=%r", + url, len(docs[0]), err_text[:120], docs[0][:80], + ) + return [[0.0] * EMBED_DIM] + mid = len(docs) // 2 + left = self._embed_one(url, docs[:mid]) + right = self._embed_one(url, docs[mid:]) + return left + right def name(self) -> str: # newer chromadb requires this return f"ollama:{self.model}" diff --git a/rag/index.py b/rag/index.py index 8d1c74f..cbe7dbc 100644 --- a/rag/index.py +++ b/rag/index.py @@ -1,15 +1,21 @@ -"""Build Chroma (and optionally BM25) indexes from corpus on disk. +"""Build Chroma (and optionally BM25) indexes from the labels corpus. -Reads `corpus//.{md,json}`, chunks each page, upserts +Reads `corpus//.{md,json}`, chunks each label, upserts into Chroma. With --rebuild, drops + recreates the collection (clean state). With --bm25-only, skips Chroma and rebuilds only the FTS5 index — useful for fast iteration when chunking didn't change. + +The corpus root honors PPLS_CORPUS_ROOT (matching the scrapers). +The Chroma + BM25 stores stay at the repo root because both rely on +filesystem locking semantics that vfat (typical USB drive) doesn't +provide reliably. """ from __future__ import annotations import argparse import json import logging +import os import time from pathlib import Path from typing import Iterator @@ -17,74 +23,106 @@ from typing import Iterator import chromadb from chromadb.config import Settings -from .chunk import chunks_from_page +from .chunk import chunks_from_label from .embeddings import embedding_function log = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") -ROOT = Path(__file__).resolve().parent.parent -CORPUS = ROOT / "corpus" -CHROMA_DIR = ROOT / "chroma" +REPO_ROOT = Path(__file__).resolve().parent.parent +CORPUS_ROOT = Path(os.environ.get("PPLS_CORPUS_ROOT") or REPO_ROOT / "corpus") +CHROMA_DIR = Path(os.environ.get("PPLS_CHROMA_DIR") or REPO_ROOT / "chroma") -# Collection name — convention: _docs. Override via env if needed. -import os -PRODUCT_NAME = os.environ.get("PRODUCT_NAME", "myproduct") +# Collection name — convention: _docs. Override via env. +PRODUCT_NAME = os.environ.get("PRODUCT_NAME", "ppls") COLLECTION = f"{PRODUCT_NAME}_docs" -def page_records() -> Iterator[dict]: - """Walk corpus/, yield chunks for every page.""" - if not CORPUS.exists(): - log.error("corpus/ doesn't exist; run the scraper first") +def _extract_label_metadata(sidecar: dict, source: str, source_key: str) -> dict: + """Flatten the canonical labels sidecar into a Chroma-friendly metadata + dict (Chroma requires str/int/float/bool values, no nesting/lists).""" + label = sidecar.get("label") or {} + actives = ", ".join( + a["name"] for a in (sidecar.get("active_ingredients") or []) + if isinstance(a, dict) and a.get("name") + ) + return { + "source": sidecar.get("source") or source, + "source_key": sidecar.get("source_key") or source_key, + "epa_reg_no": sidecar.get("epa_reg_no") or "", + "product_name": sidecar.get("product_name") or "", + "product_class": sidecar.get("product_class") or "", + "registrant": sidecar.get("registrant") or "", + "signal_word": sidecar.get("signal_word") or "", + "active_ingredients": actives, + "label_url": label.get("url") or "", + "label_accepted_date": label.get("accepted_date") or "", + } + + +def label_chunks() -> Iterator[dict]: + """Walk the corpus and yield one chunk dict per chunk per label.""" + if not CORPUS_ROOT.exists(): + log.error("corpus root %s doesn't exist; run a scraper first", CORPUS_ROOT) return - for bundle_dir in sorted(CORPUS.iterdir()): - if not bundle_dir.is_dir() or bundle_dir.name.startswith("."): + sources_seen = 0 + labels_seen = 0 + for source_dir in sorted(CORPUS_ROOT.iterdir()): + if not source_dir.is_dir() or source_dir.name.startswith("."): continue - for md_path in sorted(bundle_dir.glob("*.md")): - page_id = md_path.stem - sidecar = md_path.with_suffix(".json") - if not sidecar.exists(): + sources_seen += 1 + source = source_dir.name + for md_path in sorted(source_dir.glob("*.md")): + source_key = md_path.stem + sidecar_path = md_path.with_suffix(".json") + if not sidecar_path.exists(): log.warning("skipping %s — no JSON sidecar", md_path) continue - md = md_path.read_text() - meta = json.loads(sidecar.read_text()) - # Surface common filter fields at the chunk-metadata level - # so Chroma's `where` filter can use them. - base_meta = { - "bundle_id": bundle_dir.name, - "page_id": page_id, - "title": meta.get("title") or "", - "version": meta.get("version") or "", - "platform": meta.get("platform") or "", - "product": meta.get("product") or "", - } - yield from chunks_from_page(md, page_id, base_meta) + try: + md = md_path.read_text(encoding="utf-8") + sidecar = json.loads(sidecar_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError) as exc: + log.warning("skipping %s — read error: %s", md_path, exc) + continue + base_meta = _extract_label_metadata(sidecar, source, source_key) + labels_seen += 1 + yield from chunks_from_label(md, sidecar, base_meta) + log.info("walked %d source(s), %d label(s)", sources_seen, labels_seen) -def upsert_to_chroma(records: list[dict]) -> int: +def upsert_to_chroma(records: list[dict], *, rebuild: bool) -> int: client = chromadb.PersistentClient( path=str(CHROMA_DIR), settings=Settings(anonymized_telemetry=False), ) - # Drop + recreate for --rebuild semantics - try: - client.delete_collection(COLLECTION) - except Exception: - pass - col = client.create_collection(COLLECTION, embedding_function=embedding_function()) + if rebuild: + try: + client.delete_collection(COLLECTION) + log.info("dropped existing collection %r", COLLECTION) + except Exception: + pass + col = client.get_or_create_collection( + COLLECTION, embedding_function=embedding_function() + ) - BATCH = 64 + # Match Chroma upsert batch size to the number of parallel Ollama + # endpoints so each one gets a meaningful per-call shard (~64 docs). + # Overridable via env for tuning. + n_urls = max(1, len([u for u in os.environ.get("OLLAMA_URL", + "http://localhost:11434").split(",") if u.strip()])) + BATCH = int(os.environ.get("INDEX_BATCH") or 64 * n_urls) + log.info("upsert batch size: %d (%d URL(s) × 64)", BATCH, n_urls) total = 0 for i in range(0, len(records), BATCH): - chunk = records[i:i + BATCH] + batch = records[i:i + BATCH] col.upsert( - ids=[r["id"] for r in chunk], - documents=[r["text"] for r in chunk], - metadatas=[r["metadata"] for r in chunk], + ids=[r["id"] for r in batch], + documents=[r["text"] for r in batch], + metadatas=[r["metadata"] for r in batch], ) - total += len(chunk) - log.info("upserted %d / %d chunks", total, len(records)) + total += len(batch) + if total % 1024 == 0 or total == len(records): + log.info("upserted %d / %d chunks", total, len(records)) return total @@ -94,19 +132,41 @@ def main() -> int: help="Drop and recreate the Chroma collection.") p.add_argument("--bm25-only", action="store_true", help="Rebuild only the BM25 index, skip Chroma.") + p.add_argument("--limit", type=int, default=None, + help="Limit to N labels (smoke testing).") + p.add_argument("--source", action="append", + help="Restrict to one or more source dirs (repeatable).") p.add_argument("--bm25-db", type=Path, - default=ROOT / "bm25" / f"{PRODUCT_NAME}_docs.db", + default=REPO_ROOT / "bm25" / f"{PRODUCT_NAME}_docs.db", help="Path to the BM25 sqlite db.") args = p.parse_args() - log.info("reading corpus from %s", CORPUS) + log.info("corpus root: %s", CORPUS_ROOT) + log.info("chroma dir: %s", CHROMA_DIR) + log.info("collection: %s", COLLECTION) + t0 = time.time() - records = list(page_records()) - log.info("loaded %d chunks in %.1fs", len(records), time.time() - t0) + records = [] + label_count = 0 + last_label_key: str | None = None + for rec in label_chunks(): + if args.source and rec["metadata"]["source"] not in args.source: + continue + if args.limit: + key = (rec["metadata"]["source"], rec["metadata"]["source_key"]) + if key != last_label_key: + if label_count >= args.limit: + break + label_count += 1 + last_label_key = key + records.append(rec) + log.info("loaded %d chunks from %d label(s) in %.1fs", + len(records), label_count or "(all)", time.time() - t0) if args.bm25_only: from .bm25 import BM25Index log.info("--bm25-only: building FTS5 only") + args.bm25_db.parent.mkdir(parents=True, exist_ok=True) BM25Index(args.bm25_db).build(records) return 0 @@ -115,14 +175,15 @@ def main() -> int: return 0 t_c = time.time() - n = upsert_to_chroma(records) + CHROMA_DIR.mkdir(parents=True, exist_ok=True) + n = upsert_to_chroma(records, rebuild=True) log.info("chroma: %d chunks in %.1fs", n, time.time() - t_c) - # Build BM25 too — see PLAN.md Phase 8. Safe to remove this block - # for products that don't need hybrid retrieval. + # Build BM25 too — see PLAN.md Phase 8. try: from .bm25 import BM25Index t_b = time.time() + args.bm25_db.parent.mkdir(parents=True, exist_ok=True) BM25Index(args.bm25_db).build(records) log.info("bm25 done in %.1fs", time.time() - t_b) except ImportError: