d1fba83a87
Port of zerto-docs PR #45. OllamaEmbeddings previously made a single attempt per batch — any transient connection drop or HTTP error from one endpoint failed the entire index rebuild. - _embed() now rotates to the next endpoint and retries with backoff (5 attempts) on transport errors, and additionally halves the input (floor 16) on HTTP status errors: the .0.125 Windows Ollama (4090) 400s when its model runner dies on an oversized input array. Error response bodies are logged instead of swallowed. - CI workflows: OLLAMA_URLS extended from the two ripper instances to the full 4-endpoint GPU pool (+ .0.125 4090, + .0.126). At the 64-chunk batches this indexer already uses, .0.125 is the fastest embedder in the fleet (242 embeds/s measured on seed-mcp). Verified against the live pool: 64-text happy path, dead-endpoint rotation, and a forced 512-text 400 on .0.125 that split and completed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
121 lines
4.8 KiB
Python
121 lines
4.8 KiB
Python
"""Embedding function for Chroma — Ollama-hosted nomic-embed-text by default.
|
|
|
|
Swappable: implement the same `embedding_function()` interface returning
|
|
a Chroma `EmbeddingFunction` and the rest of the pipeline doesn't care.
|
|
|
|
Env-configurable (matches the zerto-docs-rag pattern so the same Gitea
|
|
runner + GPU-pinned Ollama containers can serve every docs MCP build):
|
|
|
|
OLLAMA_URLS comma-separated list, load-balanced round-robin per batch.
|
|
Preferred — set in the CI workflow to fan out across two
|
|
GPU-pinned Ollama containers on the Gitea host.
|
|
OLLAMA_URL single endpoint, fallback when OLLAMA_URLS is unset.
|
|
Default http://192.168.0.2:11434 (the host where the GPUs
|
|
live in Justin's lab).
|
|
EMBED_MODEL model name; default 'nomic-embed-text'
|
|
EMBED_DIM expected embedding dim; default 768 (nomic-embed-text)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import logging
|
|
import time
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from chromadb import EmbeddingFunction, Documents, Embeddings
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
DEFAULT_OLLAMA_URL = "http://192.168.0.2:11434"
|
|
|
|
|
|
def _resolve_urls() -> list[str]:
|
|
raw = os.environ.get("OLLAMA_URLS", "").strip()
|
|
if raw:
|
|
return [u.strip().rstrip("/") for u in raw.split(",") if u.strip()]
|
|
single = os.environ.get("OLLAMA_URL", DEFAULT_OLLAMA_URL).strip().rstrip("/")
|
|
return [single]
|
|
|
|
|
|
OLLAMA_URLS = _resolve_urls()
|
|
EMBED_MODEL = os.environ.get("EMBED_MODEL", "nomic-embed-text")
|
|
EMBED_DIM = int(os.environ.get("EMBED_DIM", "768"))
|
|
|
|
|
|
class OllamaEmbeddings(EmbeddingFunction):
|
|
"""Calls /api/embed across N Ollama endpoints, round-robin per batch.
|
|
|
|
For indexing throughput on multiple GPUs, run one Ollama container
|
|
per GPU (pinned via NVIDIA_VISIBLE_DEVICES) and pass all their URLs
|
|
in OLLAMA_URL — the embedder picks the next endpoint per batch.
|
|
|
|
Resilient (ported from zerto-docs PR #45): a failed call rotates to
|
|
the next endpoint and retries with backoff instead of failing the
|
|
whole rebuild. HTTP status errors additionally halve the input —
|
|
the .0.125 Windows Ollama (4090) 400s when its model runner dies on
|
|
an oversized input array, and one endpoint rejecting a batch the
|
|
others accept shouldn't kill a multi-hour index build.
|
|
"""
|
|
|
|
def __init__(self, urls: list[str] = OLLAMA_URLS, model: str = EMBED_MODEL):
|
|
self.urls = urls
|
|
self.model = model
|
|
self._next = 0
|
|
|
|
def __call__(self, input: Documents) -> Embeddings:
|
|
return self._embed(list(input), attempt=1)
|
|
|
|
def _embed(self, texts: list, attempt: int) -> Embeddings:
|
|
url = self.urls[self._next % len(self.urls)]
|
|
self._next += 1
|
|
try:
|
|
with httpx.Client(timeout=300) as c:
|
|
r = c.post(f"{url}/api/embed",
|
|
json={"model": self.model, "input": texts})
|
|
r.raise_for_status()
|
|
return r.json().get("embeddings") or []
|
|
except (httpx.TransportError, httpx.HTTPStatusError) as e:
|
|
if isinstance(e, httpx.HTTPStatusError):
|
|
desc = f"HTTP {e.response.status_code} ({e.response.text[:200]})"
|
|
else:
|
|
desc = f"transport error {type(e).__name__}"
|
|
if attempt >= 5:
|
|
log.error("%s from %s (%d texts) — giving up after %d attempts",
|
|
desc, url, len(texts), attempt)
|
|
raise
|
|
if isinstance(e, httpx.HTTPStatusError) and len(texts) > 16:
|
|
mid = len(texts) // 2
|
|
log.warning("%s from %s — splitting %d texts into %d+%d (attempt %d)",
|
|
desc, url, len(texts), mid, len(texts) - mid, attempt)
|
|
return (self._embed(texts[:mid], attempt + 1)
|
|
+ self._embed(texts[mid:], attempt + 1))
|
|
backoff = 0.5 * (2 ** (attempt - 1)) # 0.5, 1, 2, 4
|
|
log.warning("%s (attempt %d, %s) — retrying in %.1fs",
|
|
desc, attempt, url, backoff)
|
|
time.sleep(backoff)
|
|
return self._embed(texts, attempt + 1)
|
|
|
|
def name(self) -> str: # newer chromadb requires this
|
|
return f"ollama:{self.model}"
|
|
|
|
@staticmethod
|
|
def build_from_config(config: dict) -> "OllamaEmbeddings": # newer chromadb
|
|
return OllamaEmbeddings(
|
|
urls=config.get("urls", OLLAMA_URLS),
|
|
model=config.get("model", EMBED_MODEL),
|
|
)
|
|
|
|
def get_config(self) -> dict: # newer chromadb
|
|
return {"urls": self.urls, "model": self.model}
|
|
|
|
def default_space(self) -> str:
|
|
return "cosine"
|
|
|
|
def supported_spaces(self) -> list[str]:
|
|
return ["cosine", "l2", "ip"]
|
|
|
|
|
|
def embedding_function() -> EmbeddingFunction:
|
|
return OllamaEmbeddings()
|