diff --git a/src/config/settings.py b/src/config/settings.py index 9b0c5c49b..ed3672dfc 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -163,6 +163,8 @@ def get_ingest_callback_url() -> str: # os.environ directly. JWT_SIGNING_KEY = os.getenv("JWT_SIGNING_KEY") GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID") +MICROSOFT_GRAPH_OAUTH_CLIENT_ID = os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID") +MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET = os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET") GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET") # IBM AMS authentication (Watsonx Data embedded mode) diff --git a/src/connectors/google_drive_acl.py b/src/connectors/google_drive_acl.py index be2bb1cba..ef8226530 100644 --- a/src/connectors/google_drive_acl.py +++ b/src/connectors/google_drive_acl.py @@ -5,7 +5,6 @@ import asyncio from typing import Any -import jwt from googleapiclient.discovery import build from googleapiclient.errors import HttpError @@ -58,18 +57,28 @@ def google_drive_user_principal(user_email: str | None) -> str | None: def _email_from_id_token(id_token: str | None) -> str | None: + """Extract email from Google ID token with signature verification.""" if not id_token: return None try: - claims = jwt.decode( - id_token, - options={"verify_signature": False, "verify_aud": False}, - ) + from config.settings import GOOGLE_OAUTH_CLIENT_ID + from utils.jwt_verification import JWTVerificationError, verify_google_id_token + + if not GOOGLE_OAUTH_CLIENT_ID: + logger.error("GOOGLE_OAUTH_CLIENT_ID not configured - cannot verify ID token") + return None + + # Verify token with FULL validation + claims = verify_google_id_token(id_token, GOOGLE_OAUTH_CLIENT_ID) email = claims.get("email") if email: return str(email) + + except JWTVerificationError as e: + logger.warning("Google ID token verification failed", error=str(e)) except Exception as e: - logger.debug("Could not decode Google id_token email", error=str(e)) + logger.error("Unexpected error verifying Google ID token", error=str(e)) + return None diff --git a/src/connectors/microsoft_graph_acl.py b/src/connectors/microsoft_graph_acl.py index 1e2a17fb8..0e745980c 100644 --- a/src/connectors/microsoft_graph_acl.py +++ b/src/connectors/microsoft_graph_acl.py @@ -6,7 +6,6 @@ from typing import Any import httpx -import jwt from utils.group_acl import ( acl_principal_label, @@ -23,19 +22,33 @@ def tenant_id_from_access_token(access_token: str | None, fallback: str | None = None) -> str: - """Read the tenant id from a Microsoft access token without validating it.""" + """Extract tenant ID from Microsoft access token with signature verification.""" if access_token: raw_token = access_token.removeprefix("Bearer ").strip() try: - claims = jwt.decode( - raw_token, - options={"verify_signature": False, "verify_aud": False}, + from config.settings import MICROSOFT_GRAPH_OAUTH_CLIENT_ID + from utils.jwt_verification import ( + JWTVerificationError, + verify_microsoft_access_token, ) + + if not MICROSOFT_GRAPH_OAUTH_CLIENT_ID: + logger.error( + "MICROSOFT_GRAPH_OAUTH_CLIENT_ID not configured - cannot verify access token" + ) + return fallback or "common" + + # Verify token with FULL validation + claims = verify_microsoft_access_token(raw_token, MICROSOFT_GRAPH_OAUTH_CLIENT_ID) token_tenant = claims.get("tid") if token_tenant: return token_tenant + + except JWTVerificationError as e: + logger.warning("Microsoft access token verification failed", error=str(e)) except Exception as e: - logger.debug("Could not decode Microsoft access token tenant", error=str(e)) + logger.error("Unexpected error verifying Microsoft access token", error=str(e)) + return fallback or "common" @@ -204,14 +217,31 @@ async def get_current_user_microsoft_group_roles( def _decode_microsoft_user_identifiers(access_token: str, tenant_id: str | None) -> list[str]: + """Extract user identifiers from Microsoft access token with signature verification.""" raw_token = access_token.removeprefix("Bearer ").strip() try: - claims = jwt.decode( - raw_token, - options={"verify_signature": False, "verify_aud": False}, + from config.settings import MICROSOFT_GRAPH_OAUTH_CLIENT_ID + from utils.jwt_verification import ( + JWTVerificationError, + verify_microsoft_access_token, ) + + if not MICROSOFT_GRAPH_OAUTH_CLIENT_ID: + logger.error( + "MICROSOFT_GRAPH_OAUTH_CLIENT_ID not configured - cannot verify access token" + ) + return [] + + # Verify token with FULL validation + claims = verify_microsoft_access_token( + raw_token, MICROSOFT_GRAPH_OAUTH_CLIENT_ID, tenant_id=tenant_id + ) + + except JWTVerificationError as e: + logger.warning("Microsoft access token verification failed", error=str(e)) + return [] except Exception as e: - logger.debug("Could not decode Microsoft access token user identifiers", error=str(e)) + logger.error("Unexpected error verifying Microsoft access token", error=str(e)) return [] identifiers: list[str] = [] diff --git a/src/utils/jwt_verification.py b/src/utils/jwt_verification.py new file mode 100644 index 000000000..e9caea0f8 --- /dev/null +++ b/src/utils/jwt_verification.py @@ -0,0 +1,292 @@ +"""JWT signature verification utilities for OAuth tokens.""" + +from __future__ import annotations + +from typing import Any + +import httpx +import jwt +from cachetools import TTLCache +from jwt.algorithms import RSAAlgorithm + +from utils.logging_config import get_logger + +logger = get_logger(__name__) + +# JWKS cache: 1 hour TTL, max 10 entries +_jwks_cache: TTLCache = TTLCache(maxsize=10, ttl=3600) + +# JWKS endpoints +GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs" +MICROSOFT_JWKS_URL_TEMPLATE = "https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys" + + +class JWTVerificationError(Exception): + """Base exception for JWT verification errors.""" + + pass + + +class InvalidSignatureError(JWTVerificationError): + """JWT signature is invalid.""" + + pass + + +class ExpiredTokenError(JWTVerificationError): + """JWT token has expired.""" + + pass + + +class InvalidAudienceError(JWTVerificationError): + """JWT audience claim is invalid.""" + + pass + + +class InvalidIssuerError(JWTVerificationError): + """JWT issuer claim is invalid.""" + + pass + + +def _fetch_jwks(url: str) -> dict[str, Any]: + """ + Fetch JWKS from URL with caching. + + Args: + url: JWKS endpoint URL + + Returns: + JWKS dictionary + + Raises: + JWTVerificationError: If JWKS fetch fails + """ + # Check cache first + if url in _jwks_cache: + logger.debug(f"JWKS cache hit for {url}") + return _jwks_cache[url] + + # Fetch from endpoint + try: + logger.debug(f"Fetching JWKS from {url}") + response = httpx.get(url, timeout=5.0) + response.raise_for_status() + jwks = response.json() + + # Cache the result + _jwks_cache[url] = jwks + logger.debug(f"JWKS cached for {url}") + + return jwks + except Exception as e: + logger.error(f"Failed to fetch JWKS from {url}", error=str(e)) + raise JWTVerificationError(f"Failed to fetch JWKS: {e}") from e + + +def _get_signing_key(token: str, jwks: dict[str, Any]) -> Any: + """ + Extract signing key from JWKS based on token header. + + Args: + token: JWT token + jwks: JWKS dictionary + + Returns: + RSA public key object + + Raises: + JWTVerificationError: If key not found + """ + try: + # Decode header without verification to get kid + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + + if not kid: + raise JWTVerificationError("Token header missing 'kid' field") + + # Find matching key in JWKS + for key in jwks.get("keys", []): + if key.get("kid") == kid: + # Use PyJWT's built-in JWK to PEM conversion + public_key = RSAAlgorithm.from_jwk(key) + return public_key + + raise JWTVerificationError(f"Signing key with kid '{kid}' not found in JWKS") + + except jwt.DecodeError as e: + raise JWTVerificationError(f"Failed to decode token header: {e}") from e + + +def verify_google_id_token(token: str, client_id: str) -> dict[str, Any]: + """ + Verify Google ID token with FULL validation. + + Performs: + - Signature verification using Google's JWKS + - Issuer validation (accounts.google.com) + - Expiration validation + - Audience validation (requires client_id) + + Args: + token: Google ID token (JWT) + client_id: Expected audience (Google OAuth client ID) + + Returns: + Verified token claims + + Raises: + InvalidSignatureError: If signature is invalid + ExpiredTokenError: If token has expired + InvalidAudienceError: If audience doesn't match + InvalidIssuerError: If issuer is invalid + JWTVerificationError: For other verification failures + """ + if not client_id: + raise JWTVerificationError("client_id is required for Google ID token verification") + + try: + # Fetch JWKS + jwks = _fetch_jwks(GOOGLE_JWKS_URL) + + # Get signing key + signing_key = _get_signing_key(token, jwks) + + # Verify token with FULL validation + claims = jwt.decode( + token, + signing_key, + algorithms=["RS256"], + audience=client_id, + issuer=["https://accounts.google.com", "accounts.google.com"], + options={ + "verify_signature": True, + "verify_exp": True, + "verify_aud": True, + "verify_iss": True, + }, + ) + + logger.debug("Google ID token verified successfully") + return claims + + except jwt.InvalidSignatureError as e: + logger.warning("Google ID token has invalid signature", error=str(e)) + raise InvalidSignatureError(f"Invalid signature: {e}") from e + except jwt.ExpiredSignatureError as e: + logger.warning("Google ID token has expired", error=str(e)) + raise ExpiredTokenError(f"Token expired: {e}") from e + except jwt.InvalidAudienceError as e: + logger.warning("Google ID token has invalid audience", error=str(e)) + raise InvalidAudienceError(f"Invalid audience: {e}") from e + except jwt.InvalidIssuerError as e: + logger.warning("Google ID token has invalid issuer", error=str(e)) + raise InvalidIssuerError(f"Invalid issuer: {e}") from e + except JWTVerificationError: + raise + except Exception as e: + logger.error("Google ID token verification failed", error=str(e)) + raise JWTVerificationError(f"Verification failed: {e}") from e + + +def verify_microsoft_access_token( + token: str, client_id: str, tenant_id: str | None = None +) -> dict[str, Any]: + """ + Verify Microsoft access token with FULL validation. + + Performs: + - Signature verification using Microsoft's JWKS + - Issuer validation + - Expiration validation + - Audience validation (requires client_id) + + Args: + token: Microsoft access token (JWT) + client_id: Expected audience (Microsoft Graph OAuth client ID) + tenant_id: Optional tenant ID for JWKS endpoint (extracted from token if not provided) + + Returns: + Verified token claims + + Raises: + InvalidSignatureError: If signature is invalid + ExpiredTokenError: If token has expired + InvalidAudienceError: If audience doesn't match + InvalidIssuerError: If issuer is invalid + JWTVerificationError: For other verification failures + """ + if not client_id: + raise JWTVerificationError("client_id is required for Microsoft access token verification") + + try: + # Extract tenant from token if not provided + if not tenant_id: + unverified_claims = jwt.decode(token, options={"verify_signature": False}) + tenant_id = unverified_claims.get("tid", "common") + logger.debug(f"Extracted tenant_id from token: {tenant_id}") + + # Fetch JWKS for this tenant + jwks_url = MICROSOFT_JWKS_URL_TEMPLATE.format(tenant=tenant_id) + jwks = _fetch_jwks(jwks_url) + + # Get signing key + signing_key = _get_signing_key(token, jwks) + + # Verify token with FULL validation + # Note: Microsoft tokens may have audience as client_id or resource URL + claims = jwt.decode( + token, + signing_key, + algorithms=["RS256"], + audience=client_id, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_aud": True, + "verify_iss": True, + }, + ) + + # Additional issuer validation for Microsoft + issuer = claims.get("iss", "") + expected_issuer_patterns = [ + f"https://login.microsoftonline.com/{tenant_id}/v2.0", + f"https://sts.windows.net/{tenant_id}/", + ] + + if not any(issuer.startswith(pattern) for pattern in expected_issuer_patterns): + raise InvalidIssuerError(f"Unexpected issuer: {issuer}") + + logger.debug("Microsoft access token verified successfully") + return claims + + except jwt.InvalidSignatureError as e: + logger.warning("Microsoft access token has invalid signature", error=str(e)) + raise InvalidSignatureError(f"Invalid signature: {e}") from e + except jwt.ExpiredSignatureError as e: + logger.warning("Microsoft access token has expired", error=str(e)) + raise ExpiredTokenError(f"Token expired: {e}") from e + except jwt.InvalidAudienceError as e: + logger.warning("Microsoft access token has invalid audience", error=str(e)) + raise InvalidAudienceError(f"Invalid audience: {e}") from e + except (jwt.InvalidIssuerError, InvalidIssuerError) as e: + logger.warning("Microsoft access token has invalid issuer", error=str(e)) + raise InvalidIssuerError(f"Invalid issuer: {e}") from e + except JWTVerificationError: + raise + except Exception as e: + logger.error("Microsoft access token verification failed", error=str(e)) + raise JWTVerificationError(f"Verification failed: {e}") from e + + +def clear_jwks_cache(): + """Clear the JWKS cache. Useful for testing.""" + _jwks_cache.clear() + logger.debug("JWKS cache cleared") + + +# Made with Bob