From 48393692dc2fd472cc52e074a9e5414e78271661 Mon Sep 17 00:00:00 2001 From: Edwin Jose Date: Mon, 15 Jun 2026 10:09:17 -0500 Subject: [PATCH] Shim for MCP Auth --- src/app/routes/mcp.py | 42 +++++++++++++++++++++++++++++++++++++++++- src/config/settings.py | 12 ++++++++++++ src/dependencies.py | 13 +++++++++++-- 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/src/app/routes/mcp.py b/src/app/routes/mcp.py index 3922829b8..e2f70e60e 100644 --- a/src/app/routes/mcp.py +++ b/src/app/routes/mcp.py @@ -13,12 +13,52 @@ logger = get_logger(__name__) +class _ForwardJWTToMCPMiddleware: + """ASGI shim: copy the inbound ``Authorization`` JWT into a non-excluded + header so it survives FastMCP's proxy to the underlying /v1 handler. + + FastMCP's ``get_http_headers()`` strips ``authorization`` from the headers + it forwards when an MCP tool re-invokes a /v1 route, so a gateway-forwarded + JWT would otherwise be lost (the client then falls back to lakehouse creds, + which OpenSearch rejects with 401 and which carry no RBAC roles -> 403). + We *copy* (not move) the value into ``OPENRAG_MCP_JWT_HEADER`` (default + ``X-OpenRAG-JWT``), which is not in FastMCP's exclude set; the original + header is left intact. ``get_api_key_user_async`` reads it as a fallback. + + Scoped to the /mcp sub-app only, so normal /v1 and UI traffic is untouched. + """ + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + from config.settings import get_mcp_forwarded_jwt_header + + target = get_mcp_forwarded_jwt_header().lower().encode("latin-1") + headers = scope.get("headers", []) + auth = None + has_target = False + for name, value in headers: + lname = name.lower() + if lname == b"authorization": + auth = value + elif lname == target: + has_target = True + if auth is not None and not has_target: + scope = dict(scope) + scope["headers"] = list(headers) + [(target, auth)] + await self.app(scope, receive, send) + + def mount_mcp(app: FastAPI): """Mount the FastMCP app at /mcp and return its lifespan context manager.""" logger.info("Creating MCP server") mcp_server = create_mcp_server(app) mcp_http_app = mcp_server.http_app(transport="streamable-http", path="/") - app.mount("/mcp", mcp_http_app) + app.mount("/mcp", _ForwardJWTToMCPMiddleware(mcp_http_app)) logger.info("MCP server mounted at /mcp (streamable-http)") # FastMCP requires its own lifespan to be run so that the diff --git a/src/config/settings.py b/src/config/settings.py index 9adfba9f2..4f421f88f 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -270,6 +270,18 @@ def get_jwt_auth_header() -> str: return os.getenv("OPENRAG_JWT_AUTH_HEADER", "Authorization") +def get_mcp_forwarded_jwt_header() -> str: + """Header into which the /mcp ASGI shim copies the inbound JWT. + + FastMCP's get_http_headers() strips 'authorization' from the headers it + forwards to the underlying /v1 handler when an MCP tool is invoked, so a + JWT arriving on /mcp in the Authorization header never reaches the /v1 + auth dependency. The shim copies it into this (non-excluded) header so it + survives the proxy; get_api_key_user_async reads it as a fallback. + Read per-call so tests can override via monkeypatch.setenv.""" + return os.getenv("OPENRAG_MCP_JWT_HEADER", "X-OpenRAG-JWT") + + def get_jwt_issuer_verify_tls() -> bool: """Whether to verify TLS when fetching JWT signing keys from the token's ``iss`` URL (``verify_jwt_from_issuer``). Defaults to false for internal diff --git a/src/dependencies.py b/src/dependencies.py index 77236195c..a8ab729d7 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -844,10 +844,19 @@ async def get_api_key_user_async( # the user's roles (synced via request.state.jwt_roles -> # _attach_db_user_id), with a 401 when no recognized role is present. from auth.jwt_roles import jwt_roles_enabled - from config.settings import IBM_AUTH_ENABLED, get_jwt_auth_header + from config.settings import ( + IBM_AUTH_ENABLED, + get_jwt_auth_header, + get_mcp_forwarded_jwt_header, + ) from config.utils import resolve_jwt_claims - raw_jwt = request.headers.get(get_jwt_auth_header(), "") + # Primary: the gateway-forwarded JWT header (default Authorization). Fallback: + # the header the /mcp shim copies the JWT into, because FastMCP strips + # Authorization before proxying an MCP tool call to this /v1 handler. + raw_jwt = request.headers.get(get_jwt_auth_header(), "") or request.headers.get( + get_mcp_forwarded_jwt_header(), "" + ) logger.debug( "[AUTH] API-key path JWT header lookup", header_name=get_jwt_auth_header(),