"""Build Chroma (and optionally BM25) indexes from the labels corpus. Reads `corpus//.{md,json}`, chunks each label, upserts into Chroma. With --rebuild, drops + recreates the collection (clean state). With --bm25-only, skips Chroma and rebuilds only the FTS5 index — useful for fast iteration when chunking didn't change. The corpus root honors CORPUS_ROOT (matching the scrapers). The Chroma + BM25 stores stay at the repo root because both rely on filesystem locking semantics that vfat (typical USB drive) doesn't provide reliably. """ from __future__ import annotations import argparse import json import logging import os import time from pathlib import Path from typing import Iterator import chromadb from chromadb.config import Settings from .chunk import chunks_from_label from .embeddings import embedding_function log = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") REPO_ROOT = Path(__file__).resolve().parent.parent CORPUS_ROOT = Path(os.environ.get("CORPUS_ROOT") or REPO_ROOT / "corpus") CHROMA_DIR = Path(os.environ.get("CHROMA_DIR_OVERRIDE") or REPO_ROOT / "chroma") # Collection name — convention: _docs. Override via env. PRODUCT_NAME = os.environ.get("PRODUCT_NAME", "crop_chem") COLLECTION = f"{PRODUCT_NAME}_docs" def _extract_label_metadata(sidecar: dict, source: str, source_key: str) -> dict: """Flatten the canonical labels sidecar into a Chroma-friendly metadata dict (Chroma requires str/int/float/bool values, no nesting/lists).""" label = sidecar.get("label") or {} actives = ", ".join( a["name"] for a in (sidecar.get("active_ingredients") or []) if isinstance(a, dict) and a.get("name") ) return { "source": sidecar.get("source") or source, "source_key": sidecar.get("source_key") or source_key, "epa_reg_no": sidecar.get("epa_reg_no") or "", "product_name": sidecar.get("product_name") or "", "product_class": sidecar.get("product_class") or "", "registrant": sidecar.get("registrant") or "", "signal_word": sidecar.get("signal_word") or "", "active_ingredients": actives, "label_url": label.get("url") or "", "label_accepted_date": label.get("accepted_date") or "", } def label_chunks() -> Iterator[dict]: """Walk the corpus and yield one chunk dict per chunk per label.""" if not CORPUS_ROOT.exists(): log.error("corpus root %s doesn't exist; run a scraper first", CORPUS_ROOT) return sources_seen = 0 labels_seen = 0 for source_dir in sorted(CORPUS_ROOT.iterdir()): if not source_dir.is_dir() or source_dir.name.startswith("."): continue sources_seen += 1 source = source_dir.name for md_path in sorted(source_dir.glob("*.md")): source_key = md_path.stem sidecar_path = md_path.with_suffix(".json") if not sidecar_path.exists(): log.warning("skipping %s — no JSON sidecar", md_path) continue try: md = md_path.read_text(encoding="utf-8") sidecar = json.loads(sidecar_path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError) as exc: log.warning("skipping %s — read error: %s", md_path, exc) continue base_meta = _extract_label_metadata(sidecar, source, source_key) labels_seen += 1 yield from chunks_from_label(md, sidecar, base_meta) log.info("walked %d source(s), %d label(s)", sources_seen, labels_seen) def upsert_to_chroma(records: list[dict], *, rebuild: bool) -> int: client = chromadb.PersistentClient( path=str(CHROMA_DIR), settings=Settings(anonymized_telemetry=False), ) if rebuild: try: client.delete_collection(COLLECTION) log.info("dropped existing collection %r", COLLECTION) except Exception: pass col = client.get_or_create_collection( COLLECTION, embedding_function=embedding_function() ) # Match Chroma upsert batch size to the number of parallel Ollama # endpoints so each one gets a meaningful per-call shard (~64 docs). # Overridable via env for tuning. n_urls = max(1, len([u for u in os.environ.get("OLLAMA_URL", "http://localhost:11434").split(",") if u.strip()])) BATCH = int(os.environ.get("INDEX_BATCH") or 64 * n_urls) log.info("upsert batch size: %d (%d URL(s) × 64)", BATCH, n_urls) total = 0 for i in range(0, len(records), BATCH): batch = records[i:i + BATCH] col.upsert( ids=[r["id"] for r in batch], documents=[r["text"] for r in batch], metadatas=[r["metadata"] for r in batch], ) total += len(batch) if total % 1024 == 0 or total == len(records): log.info("upserted %d / %d chunks", total, len(records)) return total def main() -> int: p = argparse.ArgumentParser() p.add_argument("--rebuild", action="store_true", help="Drop and recreate the Chroma collection.") p.add_argument("--bm25-only", action="store_true", help="Rebuild only the BM25 index, skip Chroma.") p.add_argument("--limit", type=int, default=None, help="Limit to N labels (smoke testing).") p.add_argument("--source", action="append", help="Restrict to one or more source dirs (repeatable).") p.add_argument("--bm25-db", type=Path, default=REPO_ROOT / "bm25" / f"{PRODUCT_NAME}_docs.db", help="Path to the BM25 sqlite db.") args = p.parse_args() log.info("corpus root: %s", CORPUS_ROOT) log.info("chroma dir: %s", CHROMA_DIR) log.info("collection: %s", COLLECTION) t0 = time.time() records = [] label_count = 0 last_label_key: str | None = None for rec in label_chunks(): if args.source and rec["metadata"]["source"] not in args.source: continue if args.limit: key = (rec["metadata"]["source"], rec["metadata"]["source_key"]) if key != last_label_key: if label_count >= args.limit: break label_count += 1 last_label_key = key records.append(rec) log.info("loaded %d chunks from %d label(s) in %.1fs", len(records), label_count or "(all)", time.time() - t0) if args.bm25_only: from .bm25 import BM25Index log.info("--bm25-only: building FTS5 only") args.bm25_db.parent.mkdir(parents=True, exist_ok=True) BM25Index(args.bm25_db).build(records) return 0 if not args.rebuild: log.info("no --rebuild; nothing to do. (Use --rebuild to upsert.)") return 0 t_c = time.time() CHROMA_DIR.mkdir(parents=True, exist_ok=True) n = upsert_to_chroma(records, rebuild=True) log.info("chroma: %d chunks in %.1fs", n, time.time() - t_c) # Build BM25 too — see PLAN.md Phase 8. try: from .bm25 import BM25Index t_b = time.time() args.bm25_db.parent.mkdir(parents=True, exist_ok=True) BM25Index(args.bm25_db).build(records) log.info("bm25 done in %.1fs", time.time() - t_b) except ImportError: log.info("rag.bm25 not available — skipping BM25 build") return 0 if __name__ == "__main__": raise SystemExit(main())