"""HTTP Basic auth in front of the FastMCP Streamable-HTTP transport. MetaMCP can inject ``Authorization: Basic `` on every upstream call, so this is the simplest robust gate. Two env vars (``AG_BIDS_MCP_USER`` and ``AG_BIDS_MCP_PASS``) are required at process startup; the server fails closed if either is missing. Stdio transport (local dev with Claude Desktop) skips this entirely — no HTTP layer exists in stdio mode. """ from __future__ import annotations import base64 import logging import os import secrets from starlette.requests import Request from starlette.responses import PlainTextResponse, Response log = logging.getLogger(__name__) REALM = "ag-bids MCP" def expected_credentials() -> tuple[str, str]: """Return the (user, pass) the server enforces. Raises if missing.""" u = os.environ.get("AG_BIDS_MCP_USER", "") p = os.environ.get("AG_BIDS_MCP_PASS", "") if not u or not p: raise RuntimeError( "AG_BIDS_MCP_USER and AG_BIDS_MCP_PASS must both be set for HTTP " "Basic auth on the ag-bids MCP server." ) return u, p def _decode_basic(header: str) -> tuple[str, str] | None: if not header or not header.lower().startswith("basic "): return None try: decoded = base64.b64decode(header.split(" ", 1)[1]).decode("utf-8") except (ValueError, UnicodeDecodeError): return None user, _, pw = decoded.partition(":") return user, pw def _check(presented_user: str, presented_pass: str) -> bool: expected_user, expected_pass = expected_credentials() return ( secrets.compare_digest(presented_user, expected_user) and secrets.compare_digest(presented_pass, expected_pass) ) def _unauthorized() -> Response: return PlainTextResponse( "Unauthorized", status_code=401, headers={"WWW-Authenticate": f'Basic realm="{REALM}"'}, ) async def basic_auth_middleware(request: Request, call_next): """Starlette middleware that 401s anything missing valid Basic creds.""" creds = _decode_basic(request.headers.get("authorization", "")) if creds is None: log.info("auth: missing/malformed Authorization header (path=%s)", request.url.path) return _unauthorized() user, pw = creds if not _check(user, pw): log.info("auth: bad credentials (user=%r path=%s)", user, request.url.path) return _unauthorized() return await call_next(request)