From f0c1108ae1b4781efc32cf552b0131d7d030feff Mon Sep 17 00:00:00 2001 From: rodageve Date: Tue, 16 Jun 2026 15:43:22 -0400 Subject: [PATCH 1/3] Microsoft and google connector OAuth JWT validation --- src/config/settings.py | 2 + src/connectors/google_drive_acl.py | 20 +- src/connectors/microsoft_graph_acl.py | 49 ++++- src/utils/jwt_verification.py | 295 ++++++++++++++++++++++++++ 4 files changed, 352 insertions(+), 14 deletions(-) create mode 100644 src/utils/jwt_verification.py diff --git a/src/config/settings.py b/src/config/settings.py index 9b0c5c49b..ed3672dfc 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -163,6 +163,8 @@ def get_ingest_callback_url() -> str: # os.environ directly. JWT_SIGNING_KEY = os.getenv("JWT_SIGNING_KEY") GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID") +MICROSOFT_GRAPH_OAUTH_CLIENT_ID = os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID") +MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET = os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET") GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET") # IBM AMS authentication (Watsonx Data embedded mode) diff --git a/src/connectors/google_drive_acl.py b/src/connectors/google_drive_acl.py index be2bb1cba..7402c0436 100644 --- a/src/connectors/google_drive_acl.py +++ b/src/connectors/google_drive_acl.py @@ -58,18 +58,28 @@ def google_drive_user_principal(user_email: str | None) -> str | None: def _email_from_id_token(id_token: str | None) -> str | None: + """Extract email from Google ID token with signature verification.""" if not id_token: return None try: - claims = jwt.decode( - id_token, - options={"verify_signature": False, "verify_aud": False}, - ) + from config.settings import GOOGLE_OAUTH_CLIENT_ID + from utils.jwt_verification import verify_google_id_token, JWTVerificationError + + if not GOOGLE_OAUTH_CLIENT_ID: + logger.error("GOOGLE_OAUTH_CLIENT_ID not configured - cannot verify ID token") + return None + + # Verify token with FULL validation + claims = verify_google_id_token(id_token, GOOGLE_OAUTH_CLIENT_ID) email = claims.get("email") if email: return str(email) + + except JWTVerificationError as e: + logger.warning("Google ID token verification failed", error=str(e)) except Exception as e: - logger.debug("Could not decode Google id_token email", error=str(e)) + logger.error("Unexpected error verifying Google ID token", error=str(e)) + return None diff --git a/src/connectors/microsoft_graph_acl.py b/src/connectors/microsoft_graph_acl.py index 1e2a17fb8..5626e5c50 100644 --- a/src/connectors/microsoft_graph_acl.py +++ b/src/connectors/microsoft_graph_acl.py @@ -23,19 +23,33 @@ def tenant_id_from_access_token(access_token: str | None, fallback: str | None = None) -> str: - """Read the tenant id from a Microsoft access token without validating it.""" + """Extract tenant ID from Microsoft access token with signature verification.""" if access_token: raw_token = access_token.removeprefix("Bearer ").strip() try: - claims = jwt.decode( - raw_token, - options={"verify_signature": False, "verify_aud": False}, + from config.settings import MICROSOFT_GRAPH_OAUTH_CLIENT_ID + from utils.jwt_verification import ( + verify_microsoft_access_token, + JWTVerificationError, ) + + if not MICROSOFT_GRAPH_OAUTH_CLIENT_ID: + logger.error( + "MICROSOFT_GRAPH_OAUTH_CLIENT_ID not configured - cannot verify access token" + ) + return fallback or "common" + + # Verify token with FULL validation + claims = verify_microsoft_access_token(raw_token, MICROSOFT_GRAPH_OAUTH_CLIENT_ID) token_tenant = claims.get("tid") if token_tenant: return token_tenant + + except JWTVerificationError as e: + logger.warning("Microsoft access token verification failed", error=str(e)) except Exception as e: - logger.debug("Could not decode Microsoft access token tenant", error=str(e)) + logger.error("Unexpected error verifying Microsoft access token", error=str(e)) + return fallback or "common" @@ -204,14 +218,31 @@ async def get_current_user_microsoft_group_roles( def _decode_microsoft_user_identifiers(access_token: str, tenant_id: str | None) -> list[str]: + """Extract user identifiers from Microsoft access token with signature verification.""" raw_token = access_token.removeprefix("Bearer ").strip() try: - claims = jwt.decode( - raw_token, - options={"verify_signature": False, "verify_aud": False}, + from config.settings import MICROSOFT_GRAPH_OAUTH_CLIENT_ID + from utils.jwt_verification import ( + verify_microsoft_access_token, + JWTVerificationError, ) + + if not MICROSOFT_GRAPH_OAUTH_CLIENT_ID: + logger.error( + "MICROSOFT_GRAPH_OAUTH_CLIENT_ID not configured - cannot verify access token" + ) + return [] + + # Verify token with FULL validation + claims = verify_microsoft_access_token( + raw_token, MICROSOFT_GRAPH_OAUTH_CLIENT_ID, tenant_id=tenant_id + ) + + except JWTVerificationError as e: + logger.warning("Microsoft access token verification failed", error=str(e)) + return [] except Exception as e: - logger.debug("Could not decode Microsoft access token user identifiers", error=str(e)) + logger.error("Unexpected error verifying Microsoft access token", error=str(e)) return [] identifiers: list[str] = [] diff --git a/src/utils/jwt_verification.py b/src/utils/jwt_verification.py new file mode 100644 index 000000000..3d579a441 --- /dev/null +++ b/src/utils/jwt_verification.py @@ -0,0 +1,295 @@ +"""JWT signature verification utilities for OAuth tokens.""" + +from __future__ import annotations + +from typing import Any + +import httpx +import jwt +from cachetools import TTLCache +from jwt.algorithms import RSAAlgorithm + +from utils.logging_config import get_logger + +logger = get_logger(__name__) + +# JWKS cache: 1 hour TTL, max 10 entries +_jwks_cache: TTLCache = TTLCache(maxsize=10, ttl=3600) + +# JWKS endpoints +GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs" +MICROSOFT_JWKS_URL_TEMPLATE = "https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys" + + +class JWTVerificationError(Exception): + """Base exception for JWT verification errors.""" + + pass + + +class InvalidSignatureError(JWTVerificationError): + """JWT signature is invalid.""" + + pass + + +class ExpiredTokenError(JWTVerificationError): + """JWT token has expired.""" + + pass + + +class InvalidAudienceError(JWTVerificationError): + """JWT audience claim is invalid.""" + + pass + + +class InvalidIssuerError(JWTVerificationError): + """JWT issuer claim is invalid.""" + + pass + + +def _fetch_jwks(url: str) -> dict[str, Any]: + """ + Fetch JWKS from URL with caching. + + Args: + url: JWKS endpoint URL + + Returns: + JWKS dictionary + + Raises: + JWTVerificationError: If JWKS fetch fails + """ + # Check cache first + if url in _jwks_cache: + logger.debug(f"JWKS cache hit for {url}") + return _jwks_cache[url] + + # Fetch from endpoint + try: + logger.debug(f"Fetching JWKS from {url}") + response = httpx.get(url, timeout=5.0) + response.raise_for_status() + jwks = response.json() + + # Cache the result + _jwks_cache[url] = jwks + logger.debug(f"JWKS cached for {url}") + + return jwks + except Exception as e: + logger.error(f"Failed to fetch JWKS from {url}", error=str(e)) + raise JWTVerificationError(f"Failed to fetch JWKS: {e}") + + +def _get_signing_key(token: str, jwks: dict[str, Any]) -> Any: + """ + Extract signing key from JWKS based on token header. + + Args: + token: JWT token + jwks: JWKS dictionary + + Returns: + RSA public key object + + Raises: + JWTVerificationError: If key not found + """ + try: + # Decode header without verification to get kid + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + + if not kid: + raise JWTVerificationError("Token header missing 'kid' field") + + # Find matching key in JWKS + for key in jwks.get("keys", []): + if key.get("kid") == kid: + # Use PyJWT's built-in JWK to PEM conversion + public_key = RSAAlgorithm.from_jwk(key) + return public_key + + raise JWTVerificationError(f"Signing key with kid '{kid}' not found in JWKS") + + except jwt.DecodeError as e: + raise JWTVerificationError(f"Failed to decode token header: {e}") + + +def verify_google_id_token(token: str, client_id: str) -> dict[str, Any]: + """ + Verify Google ID token with FULL validation. + + Performs: + - Signature verification using Google's JWKS + - Issuer validation (accounts.google.com) + - Expiration validation + - Audience validation (requires client_id) + + Args: + token: Google ID token (JWT) + client_id: Expected audience (Google OAuth client ID) + + Returns: + Verified token claims + + Raises: + InvalidSignatureError: If signature is invalid + ExpiredTokenError: If token has expired + InvalidAudienceError: If audience doesn't match + InvalidIssuerError: If issuer is invalid + JWTVerificationError: For other verification failures + """ + if not client_id: + raise JWTVerificationError( + "client_id is required for Google ID token verification" + ) + + try: + # Fetch JWKS + jwks = _fetch_jwks(GOOGLE_JWKS_URL) + + # Get signing key + signing_key = _get_signing_key(token, jwks) + + # Verify token with FULL validation + claims = jwt.decode( + token, + signing_key, + algorithms=["RS256"], + audience=client_id, + issuer=["https://accounts.google.com", "accounts.google.com"], + options={ + "verify_signature": True, + "verify_exp": True, + "verify_aud": True, + "verify_iss": True, + }, + ) + + logger.debug("Google ID token verified successfully") + return claims + + except jwt.InvalidSignatureError as e: + logger.warning("Google ID token has invalid signature", error=str(e)) + raise InvalidSignatureError(f"Invalid signature: {e}") + except jwt.ExpiredSignatureError as e: + logger.warning("Google ID token has expired", error=str(e)) + raise ExpiredTokenError(f"Token expired: {e}") + except jwt.InvalidAudienceError as e: + logger.warning("Google ID token has invalid audience", error=str(e)) + raise InvalidAudienceError(f"Invalid audience: {e}") + except jwt.InvalidIssuerError as e: + logger.warning("Google ID token has invalid issuer", error=str(e)) + raise InvalidIssuerError(f"Invalid issuer: {e}") + except JWTVerificationError: + raise + except Exception as e: + logger.error("Google ID token verification failed", error=str(e)) + raise JWTVerificationError(f"Verification failed: {e}") + + +def verify_microsoft_access_token( + token: str, client_id: str, tenant_id: str | None = None +) -> dict[str, Any]: + """ + Verify Microsoft access token with FULL validation. + + Performs: + - Signature verification using Microsoft's JWKS + - Issuer validation + - Expiration validation + - Audience validation (requires client_id) + + Args: + token: Microsoft access token (JWT) + client_id: Expected audience (Microsoft Graph OAuth client ID) + tenant_id: Optional tenant ID for JWKS endpoint (extracted from token if not provided) + + Returns: + Verified token claims + + Raises: + InvalidSignatureError: If signature is invalid + ExpiredTokenError: If token has expired + InvalidAudienceError: If audience doesn't match + InvalidIssuerError: If issuer is invalid + JWTVerificationError: For other verification failures + """ + if not client_id: + raise JWTVerificationError( + "client_id is required for Microsoft access token verification" + ) + + try: + # Extract tenant from token if not provided + if not tenant_id: + unverified_claims = jwt.decode(token, options={"verify_signature": False}) + tenant_id = unverified_claims.get("tid", "common") + logger.debug(f"Extracted tenant_id from token: {tenant_id}") + + # Fetch JWKS for this tenant + jwks_url = MICROSOFT_JWKS_URL_TEMPLATE.format(tenant=tenant_id) + jwks = _fetch_jwks(jwks_url) + + # Get signing key + signing_key = _get_signing_key(token, jwks) + + # Verify token with FULL validation + # Note: Microsoft tokens may have audience as client_id or resource URL + claims = jwt.decode( + token, + signing_key, + algorithms=["RS256"], + audience=client_id, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_aud": True, + "verify_iss": True, + }, + ) + + # Additional issuer validation for Microsoft + issuer = claims.get("iss", "") + expected_issuer_patterns = [ + f"https://login.microsoftonline.com/{tenant_id}/v2.0", + f"https://sts.windows.net/{tenant_id}/", + ] + + if not any(issuer.startswith(pattern) for pattern in expected_issuer_patterns): + raise InvalidIssuerError(f"Unexpected issuer: {issuer}") + + logger.debug("Microsoft access token verified successfully") + return claims + + except jwt.InvalidSignatureError as e: + logger.warning("Microsoft access token has invalid signature", error=str(e)) + raise InvalidSignatureError(f"Invalid signature: {e}") + except jwt.ExpiredSignatureError as e: + logger.warning("Microsoft access token has expired", error=str(e)) + raise ExpiredTokenError(f"Token expired: {e}") + except jwt.InvalidAudienceError as e: + logger.warning("Microsoft access token has invalid audience", error=str(e)) + raise InvalidAudienceError(f"Invalid audience: {e}") + except (jwt.InvalidIssuerError, InvalidIssuerError) as e: + logger.warning("Microsoft access token has invalid issuer", error=str(e)) + raise InvalidIssuerError(f"Invalid issuer: {e}") + except JWTVerificationError: + raise + except Exception as e: + logger.error("Microsoft access token verification failed", error=str(e)) + raise JWTVerificationError(f"Verification failed: {e}") + + +def clear_jwks_cache(): + """Clear the JWKS cache. Useful for testing.""" + _jwks_cache.clear() + logger.debug("JWKS cache cleared") + +# Made with Bob From 87e7c96925c01e3082192ad10ba4eba94a53b396 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 22:31:54 +0000 Subject: [PATCH 2/3] style: ruff autofix (auto) --- src/connectors/google_drive_acl.py | 2 +- src/connectors/microsoft_graph_acl.py | 4 ++-- src/utils/jwt_verification.py | 9 +++------ 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/connectors/google_drive_acl.py b/src/connectors/google_drive_acl.py index 7402c0436..6a567bc56 100644 --- a/src/connectors/google_drive_acl.py +++ b/src/connectors/google_drive_acl.py @@ -63,7 +63,7 @@ def _email_from_id_token(id_token: str | None) -> str | None: return None try: from config.settings import GOOGLE_OAUTH_CLIENT_ID - from utils.jwt_verification import verify_google_id_token, JWTVerificationError + from utils.jwt_verification import JWTVerificationError, verify_google_id_token if not GOOGLE_OAUTH_CLIENT_ID: logger.error("GOOGLE_OAUTH_CLIENT_ID not configured - cannot verify ID token") diff --git a/src/connectors/microsoft_graph_acl.py b/src/connectors/microsoft_graph_acl.py index 5626e5c50..a6acd8685 100644 --- a/src/connectors/microsoft_graph_acl.py +++ b/src/connectors/microsoft_graph_acl.py @@ -29,8 +29,8 @@ def tenant_id_from_access_token(access_token: str | None, fallback: str | None = try: from config.settings import MICROSOFT_GRAPH_OAUTH_CLIENT_ID from utils.jwt_verification import ( - verify_microsoft_access_token, JWTVerificationError, + verify_microsoft_access_token, ) if not MICROSOFT_GRAPH_OAUTH_CLIENT_ID: @@ -223,8 +223,8 @@ def _decode_microsoft_user_identifiers(access_token: str, tenant_id: str | None) try: from config.settings import MICROSOFT_GRAPH_OAUTH_CLIENT_ID from utils.jwt_verification import ( - verify_microsoft_access_token, JWTVerificationError, + verify_microsoft_access_token, ) if not MICROSOFT_GRAPH_OAUTH_CLIENT_ID: diff --git a/src/utils/jwt_verification.py b/src/utils/jwt_verification.py index 3d579a441..79b5b5f25 100644 --- a/src/utils/jwt_verification.py +++ b/src/utils/jwt_verification.py @@ -146,9 +146,7 @@ def verify_google_id_token(token: str, client_id: str) -> dict[str, Any]: JWTVerificationError: For other verification failures """ if not client_id: - raise JWTVerificationError( - "client_id is required for Google ID token verification" - ) + raise JWTVerificationError("client_id is required for Google ID token verification") try: # Fetch JWKS @@ -222,9 +220,7 @@ def verify_microsoft_access_token( JWTVerificationError: For other verification failures """ if not client_id: - raise JWTVerificationError( - "client_id is required for Microsoft access token verification" - ) + raise JWTVerificationError("client_id is required for Microsoft access token verification") try: # Extract tenant from token if not provided @@ -292,4 +288,5 @@ def clear_jwks_cache(): _jwks_cache.clear() logger.debug("JWKS cache cleared") + # Made with Bob From 10abb2e628aed95cff67e4719ab2f6a65ae3d257 Mon Sep 17 00:00:00 2001 From: rodageve Date: Tue, 16 Jun 2026 18:37:14 -0400 Subject: [PATCH 3/3] Fix linting errors --- src/connectors/google_drive_acl.py | 1 - src/connectors/microsoft_graph_acl.py | 1 - src/utils/jwt_verification.py | 24 ++++++++++++------------ 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/connectors/google_drive_acl.py b/src/connectors/google_drive_acl.py index 6a567bc56..ef8226530 100644 --- a/src/connectors/google_drive_acl.py +++ b/src/connectors/google_drive_acl.py @@ -5,7 +5,6 @@ import asyncio from typing import Any -import jwt from googleapiclient.discovery import build from googleapiclient.errors import HttpError diff --git a/src/connectors/microsoft_graph_acl.py b/src/connectors/microsoft_graph_acl.py index a6acd8685..0e745980c 100644 --- a/src/connectors/microsoft_graph_acl.py +++ b/src/connectors/microsoft_graph_acl.py @@ -6,7 +6,6 @@ from typing import Any import httpx -import jwt from utils.group_acl import ( acl_principal_label, diff --git a/src/utils/jwt_verification.py b/src/utils/jwt_verification.py index 79b5b5f25..e9caea0f8 100644 --- a/src/utils/jwt_verification.py +++ b/src/utils/jwt_verification.py @@ -83,7 +83,7 @@ def _fetch_jwks(url: str) -> dict[str, Any]: return jwks except Exception as e: logger.error(f"Failed to fetch JWKS from {url}", error=str(e)) - raise JWTVerificationError(f"Failed to fetch JWKS: {e}") + raise JWTVerificationError(f"Failed to fetch JWKS: {e}") from e def _get_signing_key(token: str, jwks: dict[str, Any]) -> Any: @@ -118,7 +118,7 @@ def _get_signing_key(token: str, jwks: dict[str, Any]) -> Any: raise JWTVerificationError(f"Signing key with kid '{kid}' not found in JWKS") except jwt.DecodeError as e: - raise JWTVerificationError(f"Failed to decode token header: {e}") + raise JWTVerificationError(f"Failed to decode token header: {e}") from e def verify_google_id_token(token: str, client_id: str) -> dict[str, Any]: @@ -175,21 +175,21 @@ def verify_google_id_token(token: str, client_id: str) -> dict[str, Any]: except jwt.InvalidSignatureError as e: logger.warning("Google ID token has invalid signature", error=str(e)) - raise InvalidSignatureError(f"Invalid signature: {e}") + raise InvalidSignatureError(f"Invalid signature: {e}") from e except jwt.ExpiredSignatureError as e: logger.warning("Google ID token has expired", error=str(e)) - raise ExpiredTokenError(f"Token expired: {e}") + raise ExpiredTokenError(f"Token expired: {e}") from e except jwt.InvalidAudienceError as e: logger.warning("Google ID token has invalid audience", error=str(e)) - raise InvalidAudienceError(f"Invalid audience: {e}") + raise InvalidAudienceError(f"Invalid audience: {e}") from e except jwt.InvalidIssuerError as e: logger.warning("Google ID token has invalid issuer", error=str(e)) - raise InvalidIssuerError(f"Invalid issuer: {e}") + raise InvalidIssuerError(f"Invalid issuer: {e}") from e except JWTVerificationError: raise except Exception as e: logger.error("Google ID token verification failed", error=str(e)) - raise JWTVerificationError(f"Verification failed: {e}") + raise JWTVerificationError(f"Verification failed: {e}") from e def verify_microsoft_access_token( @@ -266,21 +266,21 @@ def verify_microsoft_access_token( except jwt.InvalidSignatureError as e: logger.warning("Microsoft access token has invalid signature", error=str(e)) - raise InvalidSignatureError(f"Invalid signature: {e}") + raise InvalidSignatureError(f"Invalid signature: {e}") from e except jwt.ExpiredSignatureError as e: logger.warning("Microsoft access token has expired", error=str(e)) - raise ExpiredTokenError(f"Token expired: {e}") + raise ExpiredTokenError(f"Token expired: {e}") from e except jwt.InvalidAudienceError as e: logger.warning("Microsoft access token has invalid audience", error=str(e)) - raise InvalidAudienceError(f"Invalid audience: {e}") + raise InvalidAudienceError(f"Invalid audience: {e}") from e except (jwt.InvalidIssuerError, InvalidIssuerError) as e: logger.warning("Microsoft access token has invalid issuer", error=str(e)) - raise InvalidIssuerError(f"Invalid issuer: {e}") + raise InvalidIssuerError(f"Invalid issuer: {e}") from e except JWTVerificationError: raise except Exception as e: logger.error("Microsoft access token verification failed", error=str(e)) - raise JWTVerificationError(f"Verification failed: {e}") + raise JWTVerificationError(f"Verification failed: {e}") from e def clear_jwks_cache():