"""HTTP Basic middleware tests.""" from __future__ import annotations import base64 import importlib import os import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) def _reload_auth(monkeypatch, user="alice", password="hunter2"): monkeypatch.setenv("AG_BIDS_MCP_USER", user) monkeypatch.setenv("AG_BIDS_MCP_PASS", password) from ag_bids_mcp import auth importlib.reload(auth) return auth def _b64(creds: str) -> str: return base64.b64encode(creds.encode()).decode() def test_expected_credentials_requires_both(monkeypatch): monkeypatch.delenv("AG_BIDS_MCP_USER", raising=False) monkeypatch.delenv("AG_BIDS_MCP_PASS", raising=False) from ag_bids_mcp import auth importlib.reload(auth) import pytest with pytest.raises(RuntimeError): auth.expected_credentials() monkeypatch.setenv("AG_BIDS_MCP_USER", "alice") monkeypatch.delenv("AG_BIDS_MCP_PASS", raising=False) importlib.reload(auth) with pytest.raises(RuntimeError): auth.expected_credentials() def test_middleware_via_starlette_app(monkeypatch): """End-to-end: a Starlette app with the middleware returns 401 / 200 correctly.""" auth = _reload_auth(monkeypatch, "alice", "hunter2") from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse from starlette.routing import Route from starlette.testclient import TestClient async def hello(request): return PlainTextResponse("ok") app = Starlette( routes=[Route("/x", endpoint=hello)], middleware=[Middleware(BaseHTTPMiddleware, dispatch=auth.basic_auth_middleware)], ) c = TestClient(app) # No header -> 401 + WWW-Authenticate r = c.get("/x") assert r.status_code == 401 assert r.headers.get("www-authenticate", "").startswith("Basic") # Wrong creds -> 401 r = c.get("/x", headers={"Authorization": "Basic " + _b64("alice:wrong")}) assert r.status_code == 401 # Wrong username, right password -> 401 r = c.get("/x", headers={"Authorization": "Basic " + _b64("eve:hunter2")}) assert r.status_code == 401 # Right creds -> 200 r = c.get("/x", headers={"Authorization": "Basic " + _b64("alice:hunter2")}) assert r.status_code == 200 assert r.text == "ok" def test_malformed_authorization_header(monkeypatch): auth = _reload_auth(monkeypatch) from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse from starlette.routing import Route from starlette.testclient import TestClient async def hello(request): return PlainTextResponse("ok") app = Starlette( routes=[Route("/x", endpoint=hello)], middleware=[Middleware(BaseHTTPMiddleware, dispatch=auth.basic_auth_middleware)], ) c = TestClient(app) # Not even base64 r = c.get("/x", headers={"Authorization": "Basic !!!not_b64!!!"}) assert r.status_code == 401 # Bearer instead of Basic r = c.get("/x", headers={"Authorization": "Bearer abc"}) assert r.status_code == 401