Files
hvm-docs/rag/embeddings.py
T
justin 6b11993688 ci: use zerto-docs's load-balanced Ollama GPU pool on the Gitea host
Match the OLLAMA_URLS pattern from zerto-docs-rag so every docs MCP
build fans out across the same two GPU-pinned Ollama containers on
192.168.0.2 (:11435 Titan X, :11436 1080 Ti). The host's primary
Ollama on :11434 is left alone for OpenWebUI.

rag.embeddings now reads OLLAMA_URLS (plural CSV) preferentially with
fallback to OLLAMA_URL, defaulting to http://192.168.0.2:11434 — same
shape as zerto's embeddings.py. The OllamaEmbeddings class already
round-robins per batch, so both GPUs run in parallel during the
chroma rebuild.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 13:22:59 -04:00

90 lines
3.1 KiB
Python

"""Embedding function for Chroma — Ollama-hosted nomic-embed-text by default.
Swappable: implement the same `embedding_function()` interface returning
a Chroma `EmbeddingFunction` and the rest of the pipeline doesn't care.
Env-configurable (matches the zerto-docs-rag pattern so the same Gitea
runner + GPU-pinned Ollama containers can serve every docs MCP build):
OLLAMA_URLS comma-separated list, load-balanced round-robin per batch.
Preferred — set in the CI workflow to fan out across two
GPU-pinned Ollama containers on the Gitea host.
OLLAMA_URL single endpoint, fallback when OLLAMA_URLS is unset.
Default http://192.168.0.2:11434 (the host where the GPUs
live in Justin's lab).
EMBED_MODEL model name; default 'nomic-embed-text'
EMBED_DIM expected embedding dim; default 768 (nomic-embed-text)
"""
from __future__ import annotations
import os
import logging
from typing import Any
import httpx
from chromadb import EmbeddingFunction, Documents, Embeddings
log = logging.getLogger(__name__)
DEFAULT_OLLAMA_URL = "http://192.168.0.2:11434"
def _resolve_urls() -> list[str]:
raw = os.environ.get("OLLAMA_URLS", "").strip()
if raw:
return [u.strip().rstrip("/") for u in raw.split(",") if u.strip()]
single = os.environ.get("OLLAMA_URL", DEFAULT_OLLAMA_URL).strip().rstrip("/")
return [single]
OLLAMA_URLS = _resolve_urls()
EMBED_MODEL = os.environ.get("EMBED_MODEL", "nomic-embed-text")
EMBED_DIM = int(os.environ.get("EMBED_DIM", "768"))
class OllamaEmbeddings(EmbeddingFunction):
"""Calls /api/embed across N Ollama endpoints, naive round-robin.
For indexing throughput on multiple GPUs, run one Ollama container
per GPU (pinned via NVIDIA_VISIBLE_DEVICES) and pass all their URLs
in OLLAMA_URL — the embedder picks the next endpoint per batch.
"""
def __init__(self, urls: list[str] = OLLAMA_URLS, model: str = EMBED_MODEL):
self.urls = urls
self.model = model
self._next = 0
def __call__(self, input: Documents) -> Embeddings:
url = self.urls[self._next % len(self.urls)]
self._next += 1
with httpx.Client(timeout=300) as c:
r = c.post(f"{url}/api/embed",
json={"model": self.model, "input": list(input)})
r.raise_for_status()
data = r.json()
return data.get("embeddings") or []
def name(self) -> str: # newer chromadb requires this
return f"ollama:{self.model}"
@staticmethod
def build_from_config(config: dict) -> "OllamaEmbeddings": # newer chromadb
return OllamaEmbeddings(
urls=config.get("urls", OLLAMA_URLS),
model=config.get("model", EMBED_MODEL),
)
def get_config(self) -> dict: # newer chromadb
return {"urls": self.urls, "model": self.model}
def default_space(self) -> str:
return "cosine"
def supported_spaces(self) -> list[str]:
return ["cosine", "l2", "ip"]
def embedding_function() -> EmbeddingFunction:
return OllamaEmbeddings()