"""Embedding function for Chroma — Ollama-hosted nomic-embed-text by default. Swappable: implement the same `embedding_function()` interface returning a Chroma `EmbeddingFunction` and the rest of the pipeline doesn't care. Defaults (override via env): OLLAMA_URL one or more comma-separated URLs (load-balanced) EMBED_MODEL model name; default 'nomic-embed-text' EMBED_DIM expected embedding dim; default 768 (nomic-embed-text) """ from __future__ import annotations import os import logging from typing import Any import httpx from chromadb import EmbeddingFunction, Documents, Embeddings log = logging.getLogger(__name__) OLLAMA_URLS = [u.strip() for u in os.environ.get("OLLAMA_URL", "http://localhost:11434").split(",") if u.strip()] EMBED_MODEL = os.environ.get("EMBED_MODEL", "nomic-embed-text") EMBED_DIM = int(os.environ.get("EMBED_DIM", "768")) class OllamaEmbeddings(EmbeddingFunction): """Calls /api/embed across N Ollama endpoints, naive round-robin. For indexing throughput on multiple GPUs, run one Ollama container per GPU (pinned via NVIDIA_VISIBLE_DEVICES) and pass all their URLs in OLLAMA_URL — the embedder picks the next endpoint per batch. """ def __init__(self, urls: list[str] = OLLAMA_URLS, model: str = EMBED_MODEL): self.urls = urls self.model = model self._next = 0 def __call__(self, input: Documents) -> Embeddings: url = self.urls[self._next % len(self.urls)] self._next += 1 with httpx.Client(timeout=300) as c: r = c.post(f"{url}/api/embed", json={"model": self.model, "input": list(input)}) r.raise_for_status() data = r.json() return data.get("embeddings") or [] def name(self) -> str: # newer chromadb requires this return f"ollama:{self.model}" @staticmethod def build_from_config(config: dict) -> "OllamaEmbeddings": # newer chromadb return OllamaEmbeddings( urls=config.get("urls", OLLAMA_URLS), model=config.get("model", EMBED_MODEL), ) def get_config(self) -> dict: # newer chromadb return {"urls": self.urls, "model": self.model} def default_space(self) -> str: return "cosine" def supported_spaces(self) -> list[str]: return ["cosine", "l2", "ip"] def embedding_function() -> EmbeddingFunction: return OllamaEmbeddings()