"""Embedding function for Chroma — Ollama-hosted nomic-embed-text by default. Supports parallel dispatch across multiple Ollama endpoints. Each call splits its input across the configured URLs and embeds them concurrently via a thread pool; results are recombined in original order. Swappable: implement the same `embedding_function()` interface returning a Chroma `EmbeddingFunction` and the rest of the pipeline doesn't care. Defaults (override via env): OLLAMA_URL one or more comma-separated URLs (parallel-dispatched) EMBED_MODEL model name; default 'nomic-embed-text' EMBED_DIM expected embedding dim; default 768 (nomic-embed-text) """ from __future__ import annotations import os import logging from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any import httpx from chromadb import EmbeddingFunction, Documents, Embeddings log = logging.getLogger(__name__) OLLAMA_URLS = [u.strip() for u in os.environ.get("OLLAMA_URL", "http://localhost:11434").split(",") if u.strip()] EMBED_MODEL = os.environ.get("EMBED_MODEL", "nomic-embed-text") EMBED_DIM = int(os.environ.get("EMBED_DIM", "768")) HTTP_TIMEOUT = float(os.environ.get("EMBED_TIMEOUT", "300")) class OllamaEmbeddings(EmbeddingFunction): """Calls /api/embed across N Ollama endpoints **in parallel**. Each __call__ splits its input documents into len(urls) shards via stride slicing, fires len(urls) concurrent HTTP requests, then interleaves the results back to original order. With N GPU-backed Ollamas, throughput scales close to Nx (Chroma upsert overhead and slowest-shard barrier cap it shy of true linear). For best per-call efficiency, sized batches at ~64-per-shard (i.e., BATCH = 64 * N in the indexer) keep each Ollama doing real work each round. """ def __init__(self, urls: list[str] = OLLAMA_URLS, model: str = EMBED_MODEL): if not urls: raise ValueError("OllamaEmbeddings requires at least one URL") self.urls = urls self.model = model # One persistent thread per URL — embedding throughput is HTTP-bound, # threads are essentially free. self._pool = ThreadPoolExecutor( max_workers=len(urls), thread_name_prefix="ollama-embed", ) def __call__(self, input: Documents) -> Embeddings: docs = list(input) n = len(self.urls) if not docs: return [] if n == 1: return self._embed_one(self.urls[0], docs) # Stride-slice into n shards so docs are distributed evenly. # Reconstruction reverses the stride via index arithmetic. shards: list[tuple[int, str, list[str]]] = [] for shard_idx in range(n): shard_docs = docs[shard_idx::n] if shard_docs: shards.append((shard_idx, self.urls[shard_idx], shard_docs)) # Parallel dispatch + barrier-wait results: dict[int, list[list[float]]] = {} futures = { self._pool.submit(self._embed_one, url, shard_docs): shard_idx for shard_idx, url, shard_docs in shards } for fut in as_completed(futures): shard_idx = futures[fut] results[shard_idx] = fut.result() # Interleave back to original order out: list[list[float] | None] = [None] * len(docs) for shard_idx, shard_embeds in results.items(): for offset, embed in enumerate(shard_embeds): out[shard_idx + offset * n] = embed # Surface any missing slot loudly rather than silently returning Nones if any(v is None for v in out): missing = [i for i, v in enumerate(out) if v is None] raise RuntimeError( f"embedding gap: {len(missing)} missing slot(s) after parallel " f"join; first missing index={missing[0]}" ) return out # type: ignore[return-value] def _embed_one(self, url: str, docs: list[str]) -> list[list[float]]: """Single HTTP call to one Ollama. On a 400 (typically one doc in the batch exceeded the model's context), bisect the batch until the offending doc(s) are isolated, then emit a zero-vector for each bad doc and continue. Never raises for 400 — only for connection / 5xx errors after retries are exhausted upstream.""" if not docs: return [] try: with httpx.Client(timeout=HTTP_TIMEOUT) as c: r = c.post( f"{url}/api/embed", json={"model": self.model, "input": docs}, ) if r.status_code == 400: return self._bisect_400(url, docs, r.text) r.raise_for_status() data = r.json() return data.get("embeddings") or [] except httpx.HTTPStatusError: # Anything other than 400 propagates so retries / monitors fire. raise def _bisect_400(self, url: str, docs: list[str], err_text: str) -> list[list[float]]: """Recursive bisection: split docs in half, retry each half. If one doc alone still 400s, log it with size + a snippet and return a zero-vector placeholder for that slot (so order is preserved and Chroma upsert succeeds).""" if len(docs) == 1: log.warning( "embed: dropping single bad doc on %s (chars=%d, err=%s); " "snippet=%r", url, len(docs[0]), err_text[:120], docs[0][:80], ) return [[0.0] * EMBED_DIM] mid = len(docs) // 2 left = self._embed_one(url, docs[:mid]) right = self._embed_one(url, docs[mid:]) return left + right def name(self) -> str: # newer chromadb requires this return f"ollama:{self.model}" @staticmethod def build_from_config(config: dict) -> "OllamaEmbeddings": # newer chromadb return OllamaEmbeddings( urls=config.get("urls", OLLAMA_URLS), model=config.get("model", EMBED_MODEL), ) def get_config(self) -> dict: # newer chromadb return {"urls": self.urls, "model": self.model} def default_space(self) -> str: return "cosine" def supported_spaces(self) -> list[str]: return ["cosine", "l2", "ip"] def embedding_function() -> EmbeddingFunction: return OllamaEmbeddings()