diff --git a/pyproject.toml b/pyproject.toml index f9badac7..e26c280d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ data-relational = [ testing = [ "jsonpath-ng>=1.8.0", ] +argon2 = [ + "argon2-cffi>=23.1.0", # Argon2PasswordEncoder (OWASP-preferred password hashing) +] testcontainers = [ "testcontainers>=4.0.0", "pika>=1.3.0", # testcontainers' RabbitMqContainer imports pika for its readiness probe @@ -199,6 +202,9 @@ server = "pyfly.server.auto_configuration:ServerAutoConfiguration" event-loop = "pyfly.server.auto_configuration:EventLoopAutoConfiguration" security-jwt = "pyfly.security.auto_configuration:JwtAutoConfiguration" security-password = "pyfly.security.auto_configuration:PasswordEncoderAutoConfiguration" +security-http-basic = "pyfly.security.auto_configuration:HttpBasicAutoConfiguration" +security-form-login = "pyfly.security.auto_configuration:FormLoginAutoConfiguration" +security-logout = "pyfly.security.auto_configuration:LogoutAutoConfiguration" oauth2-resource-server = "pyfly.security.auto_configuration:OAuth2ResourceServerAutoConfiguration" oauth2-authorization-server = "pyfly.security.auto_configuration:OAuth2AuthorizationServerAutoConfiguration" oauth2-client = "pyfly.security.auto_configuration:OAuth2ClientAutoConfiguration" diff --git a/src/pyfly/idp/adapters/__init__.py b/src/pyfly/idp/adapters/__init__.py index 264175c7..7a07bf26 100644 --- a/src/pyfly/idp/adapters/__init__.py +++ b/src/pyfly/idp/adapters/__init__.py @@ -1,3 +1,28 @@ # Copyright 2026 Firefly Software Foundation. # Licensed under the Apache License, Version 2.0. """Concrete IDP adapters.""" + +from __future__ import annotations + +from pyfly.kernel.exceptions import SecurityException + + +def _require_password_grant_optin(allowed: bool, provider: str) -> None: + """Refuse the Resource Owner Password Credentials (ROPC) grant unless opted in. + + The ``grant_type=password`` flow (forwarding raw user credentials to an external + IdP) is removed by OAuth 2.1 and discouraged by RFC 9700 §2.4 — it cannot carry + MFA/step-up, defeats federation, and trains users to enter credentials into the + client. It is disabled by default; enable per-adapter with + ``allow_password_grant=True`` (config: ``pyfly.idp.allow-password-grant=true``) + only for a legacy integration with no migration path. Prefer the + authorization_code + PKCE login flow instead. + """ + if not allowed: + raise SecurityException( + f"The '{provider}' resource-owner-password (ROPC) login flow is disabled. " + "It is removed by OAuth 2.1 / discouraged by RFC 9700 §2.4. Use the " + "authorization_code + PKCE flow, or, only for a legacy integration, set " + "'pyfly.idp.allow-password-grant=true' (or allow_password_grant=True).", + code="ROPC_DISABLED", + ) diff --git a/src/pyfly/idp/adapters/aws_cognito.py b/src/pyfly/idp/adapters/aws_cognito.py index 87ba5ad5..25a078fa 100644 --- a/src/pyfly/idp/adapters/aws_cognito.py +++ b/src/pyfly/idp/adapters/aws_cognito.py @@ -8,6 +8,7 @@ import logging from typing import Any +from pyfly.idp.adapters import _require_password_grant_optin from pyfly.idp.models import ( AuthResult, IdpRole, @@ -40,12 +41,14 @@ def __init__( region: str, client_secret: str | None = None, client: Any | None = None, + allow_password_grant: bool = False, ) -> None: self._user_pool_id = user_pool_id self._client_id = client_id self._region = region self._client_secret = client_secret self._client = client + self._allow_password_grant = allow_password_grant def _secret_hash(self, username: str) -> str | None: """Cognito SECRET_HASH = Base64(HMAC-SHA256(secret, username + client_id)). @@ -148,6 +151,7 @@ async def list_users(self, *, limit: int = 100) -> list[IdpUser]: return [_from_cognito(u) for u in data.get("Users", [])] async def login(self, request: LoginRequest) -> AuthResult: + _require_password_grant_optin(self._allow_password_grant, "aws-cognito") client = self._ensure_client() auth_params = {"USERNAME": request.username, "PASSWORD": request.password} secret_hash = self._secret_hash(request.username) diff --git a/src/pyfly/idp/adapters/azure_ad.py b/src/pyfly/idp/adapters/azure_ad.py index 81bf5a77..ea9ef5b4 100644 --- a/src/pyfly/idp/adapters/azure_ad.py +++ b/src/pyfly/idp/adapters/azure_ad.py @@ -7,6 +7,7 @@ import logging from typing import Any +from pyfly.idp.adapters import _require_password_grant_optin from pyfly.idp.models import ( AuthResult, IdpRole, @@ -39,11 +40,13 @@ def __init__( client_id: str, client_secret: str, scope: str = "https://graph.microsoft.com/.default", + allow_password_grant: bool = False, ) -> None: self._tenant_id = tenant_id self._client_id = client_id self._client_secret = client_secret self._scope = scope + self._allow_password_grant = allow_password_grant self._app_token: str | None = None @property @@ -139,6 +142,7 @@ async def list_users(self, *, limit: int = 100) -> list[IdpUser]: return [_from_aad(u) for u in resp.json().get("value", [])] async def login(self, request: LoginRequest) -> AuthResult: + _require_password_grant_optin(self._allow_password_grant, "azure-ad") async with await self._client() as client: resp = await client.post( self._token_url, diff --git a/src/pyfly/idp/adapters/keycloak.py b/src/pyfly/idp/adapters/keycloak.py index 2325baea..288acdf5 100644 --- a/src/pyfly/idp/adapters/keycloak.py +++ b/src/pyfly/idp/adapters/keycloak.py @@ -15,6 +15,7 @@ import logging from typing import Any +from pyfly.idp.adapters import _require_password_grant_optin from pyfly.idp.models import ( AuthResult, IdpRole, @@ -49,12 +50,14 @@ def __init__( client_id: str, client_secret: str, verify_ssl: bool = True, + allow_password_grant: bool = False, ) -> None: self._base_url = base_url.rstrip("/") self._realm = realm self._client_id = client_id self._client_secret = client_secret self._verify = verify_ssl + self._allow_password_grant = allow_password_grant self._admin_token: str | None = None self._admin_token_expiry: float = 0.0 # monotonic deadline @@ -171,6 +174,7 @@ async def list_users(self, *, limit: int = 100) -> list[IdpUser]: return [_from_kc(u) for u in resp.json()] async def login(self, request: LoginRequest) -> AuthResult: + _require_password_grant_optin(self._allow_password_grant, "keycloak") async with await self._client() as client: resp = await client.post( self._token_url, diff --git a/src/pyfly/idp/auto_configuration.py b/src/pyfly/idp/auto_configuration.py index 642c1c76..670b9d32 100644 --- a/src/pyfly/idp/auto_configuration.py +++ b/src/pyfly/idp/auto_configuration.py @@ -27,6 +27,14 @@ class IdpAutoConfiguration: @bean def idp_adapter(self, config: Config) -> IdpAdapter: provider = str(config.get("pyfly.idp.provider", "internal-db")).lower() + # ROPC (grant_type=password) against an external IdP is removed by OAuth 2.1 + # and discouraged by RFC 9700 §2.4; it is off unless explicitly opted in. + allow_ropc = str(config.get("pyfly.idp.allow-password-grant", False)).strip().lower() in ( + "1", + "true", + "yes", + "on", + ) if provider == "keycloak": from pyfly.idp.adapters.keycloak import KeycloakIdpAdapter @@ -36,6 +44,7 @@ def idp_adapter(self, config: Config) -> IdpAdapter: realm=str(config.get("pyfly.idp.keycloak.realm", "")), client_id=str(config.get("pyfly.idp.keycloak.client-id", "")), client_secret=str(config.get("pyfly.idp.keycloak.client-secret", "")), + allow_password_grant=allow_ropc, ) if provider in ("cognito", "aws-cognito"): from pyfly.idp.adapters.aws_cognito import AwsCognitoIdpAdapter @@ -45,6 +54,7 @@ def idp_adapter(self, config: Config) -> IdpAdapter: client_id=str(config.get("pyfly.idp.cognito.client-id", "")), region=str(config.get("pyfly.idp.cognito.region", "")), client_secret=str(config.get("pyfly.idp.cognito.client-secret", "")) or None, + allow_password_grant=allow_ropc, ) if provider in ("azure-ad", "azuread", "entra"): from pyfly.idp.adapters.azure_ad import AzureAdIdpAdapter @@ -53,6 +63,7 @@ def idp_adapter(self, config: Config) -> IdpAdapter: tenant_id=str(config.get("pyfly.idp.azure.tenant-id", "")), client_id=str(config.get("pyfly.idp.azure.client-id", "")), client_secret=str(config.get("pyfly.idp.azure.client-secret", "")), + allow_password_grant=allow_ropc, ) from pyfly.idp.adapters.internal_db import InternalDbIdpAdapter diff --git a/src/pyfly/security/__init__.py b/src/pyfly/security/__init__.py index 1597f4ed..7c9b8af1 100644 --- a/src/pyfly/security/__init__.py +++ b/src/pyfly/security/__init__.py @@ -24,22 +24,33 @@ from pyfly.security.context import SecurityContext from pyfly.security.decorators import secure -from pyfly.security.expression import get_role_hierarchy, set_role_hierarchy +from pyfly.security.expression import ( + get_permission_evaluator, + get_role_hierarchy, + set_permission_evaluator, + set_role_hierarchy, +) from pyfly.security.http_security import AccessRule, AccessRuleType, HttpSecurity, SecurityRule -from pyfly.security.method_security import post_authorize, pre_authorize +from pyfly.security.method_security import post_authorize, post_filter, pre_authorize, pre_filter +from pyfly.security.permission import PermissionEvaluator from pyfly.security.role_hierarchy import RoleHierarchy __all__ = [ "AccessRule", "AccessRuleType", "HttpSecurity", + "PermissionEvaluator", "RoleHierarchy", "SecurityContext", "SecurityRule", + "get_permission_evaluator", "get_role_hierarchy", "post_authorize", + "post_filter", "pre_authorize", + "pre_filter", "secure", + "set_permission_evaluator", "set_role_hierarchy", ] @@ -61,8 +72,62 @@ pass try: - from pyfly.security.password import BcryptPasswordEncoder, PasswordEncoder + from pyfly.security.password import ( + Argon2PasswordEncoder, + BcryptPasswordEncoder, + DelegatingPasswordEncoder, + PasswordEncoder, + Pbkdf2PasswordEncoder, + ScryptPasswordEncoder, + create_delegating_password_encoder, + ) - __all__ += ["BcryptPasswordEncoder", "PasswordEncoder"] + __all__ += [ + "Argon2PasswordEncoder", + "BcryptPasswordEncoder", + "DelegatingPasswordEncoder", + "PasswordEncoder", + "Pbkdf2PasswordEncoder", + "ScryptPasswordEncoder", + "create_delegating_password_encoder", + ] +except ImportError: + pass + +try: + from pyfly.security.user_details import ( + InMemoryUserDetailsService, + UserDetails, + UserDetailsService, + ) + + __all__ += ["InMemoryUserDetailsService", "UserDetails", "UserDetailsService"] +except ImportError: + pass + +# AuthenticationProvider/DaoAuthenticationProvider transitively need a +# PasswordEncoder (bcrypt), so guard the import like the other optional pieces. +try: + from pyfly.security.authentication import ( + Authentication, + AuthenticationException, + AuthenticationProvider, + BadCredentialsException, + DaoAuthenticationProvider, + DisabledException, + ProviderManager, + ProviderNotFoundException, + ) + + __all__ += [ + "Authentication", + "AuthenticationException", + "AuthenticationProvider", + "BadCredentialsException", + "DaoAuthenticationProvider", + "DisabledException", + "ProviderManager", + "ProviderNotFoundException", + ] except ImportError: pass diff --git a/src/pyfly/security/authentication.py b/src/pyfly/security/authentication.py new file mode 100644 index 00000000..509c3e45 --- /dev/null +++ b/src/pyfly/security/authentication.py @@ -0,0 +1,158 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Authentication SPI — Spring's ``AuthenticationManager`` / ``AuthenticationProvider``. + +A :class:`ProviderManager` delegates an :class:`Authentication` request to the +first :class:`AuthenticationProvider` that ``supports`` it. The built-in +:class:`DaoAuthenticationProvider` checks a username/password against a +:class:`~pyfly.security.user_details.UserDetailsService` and a +:class:`~pyfly.security.password.PasswordEncoder`. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + +from pyfly.kernel.exceptions import SecurityException +from pyfly.security.context import SecurityContext +from pyfly.security.password import PasswordEncoder +from pyfly.security.user_details import UserDetailsService + + +class AuthenticationException(SecurityException): + """Base class for authentication failures.""" + + +class BadCredentialsException(AuthenticationException): + """The supplied credentials were invalid (or the principal is unknown).""" + + def __init__(self, message: str = "Bad credentials") -> None: + super().__init__(message, code="BAD_CREDENTIALS") + + +class DisabledException(AuthenticationException): + """The account exists but is disabled.""" + + def __init__(self, message: str = "Account is disabled") -> None: + super().__init__(message, code="ACCOUNT_DISABLED") + + +class ProviderNotFoundException(AuthenticationException): + """No configured provider could handle the authentication request.""" + + def __init__(self, message: str = "No authentication provider for this request") -> None: + super().__init__(message, code="PROVIDER_NOT_FOUND") + + +@dataclass +class Authentication: + """An authentication request or result (cf. Spring's ``Authentication``). + + Before authentication: ``principal`` + ``credentials`` are the submitted + username/password. After: ``authenticated`` is True, ``authorities`` / + ``roles`` / ``permissions`` are populated and ``credentials`` is erased. + """ + + principal: str + credentials: str | None = None + authorities: list[str] = field(default_factory=list) + authenticated: bool = False + roles: list[str] = field(default_factory=list) + permissions: list[str] = field(default_factory=list) + details: dict[str, Any] = field(default_factory=dict) + + def to_security_context(self) -> SecurityContext: + """Build a :class:`SecurityContext` from this (authenticated) result.""" + return SecurityContext( + user_id=self.principal if self.authenticated else None, + roles=list(self.roles), + permissions=list(self.permissions), + ) + + +@runtime_checkable +class AuthenticationProvider(Protocol): + """Authenticates an :class:`Authentication` it ``supports``.""" + + def supports(self, authentication: Authentication) -> bool: ... + + async def authenticate(self, authentication: Authentication) -> Authentication: ... + + +class DaoAuthenticationProvider: + """Authenticates username/password against a UserDetailsService + PasswordEncoder.""" + + def __init__(self, user_details_service: UserDetailsService, password_encoder: PasswordEncoder) -> None: + self._users = user_details_service + self._encoder = password_encoder + # A throw-away hash so an unknown user still incurs a verify() — equalising + # timing so the endpoint can't be used to enumerate valid usernames. + self._dummy_hash = password_encoder.hash("pyfly-dummy-password") + + def supports(self, authentication: Authentication) -> bool: + return bool(authentication.principal) and authentication.credentials is not None + + async def authenticate(self, authentication: Authentication) -> Authentication: + user = await self._users.load_user_by_username(authentication.principal) + credentials = authentication.credentials or "" + if user is None: + self._encoder.verify(credentials, self._dummy_hash) # constant-time-ish + raise BadCredentialsException() + if not self._encoder.verify(credentials, user.password_hash): + raise BadCredentialsException() + if not user.enabled: + raise DisabledException() + return Authentication( + principal=user.username, + credentials=None, + authenticated=True, + roles=list(user.roles), + permissions=list(user.permissions), + authorities=[*user.roles, *user.permissions], + details=dict(authentication.details), + ) + + +class ProviderManager: + """An :class:`AuthenticationManager` that consults providers in order.""" + + def __init__(self, *providers: AuthenticationProvider) -> None: + self._providers: list[AuthenticationProvider] = list(providers) + + @classmethod + def of(cls, providers: Iterable[AuthenticationProvider]) -> ProviderManager: + return cls(*providers) + + async def authenticate(self, authentication: Authentication) -> Authentication: + last_error: AuthenticationException | None = None + supported = False + for provider in self._providers: + if not provider.supports(authentication): + continue + supported = True + try: + result = await provider.authenticate(authentication) + except AuthenticationException as exc: + last_error = exc + continue + if result.authenticated: + result.credentials = None # erase credentials on success + return result + if last_error is not None: + raise last_error + if not supported: + raise ProviderNotFoundException() + raise BadCredentialsException() diff --git a/src/pyfly/security/auto_configuration.py b/src/pyfly/security/auto_configuration.py index 07f3727f..9c920a64 100644 --- a/src/pyfly/security/auto_configuration.py +++ b/src/pyfly/security/auto_configuration.py @@ -22,9 +22,10 @@ JWTService = object # type: ignore[misc,assignment] try: - from pyfly.security.password import BcryptPasswordEncoder + from pyfly.security.password import BcryptPasswordEncoder, DelegatingPasswordEncoder except ImportError: BcryptPasswordEncoder = object # type: ignore[misc,assignment] + DelegatingPasswordEncoder = object # type: ignore[misc,assignment] try: from pyfly.security.oauth2.resource_server import ( @@ -86,6 +87,61 @@ conditional_on_property, ) from pyfly.core.config import Config +from pyfly.kernel.exceptions import SecurityException + +# The built-in placeholder secret shipped in defaults. Signing tokens with it +# would let anyone who knows the (public) framework default forge tokens, so the +# composition root refuses to start when it is left in place. +_PLACEHOLDER_SECRET = "change-me-in-production" +# Minimum HMAC key length for the HS family (RFC 7518 §3.2: a key of the same +# size as the hash output — 256 bits / 32 bytes — for HS256). +_MIN_HS_SECRET_BYTES = 32 + + +def _resolve_signing_secret(config: Config, key: str, algorithm: str) -> str: + """Read a token-signing secret from *key* and refuse insecure values. + + Raises: + SecurityException: if the secret is unset (the built-in placeholder) or, + for an HMAC (``HS*``) algorithm, shorter than 32 bytes. + """ + secret = str(config.get(key, _PLACEHOLDER_SECRET)) + if secret == _PLACEHOLDER_SECRET: + raise SecurityException( + f"Refusing to start: '{key}' is unset, so the built-in placeholder secret " + f"would be used to sign tokens. Set '{key}' to a strong, randomly-generated " + f'value (e.g. `python -c "import secrets; print(secrets.token_urlsafe(48))"`).', + code="INSECURE_SIGNING_SECRET", + ) + if algorithm.upper().startswith("HS") and len(secret.encode("utf-8")) < _MIN_HS_SECRET_BYTES: + raise SecurityException( + f"Refusing to start: '{key}' must be at least {_MIN_HS_SECRET_BYTES} bytes for " + f"{algorithm} (RFC 7518 §3.2); got {len(secret.encode('utf-8'))} bytes.", + code="WEAK_SIGNING_SECRET", + ) + return secret + + +def _audience(config: Config, key: str) -> str | list[str] | None: + """Read a comma-separated / list audience value (single value collapsed to a + string), or ``None`` when unset.""" + raw = config.get(key) + if raw is None: + return None + if isinstance(raw, (list, tuple)): + values = [str(a).strip() for a in raw if str(a).strip()] + else: + values = [a.strip() for a in str(raw).split(",") if a.strip()] + if not values: + return None + return values[0] if len(values) == 1 else values + + +def _as_bool(value: Any) -> bool: + """Coerce a config value (bool or string like ``"true"``/``"false"``) to bool.""" + if isinstance(value, bool): + return value + return str(value).strip().lower() in ("1", "true", "yes", "on") def _exclude_patterns(config: Config, key: str) -> Sequence[str]: @@ -106,8 +162,16 @@ class JwtAutoConfiguration: @bean def jwt_service(self, config: Config) -> JWTService: - secret = str(config.get("pyfly.security.jwt.secret", "change-me-in-production")) algorithm = str(config.get("pyfly.security.jwt.algorithm", "HS256")) + # The symmetric secret is only enforced when the symmetric JWT filter is + # actually serving requests. A resource-server-only app (the recommended + # setup) enables ``pyfly.security.enabled`` for the JWKS validator and never + # uses this signer, so it must not be forced to invent a symmetric secret. + filter_enabled = str(config.get("pyfly.security.jwt.filter.enabled", "false")).lower() == "true" + if filter_enabled: + secret = _resolve_signing_secret(config, "pyfly.security.jwt.secret", algorithm) + else: + secret = str(config.get("pyfly.security.jwt.secret", _PLACEHOLDER_SECRET)) return JWTService(secret=secret, algorithm=algorithm) @bean @@ -135,6 +199,146 @@ def password_encoder(self, config: Config) -> BcryptPasswordEncoder: rounds = int(config.get("pyfly.security.password.bcrypt-rounds", 12)) return BcryptPasswordEncoder(rounds=rounds) + @bean + @conditional_on_property("pyfly.security.password.delegating.enabled", having_value="true") + def delegating_password_encoder(self, config: Config) -> DelegatingPasswordEncoder: + # Opt-in Spring-style {id}-prefixed encoder (bcrypt default, recognises + # {pbkdf2}/{scrypt}/{argon2}) enabling on-login algorithm migration. + from pyfly.security.password import create_delegating_password_encoder + + rounds = int(config.get("pyfly.security.password.bcrypt-rounds", 12)) + return create_delegating_password_encoder(bcrypt_rounds=rounds) + + +# --------------------------------------------------------------------------- +# HTTP Basic authentication +# --------------------------------------------------------------------------- + + +def _csv_or_list(value: Any) -> list[str]: + """Parse a comma-separated string or a list into a trimmed string list.""" + if value is None: + return [] + if isinstance(value, (list, tuple)): + return [str(v).strip() for v in value if str(v).strip()] + return [s.strip() for s in str(value).split(",") if s.strip()] + + +def _users_from_config(config: Config, key: str) -> list[Any]: + """Build :class:`UserDetails` from a config map of pre-hashed users at *key*.""" + from pyfly.security.user_details import UserDetails + + raw = config.get(key, {}) + users: list[Any] = [] + if isinstance(raw, dict): + for username, props in raw.items(): + if not isinstance(props, dict): + continue + users.append( + UserDetails( + username=str(username), + password_hash=str(props.get("password-hash", "")), + roles=_csv_or_list(props.get("roles")), + permissions=_csv_or_list(props.get("permissions")), + enabled=_as_bool(props.get("enabled", True)), + ) + ) + return users + + +@auto_configuration +@conditional_on_property("pyfly.security.http-basic.enabled", having_value="true") +@conditional_on_class("starlette") +@conditional_on_class("bcrypt") +class HttpBasicAutoConfiguration: + """Auto-configures HTTP Basic authentication from config (opt-in). + + Users are declared (with **pre-hashed** bcrypt passwords) under + ``pyfly.security.http-basic.users``:: + + pyfly: + security: + http-basic: + enabled: true + realm: "PyFly" + error-mode: "401" # or "anonymous" + users: + alice: + password-hash: "$2b$12$..." # never plaintext + roles: "ADMIN,USER" + + Apps needing a dynamic user store register their own + :class:`HttpBasicAuthenticationFilter` (a ``WebFilter`` bean) instead. + """ + + @bean + def http_basic_filter(self, config: Config) -> WebFilter: + from pyfly.security.password import BcryptPasswordEncoder + from pyfly.security.user_details import InMemoryUserDetailsService + from pyfly.web.adapters.starlette.filters.http_basic_filter import HttpBasicAuthenticationFilter + + users = _users_from_config(config, "pyfly.security.http-basic.users") + rounds = int(config.get("pyfly.security.password.bcrypt-rounds", 12)) + return HttpBasicAuthenticationFilter( + InMemoryUserDetailsService(*users), + BcryptPasswordEncoder(rounds=rounds), + realm=str(config.get("pyfly.security.http-basic.realm", "Realm")), + error_mode=str(config.get("pyfly.security.http-basic.error-mode", "anonymous")), + ) + + +@auto_configuration +@conditional_on_property("pyfly.security.form-login.enabled", having_value="true") +@conditional_on_class("starlette") +@conditional_on_class("bcrypt") +class FormLoginAutoConfiguration: + """Auto-configures form login from config (opt-in). + + Declares users (pre-hashed) under ``pyfly.security.form-login.users`` and tunes + URLs/params under ``pyfly.security.form-login.*``. Apps with a dynamic user + store register their own ``FormLoginFilter`` ``WebFilter`` bean instead. + """ + + @bean + def form_login_filter(self, config: Config) -> WebFilter: + from pyfly.security.authentication import DaoAuthenticationProvider, ProviderManager + from pyfly.security.password import BcryptPasswordEncoder + from pyfly.security.user_details import InMemoryUserDetailsService + from pyfly.web.adapters.starlette.filters.form_login_filter import FormLoginFilter + + users = _users_from_config(config, "pyfly.security.form-login.users") + rounds = int(config.get("pyfly.security.password.bcrypt-rounds", 12)) + manager = ProviderManager( + DaoAuthenticationProvider(InMemoryUserDetailsService(*users), BcryptPasswordEncoder(rounds=rounds)) + ) + return FormLoginFilter( + manager, + login_url=str(config.get("pyfly.security.form-login.login-url", "/login")), + username_param=str(config.get("pyfly.security.form-login.username-param", "username")), + password_param=str(config.get("pyfly.security.form-login.password-param", "password")), + success_url=str(config.get("pyfly.security.form-login.success-url", "/")), + failure_url=str(config.get("pyfly.security.form-login.failure-url", "/login?error")), + use_redirect=_as_bool(config.get("pyfly.security.form-login.use-redirect", True)), + ) + + +@auto_configuration +@conditional_on_property("pyfly.security.logout.enabled", having_value="true") +@conditional_on_class("starlette") +class LogoutAutoConfiguration: + """Auto-configures a generic logout filter (opt-in) from ``pyfly.security.logout.*``.""" + + @bean + def logout_filter(self, config: Config) -> WebFilter: + from pyfly.web.adapters.starlette.filters.logout_filter import LogoutFilter + + return LogoutFilter( + logout_url=str(config.get("pyfly.security.logout.logout-url", "/logout")), + logout_success_url=str(config.get("pyfly.security.logout.success-url", "/login?logout")), + delete_cookies=_csv_or_list(config.get("pyfly.security.logout.delete-cookies")), + use_redirect=_as_bool(config.get("pyfly.security.logout.use-redirect", True)), + ) + # --------------------------------------------------------------------------- # OAuth2 Resource Server @@ -200,6 +404,8 @@ def oauth2_resource_server_filter(self, token_validator: JWKSTokenValidator, con token_validator=token_validator, exclude_patterns=props.exclude_pattern_list(), error_mode=props.authenticate_error_mode, + enforce_sender_constraints=props.enforce_sender_constraints, + mtls_cert_header=props.mtls_cert_header, ) @@ -227,7 +433,7 @@ def authorization_server( client_registration_repository: InMemoryClientRegistrationRepository, container: Container, ) -> AuthorizationServer: - secret = str(config.get("pyfly.security.oauth2.authorization-server.secret", "change-me-in-production")) + secret = _resolve_signing_secret(config, "pyfly.security.oauth2.authorization-server.secret", "HS256") issuer = config.get("pyfly.security.oauth2.authorization-server.issuer") access_ttl = int(config.get("pyfly.security.oauth2.authorization-server.access-token-ttl", 3600)) refresh_ttl = int(config.get("pyfly.security.oauth2.authorization-server.refresh-token-ttl", 86400)) @@ -239,6 +445,7 @@ def authorization_server( access_token_ttl=access_ttl, refresh_token_ttl=refresh_ttl, issuer=str(issuer) if issuer is not None else None, + audience=_audience(config, "pyfly.security.oauth2.authorization-server.audience"), ) def _build_token_store(self, config: Config, container: Container, refresh_ttl: int) -> Any: @@ -335,6 +542,12 @@ def client_registration_repository(self, config: Config) -> InMemoryClientRegist jwks_uri=str(props.get("jwks-uri", "")), issuer_uri=str(props.get("issuer-uri", "")), provider_name=str(props.get("provider-name", "")), + # PKCE on by default (RFC 9700 / OAuth 2.1); opt out per + # registration with ``use-pkce: false``. + use_pkce=_as_bool(props.get("use-pkce", True)), + # RFC 9207 iss enforcement (opt-in; iss is validated when + # present regardless). + require_iss=_as_bool(props.get("require-iss", False)), ) ) diff --git a/src/pyfly/security/expression.py b/src/pyfly/security/expression.py index 02a71ab1..4cc93468 100644 --- a/src/pyfly/security/expression.py +++ b/src/pyfly/security/expression.py @@ -36,6 +36,7 @@ from pyfly.kernel.exceptions import SecurityException from pyfly.security.context import SecurityContext +from pyfly.security.permission import PermissionEvaluator from pyfly.security.role_hierarchy import RoleHierarchy _PARAM_RE = re.compile(r"#(\w+)") @@ -45,6 +46,10 @@ # RoleHierarchy bean). Configure once at startup via set_role_hierarchy(). _active_hierarchy: RoleHierarchy | None = None +# Process-wide PermissionEvaluator backing hasPermission(target, perm). When unset, +# hasPermission falls back to a flat permission check on the SecurityContext. +_active_permission_evaluator: PermissionEvaluator | None = None + def set_role_hierarchy(hierarchy: RoleHierarchy | None) -> None: """Install the role hierarchy used by method-security role checks (``None`` disables).""" @@ -57,6 +62,39 @@ def get_role_hierarchy() -> RoleHierarchy | None: return _active_hierarchy +def set_permission_evaluator(evaluator: PermissionEvaluator | None) -> None: + """Install the PermissionEvaluator used by ``hasPermission`` (``None`` disables).""" + global _active_permission_evaluator + _active_permission_evaluator = evaluator + + +def get_permission_evaluator() -> PermissionEvaluator | None: + """Return the currently installed PermissionEvaluator, if any.""" + return _active_permission_evaluator + + +def _eval_permission(ctx: SecurityContext, parts: tuple[Any, ...]) -> bool: + """Resolve a ``hasPermission(...)`` call against the evaluator or the context. + + Argument shapes (Spring parity): + * ``(permission,)`` — flat permission check + * ``(target, permission)`` — domain-object permission + * ``(target_id, target_type, perm)`` — identifier + type permission + """ + if not parts: + return False + evaluator = _active_permission_evaluator + if evaluator is None: + # No ACL evaluator: fall back to the principal's flat permissions. + return ctx.has_permission(str(parts[-1])) + if len(parts) == 1: + return evaluator.has_permission(ctx, None, str(parts[0])) + if len(parts) == 2: + return evaluator.has_permission(ctx, parts[0], str(parts[1])) + target_id, target_type, permission = parts[-3], parts[-2], parts[-1] + return evaluator.has_permission(ctx, target_id, str(permission), target_type=str(target_type)) + + def _effective_roles(ctx: SecurityContext) -> set[str]: """The principal's roles, expanded through the active hierarchy when one is set.""" if _active_hierarchy is None: @@ -102,7 +140,9 @@ def _has_authority(ctx: SecurityContext, authority: Any) -> bool: return _has_role(ctx, name) or ctx.has_permission(name) -def _build_namespace(ctx: SecurityContext, args: dict[str, Any] | None, return_object: Any) -> dict[str, Any]: +def _build_namespace( + ctx: SecurityContext, args: dict[str, Any] | None, return_object: Any, filter_object: Any = None +) -> dict[str, Any]: namespace: dict[str, Any] = { "principal": ctx, "authentication": ctx, @@ -120,10 +160,11 @@ def _build_namespace(ctx: SecurityContext, args: dict[str, Any] | None, return_o "hasAnyRole": _BoolFn(lambda *roles: any(_has_role(ctx, r) for r in roles)), "hasAuthority": _BoolFn(lambda authority: _has_authority(ctx, authority)), "hasAnyAuthority": _BoolFn(lambda *auths: any(_has_authority(ctx, a) for a in auths)), - # 1-arg hasPermission(perm) or 2-arg hasPermission(target, perm) — the last - # argument is the permission (target-based ACLs are not modelled). - "hasPermission": _BoolFn(lambda *parts: ctx.has_permission(str(parts[-1]))), + # hasPermission(perm) / (target, perm) / (id, type, perm) — dispatched to the + # installed PermissionEvaluator, or a flat context check when none is set. + "hasPermission": _BoolFn(lambda *parts: _eval_permission(ctx, parts)), "returnObject": return_object, + "filterObject": filter_object, } for key, value in (args or {}).items(): namespace[_PARAM_PREFIX + key] = value @@ -188,11 +229,15 @@ def evaluate_security_expression( *, args: dict[str, Any] | None = None, return_object: Any = None, + filter_object: Any = None, ) -> bool: - """Evaluate a method-security expression; returns the boolean decision.""" + """Evaluate a method-security expression; returns the boolean decision. + + *filter_object* binds ``filterObject`` for ``@pre_filter`` / ``@post_filter``. + """ translated = _PARAM_RE.sub(lambda m: _PARAM_PREFIX + m.group(1), expression.strip()) try: tree = ast.parse(translated, mode="eval") except SyntaxError as exc: raise SecurityException(f"Invalid security expression syntax: {exc}", code="INVALID_EXPRESSION") from exc - return bool(_eval(tree, _build_namespace(ctx, args, return_object))) + return bool(_eval(tree, _build_namespace(ctx, args, return_object, filter_object))) diff --git a/src/pyfly/security/http_security.py b/src/pyfly/security/http_security.py index 40a467ca..95a3f9c8 100644 --- a/src/pyfly/security/http_security.py +++ b/src/pyfly/security/http_security.py @@ -78,10 +78,22 @@ class SecurityRule: patterns: Glob patterns (fnmatch-style) to match against the request path. An empty list means "any request". rule: The access rule to enforce when a pattern matches. + methods: Upper-case HTTP methods this rule applies to. An empty list + (the default) matches any method. """ patterns: list[str] rule: AccessRule + methods: list[str] = field(default_factory=list) + + +def _normalize_methods(methods: str | list[str] | tuple[str, ...] | None) -> list[str]: + """Coerce a method spec (str / list / None) into a list of upper-case methods.""" + if methods is None: + return [] + if isinstance(methods, str): + return [methods.upper()] + return [m.upper() for m in methods] # --------------------------------------------------------------------------- @@ -92,40 +104,46 @@ class SecurityRule: class _RequestMatcherBuilder: """Intermediate builder returned by ``authorize_requests().request_matchers(...)``.""" - def __init__(self, registry: _AuthorizeRequestsBuilder, patterns: list[str]) -> None: + def __init__( + self, + registry: _AuthorizeRequestsBuilder, + patterns: list[str], + methods: list[str] | None = None, + ) -> None: self._registry = registry self._patterns = patterns + self._methods = methods or [] # -- terminal access-rule methods -- def permit_all(self) -> _AuthorizeRequestsBuilder: """Allow all requests matching the current patterns.""" - self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.PERMIT_ALL)) + self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.PERMIT_ALL), self._methods) return self._registry def deny_all(self) -> _AuthorizeRequestsBuilder: """Deny all requests matching the current patterns.""" - self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.DENY_ALL)) + self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.DENY_ALL), self._methods) return self._registry def authenticated(self) -> _AuthorizeRequestsBuilder: """Require an authenticated user for the current patterns.""" - self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.AUTHENTICATED)) + self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.AUTHENTICATED), self._methods) return self._registry def has_role(self, role: str) -> _AuthorizeRequestsBuilder: """Require the user to have *role* for the current patterns.""" - self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.HAS_ROLE, role)) + self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.HAS_ROLE, role), self._methods) return self._registry def has_any_role(self, roles: list[str]) -> _AuthorizeRequestsBuilder: """Require the user to have at least one of *roles* for the current patterns.""" - self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.HAS_ANY_ROLE, list(roles))) + self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.HAS_ANY_ROLE, list(roles)), self._methods) return self._registry def has_permission(self, permission: str) -> _AuthorizeRequestsBuilder: """Require the user to have *permission* for the current patterns.""" - self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.HAS_PERMISSION, permission)) + self._registry._add_rule(self._patterns, AccessRule(AccessRuleType.HAS_PERMISSION, permission), self._methods) return self._registry @@ -138,26 +156,32 @@ class _AuthorizeRequestsBuilder: def __init__(self, security: HttpSecurity) -> None: self._security = security - def request_matchers(self, *patterns: str) -> _RequestMatcherBuilder: + def request_matchers( + self, *patterns: str, methods: str | list[str] | tuple[str, ...] | None = None + ) -> _RequestMatcherBuilder: """Begin a rule for one or more URL glob patterns. Args: *patterns: fnmatch-style glob patterns (e.g. ``"/api/admin/**"``). + methods: Optional HTTP method(s) the rule applies to (e.g. ``"POST"`` + or ``["PUT", "DELETE"]``). When omitted, the rule matches any + method — mirroring Spring's ``requestMatchers(HttpMethod.X, ...)``. Returns: A :class:`_RequestMatcherBuilder` to set the access rule. """ - return _RequestMatcherBuilder(self, list(patterns)) + return _RequestMatcherBuilder(self, list(patterns), _normalize_methods(methods)) - def any_request(self) -> _RequestMatcherBuilder: + def any_request(self, *, methods: str | list[str] | tuple[str, ...] | None = None) -> _RequestMatcherBuilder: """Begin a catch-all rule that matches any request path. - This should be the **last** rule in the chain. + This should be the **last** rule in the chain. An optional ``methods`` + restricts the catch-all to specific HTTP methods. """ - return _RequestMatcherBuilder(self, []) + return _RequestMatcherBuilder(self, [], _normalize_methods(methods)) - def _add_rule(self, patterns: list[str], rule: AccessRule) -> None: - self._security._rules.append(SecurityRule(patterns=patterns, rule=rule)) + def _add_rule(self, patterns: list[str], rule: AccessRule, methods: list[str] | None = None) -> None: + self._security._rules.append(SecurityRule(patterns=patterns, rule=rule, methods=methods or [])) # --------------------------------------------------------------------------- diff --git a/src/pyfly/security/method_security.py b/src/pyfly/security/method_security.py index 536e9adf..8c9e15f8 100644 --- a/src/pyfly/security/method_security.py +++ b/src/pyfly/security/method_security.py @@ -76,6 +76,104 @@ def _check_expression( ) +_COLLECTION_TYPES = (list, tuple, set) + + +def _filter_collection(expression: str, collection: Any, args: dict[str, Any]) -> Any: + """Return *collection* with only the elements for which *expression* (bound to + ``filterObject``) is True, preserving the collection's concrete type.""" + ctx = _get_security_context() + kept = [item for item in collection if evaluate_security_expression(expression, ctx, args=args, filter_object=item)] + return type(collection)(kept) + + +def _first_collection_param(arguments: dict[str, Any]) -> str | None: + """Name of the first argument (skipping ``self``/``cls``) holding a collection.""" + for name, value in arguments.items(): + if name in ("self", "cls"): + continue + if isinstance(value, _COLLECTION_TYPES): + return name + return None + + +def pre_filter(expression: str, filter_target: str | None = None) -> Callable[[F], F]: + """Filter a collection *argument* before the method runs (Spring ``@PreFilter``). + + Each element is bound to ``filterObject``; elements for which *expression* is + False are removed. ``filter_target`` names the collection parameter; when + omitted, the first collection-valued argument is used. + """ + + def decorator(func: F) -> F: + signature = inspect.signature(func) + + def _filtered_call(args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]]: + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + target = filter_target or _first_collection_param(bound.arguments) + if target is None or target not in bound.arguments: + return args, kwargs + collection = bound.arguments[target] + if not isinstance(collection, _COLLECTION_TYPES): + return args, kwargs + bound.arguments[target] = _filter_collection(expression, collection, dict(bound.arguments)) + return bound.args, bound.kwargs + + if asyncio.iscoroutinefunction(func): + + @functools.wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + new_args, new_kwargs = _filtered_call(args, kwargs) + return await func(*new_args, **new_kwargs) + + async_wrapper.__pyfly_pre_filter__ = expression # type: ignore[attr-defined] + return async_wrapper # type: ignore[return-value] + + @functools.wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + new_args, new_kwargs = _filtered_call(args, kwargs) + return func(*new_args, **new_kwargs) + + sync_wrapper.__pyfly_pre_filter__ = expression # type: ignore[attr-defined] + return sync_wrapper # type: ignore[return-value] + + return decorator + + +def post_filter(expression: str) -> Callable[[F], F]: + """Filter the returned collection after the method runs (Spring ``@PostFilter``). + + Each returned element is bound to ``filterObject``; non-collection results are + returned unchanged. + """ + + def decorator(func: F) -> F: + if asyncio.iscoroutinefunction(func): + + @functools.wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + result = await func(*args, **kwargs) + if not isinstance(result, _COLLECTION_TYPES): + return result + return _filter_collection(expression, result, _bind_args(func, args, kwargs)) + + async_wrapper.__pyfly_post_filter__ = expression # type: ignore[attr-defined] + return async_wrapper # type: ignore[return-value] + + @functools.wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + result = func(*args, **kwargs) + if not isinstance(result, _COLLECTION_TYPES): + return result + return _filter_collection(expression, result, _bind_args(func, args, kwargs)) + + sync_wrapper.__pyfly_post_filter__ = expression # type: ignore[attr-defined] + return sync_wrapper # type: ignore[return-value] + + return decorator + + def pre_authorize(expression: str) -> Callable[[F], F]: """Decorator that checks a security expression BEFORE method execution. diff --git a/src/pyfly/security/oauth2/__init__.py b/src/pyfly/security/oauth2/__init__.py index 2a2c88d3..78a55d29 100644 --- a/src/pyfly/security/oauth2/__init__.py +++ b/src/pyfly/security/oauth2/__init__.py @@ -26,13 +26,20 @@ google, keycloak, ) +from pyfly.security.oauth2.endpoints import AuthorizationServerEndpoints from pyfly.security.oauth2.login import OAuth2LoginHandler from pyfly.security.oauth2.properties import ResourceServerProperties -from pyfly.security.oauth2.resource_server import ClaimMappings, JWKSTokenValidator, discover_oidc +from pyfly.security.oauth2.resource_server import ( + ClaimMappings, + JWKSTokenValidator, + OpaqueTokenIntrospector, + discover_oidc, +) from pyfly.security.oauth2.session_security_filter import OAuth2SessionSecurityFilter __all__ = [ "AuthorizationServer", + "AuthorizationServerEndpoints", "ClaimMappings", "ClientRegistration", "ClientRegistrationRepository", @@ -41,6 +48,7 @@ "JWKSTokenValidator", "OAuth2LoginHandler", "OAuth2SessionSecurityFilter", + "OpaqueTokenIntrospector", "ResourceServerProperties", "TokenStore", "discover_oidc", diff --git a/src/pyfly/security/oauth2/authorization_server.py b/src/pyfly/security/oauth2/authorization_server.py index eee3d5cc..3a43010a 100644 --- a/src/pyfly/security/oauth2/authorization_server.py +++ b/src/pyfly/security/oauth2/authorization_server.py @@ -68,12 +68,23 @@ class AuthorizationServer: - refresh_token: exchange a refresh token for a new access token Args: - secret: Secret key for signing tokens (HS256). + secret: Secret key for HMAC signing (used when ``algorithm`` is ``HS*``). client_repository: Repository to look up client registrations. token_store: Store for refresh tokens. access_token_ttl: Access token lifetime in seconds (default: 3600 = 1 hour). refresh_token_ttl: Refresh token lifetime in seconds (default: 86400 = 24 hours). issuer: Token issuer claim (optional). + audience: Audience the issued tokens are restricted to (``aud`` claim). + Accepts a single value or a list. When unset, no ``aud`` is emitted + (backward compatible). Setting it lets resource servers reject tokens + minted for a different API (RFC 9700 / OAuth 2.1 audience restriction). + algorithm: JWS algorithm. ``HS256`` (default) signs with ``secret``; + ``RS256``/``RS384``/``RS512``/``PS*``/``ES256``/``ES384``/``ES512`` + sign with ``private_key`` and publish the matching public key via + :meth:`jwks`, so a resource server can verify AS-minted tokens. + private_key: PEM string/bytes or a cryptography private-key object, required + for asymmetric algorithms. + key_id: ``kid`` placed in the JWT header and the published JWK. """ def __init__( @@ -84,6 +95,10 @@ def __init__( access_token_ttl: int = 3600, refresh_token_ttl: int = 86400, issuer: str | None = None, + audience: str | list[str] | None = None, + algorithm: str = "HS256", + private_key: Any = None, + key_id: str | None = None, ) -> None: self._secret = secret self._client_repository = client_repository @@ -91,6 +106,130 @@ def __init__( self._access_token_ttl = access_token_ttl self._refresh_token_ttl = refresh_token_ttl self._issuer = issuer + self._algorithm = algorithm.upper() + self._is_asymmetric = self._algorithm[:2] in ("RS", "ES", "PS") + self._key_id = key_id + self._private_key: Any = self._coerce_private_key(private_key) if self._is_asymmetric else None + if self._is_asymmetric and self._private_key is None: + raise ValueError(f"algorithm {self._algorithm} requires a private_key") + if audience is None: + self._audience: str | list[str] | None = None + elif isinstance(audience, str): + self._audience = audience + else: + aud_list = [a for a in audience if a] + self._audience = aud_list or None + + @staticmethod + def _coerce_private_key(private_key: Any) -> Any: + """Load a PEM string/bytes into a key object; pass through key objects.""" + if isinstance(private_key, (str, bytes)): + from cryptography.hazmat.primitives.serialization import load_pem_private_key + + data = private_key.encode("utf-8") if isinstance(private_key, str) else private_key + return load_pem_private_key(data, password=None) + return private_key + + def _encode(self, payload: dict[str, Any]) -> str: + """Sign *payload* with the configured algorithm (HMAC or asymmetric+kid).""" + if self._is_asymmetric: + assert self._private_key is not None # guaranteed by __init__ + headers = {"kid": self._key_id} if self._key_id else None + return pyjwt.encode(payload, self._private_key, algorithm=self._algorithm, headers=headers) + return pyjwt.encode(payload, self._secret, algorithm=self._algorithm) + + def jwks(self) -> dict[str, Any]: + """Return the public JWK Set for token verification (empty for HMAC).""" + if not self._is_asymmetric or self._private_key is None: + return {"keys": []} + import json as _json + + assert self._private_key is not None # narrowed for mypy + public_key = self._private_key.public_key() + if self._algorithm[:2] == "ES": + jwk = _json.loads(pyjwt.algorithms.ECAlgorithm.to_jwk(public_key)) + else: + jwk = _json.loads(pyjwt.algorithms.RSAAlgorithm.to_jwk(public_key)) + jwk.update({"use": "sig", "alg": self._algorithm}) + if self._key_id: + jwk["kid"] = self._key_id + return {"keys": [jwk]} + + def authenticate_client(self, client_id: str, client_secret: str) -> ClientRegistration | None: + """Return the registration iff *client_id*/*client_secret* match (constant time). + + Client authentication requires real credentials: an empty client id or + secret — or a registration that has no secret configured — never + authenticates (prevents an empty-credential bypass on the management + endpoints and for any client that is not a confidential client). + """ + if not client_id or not client_secret: + return None + registration = self._client_repository.find_by_registration_id(client_id) + if registration is None or not registration.client_secret: + return None + if not secrets.compare_digest(registration.client_secret.encode("utf-8"), client_secret.encode("utf-8")): + return None + return registration + + def _verification_key(self) -> Any: + return self._private_key.public_key() if self._is_asymmetric else self._secret + + async def introspect( + self, token: str, *, requesting_client_id: str | None = None, allow_any_client: bool = False + ) -> dict[str, Any]: + """RFC 7662 token introspection for an access (JWT) or refresh token. + + When *requesting_client_id* is given and *allow_any_client* is False, a + token owned by a different client is reported as inactive — so a client + cannot scan another client's tokens (information disclosure). Designated + resource-server clients pass ``allow_any_client=True``. + """ + result = await self._introspect(token) + if ( + result.get("active") + and requesting_client_id is not None + and not allow_any_client + and result.get("client_id") != requesting_client_id + ): + return {"active": False} + return result + + async def _introspect(self, token: str) -> dict[str, Any]: + # Access token: a self-contained, signature-verified JWT. + try: + payload = pyjwt.decode( + token, + self._verification_key(), + algorithms=[self._algorithm], + options={"require": ["exp"], "verify_aud": False}, + ) + active: dict[str, Any] = {"active": True, "token_type": "Bearer"} + for claim in ("sub", "scope", "iat", "exp", "iss", "aud"): + if claim in payload: + active[claim] = payload[claim] + active.setdefault("client_id", payload.get("sub")) + return active + except pyjwt.PyJWTError: + pass + + # Refresh token: opaque, looked up in the store; active iff present, + # unused, unexpired, and its family is still active. + data = await self._token_store.find(token) + if data is None or data.get("used") or data.get("exp", 0) < int(time.time()): + return {"active": False} + family_id = data.get("family_id") + if family_id: + family = await self._token_store.find(self._family_key(family_id)) + if family is not None and not family.get("active", True): + return {"active": False} + return { + "active": True, + "token_type": "refresh_token", + "client_id": data.get("client_id"), + "scope": data.get("scope", ""), + "exp": data.get("exp"), + } async def token( self, @@ -99,6 +238,7 @@ async def token( client_secret: str, scope: str = "", refresh_token: str | None = None, + confirmation: dict[str, Any] | None = None, ) -> dict[str, Any]: """Issue tokens based on grant type. @@ -108,6 +248,9 @@ async def token( client_secret: The client's secret scope: Space-separated scopes (for client_credentials) refresh_token: The refresh token (for refresh_token grant) + confirmation: Optional ``cnf`` confirmation claim to bind the access + token to a key (e.g. ``{"jkt": ...}`` for DPoP, ``{"x5t#S256": ...}`` + for mTLS) — sender-constraining per RFC 9449 / RFC 8705. Returns: Token response dict with access_token, token_type, expires_in, @@ -118,10 +261,8 @@ async def token( """ # Authenticate client (constant-time secret comparison to avoid a timing # side-channel that could leak the client secret). - registration = self._client_repository.find_by_registration_id(client_id) - if registration is None or not secrets.compare_digest( - registration.client_secret.encode("utf-8"), client_secret.encode("utf-8") - ): + registration = self.authenticate_client(client_id, client_secret) + if registration is None: raise SecurityException("Invalid client credentials", code="INVALID_CLIENT") if grant_type == "client_credentials": @@ -133,20 +274,36 @@ async def token( f"Client '{client_id}' is not authorized for grant type 'client_credentials'", code="UNAUTHORIZED_CLIENT", ) - return await self._handle_client_credentials(registration, scope) + return await self._handle_client_credentials(registration, scope, confirmation) elif grant_type == "refresh_token": if refresh_token is None: raise SecurityException("Refresh token required", code="INVALID_REQUEST") - return await self._handle_refresh_token(registration, refresh_token) + return await self._handle_refresh_token(registration, refresh_token, confirmation) else: raise SecurityException( f"Unsupported grant type: {grant_type}", code="UNSUPPORTED_GRANT_TYPE", ) - async def _handle_client_credentials(self, registration: ClientRegistration, scope: str) -> dict[str, Any]: + async def _handle_client_credentials( + self, registration: ClientRegistration, scope: str, confirmation: dict[str, Any] | None = None + ) -> dict[str, Any]: now = int(time.time()) - scopes = scope.split() if scope else registration.scopes + # A client may only ever obtain scopes it is registered for. Requesting an + # unregistered scope is rejected wholesale (RFC 6749 §5.2 ``invalid_scope``) + # rather than silently echoed — otherwise any authenticated client could + # mint an arbitrarily-privileged token (e.g. ``admin``) just by asking. + if scope: + requested = scope.split() + unregistered = [s for s in requested if s not in registration.scopes] + if unregistered: + raise SecurityException( + f"Requested scope(s) not permitted for this client: {' '.join(unregistered)}", + code="INVALID_SCOPE", + ) + scopes = requested + else: + scopes = registration.scopes access_payload: dict[str, Any] = { "sub": registration.client_id, @@ -156,31 +313,45 @@ async def _handle_client_credentials(self, registration: ClientRegistration, sco } if self._issuer: access_payload["iss"] = self._issuer + if self._audience is not None: + access_payload["aud"] = self._audience + if confirmation: + access_payload["cnf"] = confirmation - access_token = pyjwt.encode(access_payload, self._secret, algorithm="HS256") + access_token = self._encode(access_payload) - # Generate refresh token - refresh_token_id = secrets.token_urlsafe(32) - refresh_data = { - "client_id": registration.client_id, - "scope": " ".join(scopes), - "exp": now + self._refresh_token_ttl, - } - await self._token_store.store(refresh_token_id, refresh_data) + scope_str = " ".join(scopes) + refresh_token_id = await self._issue_refresh_token(registration.client_id, scope_str) return { "access_token": access_token, "token_type": "Bearer", "expires_in": self._access_token_ttl, "refresh_token": refresh_token_id, - "scope": " ".join(scopes), + "scope": scope_str, } - async def _handle_refresh_token(self, registration: ClientRegistration, refresh_token: str) -> dict[str, Any]: + async def _handle_refresh_token( + self, registration: ClientRegistration, refresh_token: str, confirmation: dict[str, Any] | None = None + ) -> dict[str, Any]: token_data = await self._token_store.find(refresh_token) if token_data is None: raise SecurityException("Invalid refresh token", code="INVALID_GRANT") + family_id = token_data.get("family_id") + family = await self._token_store.find(self._family_key(family_id)) if family_id else None + + # The family was already revoked (e.g. by a previous reuse) — refuse. + if family is not None and not family.get("active", True): + raise SecurityException("Refresh token family revoked", code="INVALID_GRANT") + + # Reuse detection (OAuth 2.1 / RFC 9700): a refresh token that was already + # rotated is being replayed. The legitimate holder cannot do this, so we + # treat it as theft and revoke the entire token family. + if token_data.get("used"): + await self._revoke_family(family_id, family) + raise SecurityException("Refresh token reuse detected", code="INVALID_GRANT") + # Verify client matches if token_data.get("client_id") != registration.client_id: raise SecurityException("Refresh token client mismatch", code="INVALID_GRANT") @@ -190,8 +361,10 @@ async def _handle_refresh_token(self, registration: ClientRegistration, refresh_ await self._token_store.revoke(refresh_token) raise SecurityException("Refresh token expired", code="INVALID_GRANT") - # Revoke old refresh token (rotation) - await self._token_store.revoke(refresh_token) + # Mark the presented token consumed (rotation). It is retained — not + # deleted — so a later replay is detected as reuse rather than "unknown". + token_data["used"] = True + await self._token_store.store(refresh_token, token_data) # Issue new tokens now = int(time.time()) @@ -205,17 +378,14 @@ async def _handle_refresh_token(self, registration: ClientRegistration, refresh_ } if self._issuer: access_payload["iss"] = self._issuer + if self._audience is not None: + access_payload["aud"] = self._audience + if confirmation: + access_payload["cnf"] = confirmation - access_token = pyjwt.encode(access_payload, self._secret, algorithm="HS256") + access_token = self._encode(access_payload) - # New refresh token - new_refresh_id = secrets.token_urlsafe(32) - new_refresh_data = { - "client_id": registration.client_id, - "scope": scope, - "exp": now + self._refresh_token_ttl, - } - await self._token_store.store(new_refresh_id, new_refresh_data) + new_refresh_id = await self._issue_refresh_token(registration.client_id, scope, family_id) return { "access_token": access_token, @@ -225,6 +395,64 @@ async def _handle_refresh_token(self, registration: ClientRegistration, refresh_ "scope": scope, } - async def revoke(self, token_id: str) -> None: - """Revoke a refresh token.""" + # ------------------------------------------------------------------ + # Refresh-token family bookkeeping (rotation + reuse detection) + # ------------------------------------------------------------------ + + @staticmethod + def _family_key(family_id: str | None) -> str: + return f"family:{family_id}" + + async def _issue_refresh_token(self, client_id: str, scope: str, family_id: str | None = None) -> str: + """Mint a refresh token, creating or extending its rotation family.""" + token_id = secrets.token_urlsafe(32) + if family_id is None: + family_id = secrets.token_urlsafe(16) + family: dict[str, Any] = {"client_id": client_id, "active": True, "members": [token_id]} + else: + family = await self._token_store.find(self._family_key(family_id)) or { + "client_id": client_id, + "active": True, + "members": [], + } + family.setdefault("members", []).append(token_id) + token_data = { + "client_id": client_id, + "scope": scope, + "exp": int(time.time()) + self._refresh_token_ttl, + "family_id": family_id, + "used": False, + } + await self._token_store.store(token_id, token_data) + await self._token_store.store(self._family_key(family_id), family) + return token_id + + async def _revoke_family(self, family_id: str | None, family: dict[str, Any] | None = None) -> None: + """Revoke an entire refresh-token family (all rotated descendants).""" + if family_id is None: + return + if family is None: + family = await self._token_store.find(self._family_key(family_id)) + if family is None: + return + family["active"] = False + await self._token_store.store(self._family_key(family_id), family) + for token_id in family.get("members", []): + await self._token_store.revoke(token_id) + + async def revoke(self, token_id: str, *, requesting_client_id: str | None = None) -> None: + """Revoke a refresh token (and, when known, its whole rotation family). + + Per RFC 7009 §2.1, when *requesting_client_id* is given the token is only + revoked if it was issued to that client — a client cannot revoke another + client's tokens. ``requesting_client_id=None`` (internal callers) revokes + unconditionally. + """ + token_data = await self._token_store.find(token_id) + owner = token_data.get("client_id") if isinstance(token_data, dict) else None + if requesting_client_id is not None and owner is not None and owner != requesting_client_id: + return # not the owner — refuse silently (RFC 7009 still returns 200) await self._token_store.revoke(token_id) + family_id = token_data.get("family_id") if isinstance(token_data, dict) else None + if family_id: + await self._revoke_family(family_id) diff --git a/src/pyfly/security/oauth2/client.py b/src/pyfly/security/oauth2/client.py index 31657c3e..615854fa 100644 --- a/src/pyfly/security/oauth2/client.py +++ b/src/pyfly/security/oauth2/client.py @@ -39,9 +39,20 @@ class ClientRegistration: jwks_uri: str = "" issuer_uri: str = "" provider_name: str = "" - # Enable PKCE (RFC 7636, S256) on the authorization_code flow. Recommended for public - # clients (no client_secret); harmless and more secure for confidential clients too. - use_pkce: bool = False + # Enable PKCE (RFC 7636, S256) on the authorization_code flow. On by default — + # RFC 9700 / OAuth 2.1 require PKCE for the authorization code grant for all + # client types. A public client (empty client_secret) always uses PKCE even if + # this is set False, as it has no other defense against code injection. Set + # False only for a confidential client talking to an AS that rejects PKCE. + use_pkce: bool = True + # Require the RFC 9207 ``iss`` authorization-response parameter to be present + # and match ``issuer_uri`` on callback (mix-up-attack defense). When False + # (default) the ``iss`` param is still validated *when present*, but a provider + # that omits it is tolerated. + require_iss: bool = False + # Marks a resource-server client permitted to introspect tokens it does not + # own (RFC 7662). Regular clients may only introspect their own tokens. + allow_introspection: bool = False # --------------------------------------------------------------------------- diff --git a/src/pyfly/security/oauth2/dpop.py b/src/pyfly/security/oauth2/dpop.py new file mode 100644 index 00000000..2508253e --- /dev/null +++ b/src/pyfly/security/oauth2/dpop.py @@ -0,0 +1,168 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sender-constrained access tokens — DPoP (RFC 9449) and mTLS (RFC 8705). + +A bearer token can be replayed by anyone who steals it. Sender-constraining binds +the token to a key the legitimate client holds: + +* **DPoP** — the client signs a per-request *proof* JWT with its private key; the + access token carries ``cnf.jkt`` (the JWK SHA-256 thumbprint, RFC 7638). The + resource server verifies the proof and that its key thumbprint matches ``jkt``. +* **mTLS** — the access token carries ``cnf["x5t#S256"]`` (the client certificate + thumbprint). The resource server compares it to the presented client cert. +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import time +from typing import Any +from urllib.parse import urlsplit + +import jwt as pyjwt + +from pyfly.kernel.exceptions import SecurityException + + +def _b64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def jwk_thumbprint(jwk: dict[str, Any]) -> str: + """Compute the RFC 7638 JWK SHA-256 thumbprint (base64url, no padding).""" + kty = jwk.get("kty") + if kty == "RSA": + members = {"e": jwk["e"], "kty": "RSA", "n": jwk["n"]} + elif kty == "EC": + members = {"crv": jwk["crv"], "kty": "EC", "x": jwk["x"], "y": jwk["y"]} + elif kty == "OKP": + members = {"crv": jwk["crv"], "kty": "OKP", "x": jwk["x"]} + else: + raise SecurityException(f"Unsupported JWK key type for thumbprint: {kty!r}", code="INVALID_TOKEN") + canonical = json.dumps(members, separators=(",", ":"), sort_keys=True) + return _b64url(hashlib.sha256(canonical.encode("utf-8")).digest()) + + +def _normalize_htu(url: str) -> str: + """Normalize an HTTP URI for ``htu`` comparison: scheme://host/path (no query/fragment).""" + parts = urlsplit(url) + return f"{parts.scheme}://{parts.netloc}{parts.path}" + + +def access_token_hash(access_token: str) -> str: + """The DPoP ``ath`` value: base64url(SHA-256(access_token)).""" + return _b64url(hashlib.sha256(access_token.encode("ascii")).digest()) + + +class DPoPProofValidator: + """Validates a DPoP proof JWT (RFC 9449 §4.3) and returns its key thumbprint. + + Args: + max_age_seconds: Accepted ``iat`` skew window for the proof. + replay_cache: Optional set-like collection of seen ``jti`` values; when + provided, a repeated ``jti`` is rejected as a replay. (Use a bounded / + TTL-backed set in production.) + """ + + def __init__(self, *, max_age_seconds: int = 60, replay_cache: set[str] | None = None) -> None: + self._max_age = max_age_seconds + self._replay_cache = replay_cache + + def validate( + self, + proof: str, + *, + http_method: str, + http_url: str, + access_token: str | None = None, + ) -> str: + """Verify *proof* for the given request; return the bound key thumbprint (jkt). + + Raises: + SecurityException: if the proof is malformed, mis-signed, stale, replayed, + or does not match the request method/URL (or access token hash). + """ + try: + header = pyjwt.get_unverified_header(proof) + except pyjwt.PyJWTError as exc: + raise SecurityException(f"Malformed DPoP proof: {exc}", code="INVALID_DPOP_PROOF") from exc + + if header.get("typ") != "dpop+jwt": + raise SecurityException("DPoP proof has wrong 'typ'", code="INVALID_DPOP_PROOF") + alg = str(header.get("alg", "")) + if alg[:2] not in ("RS", "ES", "PS") and not alg.startswith("Ed"): + raise SecurityException("DPoP proof must use an asymmetric algorithm", code="INVALID_DPOP_PROOF") + jwk = header.get("jwk") + if not isinstance(jwk, dict): + raise SecurityException("DPoP proof missing embedded 'jwk'", code="INVALID_DPOP_PROOF") + if any(k in jwk for k in ("d", "p", "q", "dp", "dq", "qi")): + raise SecurityException("DPoP proof 'jwk' must not contain private material", code="INVALID_DPOP_PROOF") + + try: + key = pyjwt.PyJWK.from_dict(jwk).key + claims = pyjwt.decode(proof, key, algorithms=[alg], options={"verify_aud": False}) + except pyjwt.PyJWTError as exc: + raise SecurityException(f"DPoP proof signature invalid: {exc}", code="INVALID_DPOP_PROOF") from exc + + if str(claims.get("htm", "")).upper() != http_method.upper(): + raise SecurityException("DPoP 'htm' does not match the request method", code="INVALID_DPOP_PROOF") + if _normalize_htu(str(claims.get("htu", ""))) != _normalize_htu(http_url): + raise SecurityException("DPoP 'htu' does not match the request URL", code="INVALID_DPOP_PROOF") + + iat = claims.get("iat") + if not isinstance(iat, (int, float)) or abs(time.time() - float(iat)) > self._max_age: + raise SecurityException("DPoP proof is stale or missing 'iat'", code="INVALID_DPOP_PROOF") + + jti = claims.get("jti") + if self._replay_cache is not None: + if not jti or jti in self._replay_cache: + raise SecurityException("DPoP proof replayed or missing 'jti'", code="INVALID_DPOP_PROOF") + self._replay_cache.add(str(jti)) + + if access_token is not None and claims.get("ath") != access_token_hash(access_token): + raise SecurityException("DPoP 'ath' does not match the access token", code="INVALID_DPOP_PROOF") + + return jwk_thumbprint(jwk) + + +def confirm_dpop_binding(token_claims: dict[str, Any], jkt: str) -> None: + """Assert the access token is DPoP-bound to *jkt* (its ``cnf.jkt``).""" + bound = (token_claims.get("cnf") or {}).get("jkt") + if not bound: + raise SecurityException("Access token is not DPoP-bound (no cnf.jkt)", code="INVALID_TOKEN") + if bound != jkt: + raise SecurityException("DPoP key does not match the token's cnf.jkt", code="INVALID_TOKEN") + + +def certificate_thumbprint(cert: str | bytes) -> str: + """Return the RFC 8705 ``x5t#S256`` thumbprint (base64url SHA-256 of the DER cert).""" + from cryptography import x509 + + if isinstance(cert, str): + cert = cert.encode("utf-8") + loaded = x509.load_pem_x509_certificate(cert) if b"-----BEGIN" in cert else x509.load_der_x509_certificate(cert) + from cryptography.hazmat.primitives.serialization import Encoding + + return _b64url(hashlib.sha256(loaded.public_bytes(Encoding.DER)).digest()) + + +def confirm_mtls_binding(token_claims: dict[str, Any], cert: str | bytes) -> None: + """Assert the access token is mTLS-bound to *cert* (its ``cnf["x5t#S256"]``).""" + bound = (token_claims.get("cnf") or {}).get("x5t#S256") + if not bound: + raise SecurityException("Access token is not mTLS-bound (no cnf.x5t#S256)", code="INVALID_TOKEN") + if bound != certificate_thumbprint(cert): + raise SecurityException("Client certificate does not match the token's cnf.x5t#S256", code="INVALID_TOKEN") diff --git a/src/pyfly/security/oauth2/endpoints.py b/src/pyfly/security/oauth2/endpoints.py new file mode 100644 index 00000000..fd150675 --- /dev/null +++ b/src/pyfly/security/oauth2/endpoints.py @@ -0,0 +1,156 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OAuth2 Authorization Server HTTP endpoints. + +Exposes, as Starlette routes, the token endpoint plus the standard OAuth2 +management endpoints: + +* ``POST /oauth2/token`` — issue tokens (RFC 6749) +* ``POST /oauth2/introspect`` — token introspection (RFC 7662), client-authenticated +* ``POST /oauth2/revoke`` — token revocation (RFC 7009), client-authenticated +* ``GET /oauth2/jwks`` — public JWK Set (for asymmetric signing) +""" + +from __future__ import annotations + +import base64 +import binascii +from typing import Any + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Route + +from pyfly.kernel.exceptions import SecurityException +from pyfly.security.oauth2.authorization_server import AuthorizationServer +from pyfly.security.oauth2.client import ClientRegistration + +# OAuth2 error codes that map to a 401 (client authentication failed); the rest +# are request/grant errors returned as 400 (RFC 6749 §5.2). +_UNAUTHORIZED_ERRORS = {"INVALID_CLIENT"} + + +class AuthorizationServerEndpoints: + """Builds Starlette routes that expose an :class:`AuthorizationServer`.""" + + def __init__(self, server: AuthorizationServer) -> None: + self._server = server + + def routes(self) -> list[Route]: + return [ + Route("/oauth2/token", self._token, methods=["POST"]), + Route("/oauth2/introspect", self._introspect, methods=["POST"]), + Route("/oauth2/revoke", self._revoke, methods=["POST"]), + Route("/oauth2/jwks", self._jwks, methods=["GET"]), + ] + + # -- token endpoint ---------------------------------------------------- + + async def _token(self, request: Request) -> Response: + form = await request.form() + client_id, client_secret = self._client_credentials(request, form) + # DPoP (RFC 9449): if the client presents a proof on the token request, bind + # the issued access token to its key via a cnf.jkt confirmation claim. + confirmation: dict[str, Any] | None = None + dpop_proof = request.headers.get("dpop") + if dpop_proof: + from pyfly.security.oauth2.dpop import DPoPProofValidator + + try: + jkt = DPoPProofValidator().validate(dpop_proof, http_method="POST", http_url=str(request.url)) + except SecurityException as exc: + return self._error(exc) + confirmation = {"jkt": jkt} + try: + result = await self._server.token( + grant_type=str(form.get("grant_type", "")), + client_id=client_id, + client_secret=client_secret, + scope=str(form.get("scope", "")), + refresh_token=(str(form["refresh_token"]) if form.get("refresh_token") else None), + confirmation=confirmation, + ) + except SecurityException as exc: + return self._error(exc) + return JSONResponse(result) + + # -- introspection (RFC 7662) ----------------------------------------- + + async def _introspect(self, request: Request) -> Response: + form = await request.form() + registration = self._authenticate(request, form) + if registration is None: + return self._error(SecurityException("Invalid client", code="INVALID_CLIENT")) + token = str(form.get("token", "")) + if not token: + return JSONResponse({"active": False}) + result = await self._server.introspect( + token, + requesting_client_id=registration.client_id, + allow_any_client=getattr(registration, "allow_introspection", False), + ) + return JSONResponse(result) + + # -- revocation (RFC 7009) -------------------------------------------- + + async def _revoke(self, request: Request) -> Response: + form = await request.form() + registration = self._authenticate(request, form) + if registration is None: + return self._error(SecurityException("Invalid client", code="INVALID_CLIENT")) + token = str(form.get("token", "")) + if token: + # RFC 7009 §2.1: only the owning client may revoke the token. + await self._server.revoke(token, requesting_client_id=registration.client_id) + # RFC 7009 §2.2: the AS responds 200 regardless of whether the token existed. + return JSONResponse({}) + + # -- JWKS -------------------------------------------------------------- + + async def _jwks(self, request: Request) -> Response: + return JSONResponse(self._server.jwks()) + + # -- helpers ----------------------------------------------------------- + + @staticmethod + def _client_credentials(request: Request, form: Any) -> tuple[str, str]: + """Resolve client credentials from HTTP Basic or form params (post).""" + basic = AuthorizationServerEndpoints._basic_auth(request) + if basic is not None: + return basic + return str(form.get("client_id", "")), str(form.get("client_secret", "")) + + def _authenticate(self, request: Request, form: Any) -> ClientRegistration | None: + client_id, client_secret = self._client_credentials(request, form) + return self._server.authenticate_client(client_id, client_secret) + + @staticmethod + def _basic_auth(request: Request) -> tuple[str, str] | None: + header = request.headers.get("authorization", "") + parts = header.split(" ", 1) + if len(parts) != 2 or parts[0].lower() != "basic": + return None + try: + decoded = base64.b64decode(parts[1].strip(), validate=True).decode("utf-8") + except (binascii.Error, ValueError, UnicodeDecodeError): + return None + cid, sep, secret = decoded.partition(":") + return (cid, secret) if sep else None + + @staticmethod + def _error(exc: SecurityException) -> JSONResponse: + code = getattr(exc, "code", "INVALID_REQUEST") or "INVALID_REQUEST" + status = 401 if code in _UNAUTHORIZED_ERRORS else 400 + headers = {"WWW-Authenticate": 'Basic realm="oauth2"'} if status == 401 else None + return JSONResponse({"error": code.lower(), "error_description": str(exc)}, status_code=status, headers=headers) diff --git a/src/pyfly/security/oauth2/login.py b/src/pyfly/security/oauth2/login.py index ed26ce96..a50927e1 100644 --- a/src/pyfly/security/oauth2/login.py +++ b/src/pyfly/security/oauth2/login.py @@ -47,6 +47,16 @@ def _generate_pkce() -> tuple[str, str]: return verifier, challenge +def _uses_pkce(registration: Any) -> bool: + """Whether PKCE applies to this registration's authorization_code flow. + + PKCE is on by default (RFC 9700 / OAuth 2.1). It is always enforced for a + public client (no ``client_secret``) — which has no other defense against + authorization-code injection — even if ``use_pkce`` was explicitly disabled. + """ + return bool(getattr(registration, "use_pkce", True)) or not getattr(registration, "client_secret", "") + + class OAuth2LoginHandler: """Creates Starlette routes for the OAuth2 authorization_code login flow. @@ -110,7 +120,7 @@ async def _handle_authorization(self, request: Request) -> Response: "nonce": nonce, } # PKCE (RFC 7636): stash the verifier in the session, send only the S256 challenge. - if getattr(registration, "use_pkce", False): + if _uses_pkce(registration): verifier, challenge = _generate_pkce() session.set_attribute(_OAUTH2_PKCE_VERIFIER_KEY, verifier) params["code_challenge"] = challenge @@ -168,9 +178,16 @@ async def _handle_callback(self, request: Request) -> Response: status_code=400, ) + # RFC 9207 mix-up defense: validate the issuer that produced this response. + # The ``iss`` param is always rejected on mismatch with the registration's + # ``issuer_uri``; with ``require_iss`` it must also be present. + iss_error = self._validate_iss(registration, request.query_params.get("iss")) + if iss_error is not None: + return iss_error + # PKCE: retrieve and consume the one-time verifier stashed at authorization time. code_verifier = None - if getattr(registration, "use_pkce", False): + if _uses_pkce(registration): code_verifier = session.get_attribute(_OAUTH2_PKCE_VERIFIER_KEY) session.remove_attribute(_OAUTH2_PKCE_VERIFIER_KEY) @@ -254,6 +271,31 @@ async def _handle_logout(self, request: Request) -> Response: # Internal helpers # ------------------------------------------------------------------ + @staticmethod + def _validate_iss(registration: Any, received_iss: str | None) -> Response | None: + """Validate the RFC 9207 ``iss`` authorization-response parameter. + + Returns a 400 response on a mismatch, or when ``require_iss`` is set and the + parameter is absent; otherwise ``None`` (validation passed). + """ + expected = getattr(registration, "issuer_uri", "") or "" + require = getattr(registration, "require_iss", False) + if received_iss is None: + if require: + logger.warning("OAuth2 callback missing required 'iss' parameter (RFC 9207)") + return JSONResponse( + {"error": "invalid_iss", "message": "Missing required 'iss' parameter"}, + status_code=400, + ) + return None + if expected and received_iss != expected: + logger.warning("OAuth2 'iss' mismatch: expected %r, got %r", expected, received_iss) + return JSONResponse( + {"error": "invalid_iss", "message": "Issuer (iss) does not match the expected provider"}, + status_code=400, + ) + return None + async def _exchange_code(self, registration: Any, code: str, code_verifier: str | None = None) -> dict[str, Any]: """Exchange an authorization code for tokens via the token endpoint.""" data = { diff --git a/src/pyfly/security/oauth2/properties.py b/src/pyfly/security/oauth2/properties.py index 2bafecf6..5269c292 100644 --- a/src/pyfly/security/oauth2/properties.py +++ b/src/pyfly/security/oauth2/properties.py @@ -91,6 +91,12 @@ class ResourceServerProperties: scope_claim_names: str = "scp,scope" attribute_claims: str = "" + # --- sender-constrained tokens (RFC 9449 DPoP / RFC 8705 mTLS) -------- + # When true, a token carrying a ``cnf`` claim must be accompanied by proof of + # possession (a DPoP proof header, or a client certificate in the mTLS header). + enforce_sender_constraints: bool = False + mtls_cert_header: str = "x-client-cert" + # --- filter ----------------------------------------------------------- exclude_patterns: str = "" # "anonymous" (default, non-breaking): an invalid/missing token yields an diff --git a/src/pyfly/security/oauth2/resource_server.py b/src/pyfly/security/oauth2/resource_server.py index d94ced9e..18e06dbb 100644 --- a/src/pyfly/security/oauth2/resource_server.py +++ b/src/pyfly/security/oauth2/resource_server.py @@ -253,50 +253,116 @@ def to_security_context(self, token: str) -> SecurityContext: payload = self.validate(token) return self._build_context(payload) + def validate_and_context(self, token: str) -> tuple[dict[str, Any], SecurityContext]: + """Validate *token* once and return both the raw claims and the context. + + Lets a filter inspect claims (e.g. ``cnf`` for sender-constraining) without + validating the signature twice.""" + payload = self.validate(token) + return payload, self._build_context(payload) + def _build_context(self, payload: dict[str, Any]) -> SecurityContext: """Map a validated *payload* onto a :class:`SecurityContext` per the configured claim mappings. Subclasses may override for bespoke mapping.""" - m = self._mappings - - # Principal: first non-empty principal claim wins. - user_id: str | None = None - for claim in m.principal_claims: - vals = _flatten_strs(_resolve_claim_path(payload, claim)) - if vals: - user_id = vals[0] - break - - # Authorities/roles: collect across every configured path, de-duplicated - # (order-preserving), with the optional prefix applied. - roles: list[str] = [] - seen: set[str] = set() - for claim in m.authority_claims: - for raw in _flatten_strs(_resolve_claim_path(payload, claim)): - value = f"{m.authority_prefix}{raw}" if m.authority_prefix else raw - if value not in seen: - seen.add(value) - roles.append(value) - - # Permissions/scopes: scope claims are space-delimited strings or lists. - permissions: list[str] = [] - perm_seen: set[str] = set() - for claim in m.scope_claims: - for raw in _flatten_strs(_resolve_claim_path(payload, claim)): - for part in raw.split(): - if part and part not in perm_seen: - perm_seen.add(part) - permissions.append(part) - - # Attributes: copy configured claims verbatim (string-coerced). - attributes: dict[str, str] = {} - for claim in m.attribute_claims: - vals = _flatten_strs(_resolve_claim_path(payload, claim)) - if vals: - attributes[claim] = vals[0] - - return SecurityContext( - user_id=user_id, - roles=roles, - permissions=permissions, - attributes=attributes, - ) + return build_security_context(payload, self._mappings) + + +def build_security_context(payload: dict[str, Any], mappings: ClaimMappings) -> SecurityContext: + """Map a token/introspection *payload* onto a :class:`SecurityContext`. + + Shared by :class:`JWKSTokenValidator` and :class:`OpaqueTokenIntrospector` so + JWT and opaque-token resource servers map claims identically. + """ + m = mappings + + # Principal: first non-empty principal claim wins. + user_id: str | None = None + for claim in m.principal_claims: + vals = _flatten_strs(_resolve_claim_path(payload, claim)) + if vals: + user_id = vals[0] + break + + # Authorities/roles: collect across every configured path, de-duplicated + # (order-preserving), with the optional prefix applied. + roles: list[str] = [] + seen: set[str] = set() + for claim in m.authority_claims: + for raw in _flatten_strs(_resolve_claim_path(payload, claim)): + value = f"{m.authority_prefix}{raw}" if m.authority_prefix else raw + if value not in seen: + seen.add(value) + roles.append(value) + + # Permissions/scopes: scope claims are space-delimited strings or lists. + permissions: list[str] = [] + perm_seen: set[str] = set() + for claim in m.scope_claims: + for raw in _flatten_strs(_resolve_claim_path(payload, claim)): + for part in raw.split(): + if part and part not in perm_seen: + perm_seen.add(part) + permissions.append(part) + + # Attributes: copy configured claims verbatim (string-coerced). + attributes: dict[str, str] = {} + for claim in m.attribute_claims: + vals = _flatten_strs(_resolve_claim_path(payload, claim)) + if vals: + attributes[claim] = vals[0] + + return SecurityContext( + user_id=user_id, + roles=roles, + permissions=permissions, + attributes=attributes, + ) + + +class OpaqueTokenIntrospector: + """Validates opaque access tokens via an RFC 7662 introspection endpoint. + + The resource server posts the token (with its own client credentials) to the + authorization server's ``/introspect`` endpoint and maps the returned claims + onto a :class:`SecurityContext` using the same :class:`ClaimMappings` as the + JWT validator. Use this for opaque (non-JWT) tokens. + """ + + def __init__( + self, + introspection_uri: str, + *, + client_id: str, + client_secret: str, + claim_mappings: ClaimMappings | None = None, + timeout: float = 10.0, + ) -> None: + self._uri = introspection_uri + self._client_id = client_id + self._client_secret = client_secret + self._mappings = claim_mappings or ClaimMappings() + self._timeout = timeout + + def introspect(self, token: str) -> dict[str, Any]: + """Return the introspection claims for *token*, or raise if it is inactive.""" + import httpx + + try: + with httpx.Client(timeout=self._timeout) as client: + resp = client.post( + self._uri, + data={"token": token, "token_type_hint": "access_token"}, + auth=(self._client_id, self._client_secret), + headers={"Accept": "application/json"}, + ) + except httpx.HTTPError as exc: + raise SecurityException(f"Token introspection request failed: {exc}", code="INVALID_TOKEN") from exc + if resp.status_code != 200: + raise SecurityException(f"Token introspection failed (HTTP {resp.status_code})", code="INVALID_TOKEN") + payload: dict[str, Any] = resp.json() + if not payload.get("active"): + raise SecurityException("Token is not active", code="INVALID_TOKEN") + return payload + + def to_security_context(self, token: str) -> SecurityContext: + return build_security_context(self.introspect(token), self._mappings) diff --git a/src/pyfly/security/password.py b/src/pyfly/security/password.py index e8f50d48..df506416 100644 --- a/src/pyfly/security/password.py +++ b/src/pyfly/security/password.py @@ -11,10 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Password encoding port and bcrypt adapter.""" +"""Password encoding port and adapters. + +Provides a :class:`PasswordEncoder` port and several adapters — bcrypt, PBKDF2, +scrypt, Argon2 — plus a :class:`DelegatingPasswordEncoder` that prefixes hashes +with a ``{id}`` so the algorithm can be migrated over time (Spring Security +parity with ``DelegatingPasswordEncoder`` / ``PasswordEncoderFactories``). +""" from __future__ import annotations +import base64 +import hashlib +import hmac +import secrets from typing import Protocol, runtime_checkable import bcrypt as _bcrypt @@ -50,7 +60,177 @@ def hash(self, raw_password: str) -> str: def verify(self, raw_password: str, hashed_password: str) -> bool: """Verify a raw password against a bcrypt hash.""" - return _bcrypt.checkpw( - raw_password.encode("utf-8"), - hashed_password.encode("utf-8"), + try: + return _bcrypt.checkpw( + raw_password.encode("utf-8"), + hashed_password.encode("utf-8"), + ) + except (ValueError, TypeError): + # Malformed / non-bcrypt stored value — treat as a non-match. + return False + + +class Pbkdf2PasswordEncoder: + """PasswordEncoder using PBKDF2-HMAC (stdlib ``hashlib``). + + Produces a self-describing string ``$$$``. + PBKDF2 is FIPS-friendly; defaults to 600k SHA-256 iterations (OWASP 2023). + """ + + def __init__(self, *, iterations: int = 600_000, salt_bytes: int = 16, algorithm: str = "sha256") -> None: + self._iterations = iterations + self._salt_bytes = salt_bytes + self._algorithm = algorithm + + def hash(self, raw_password: str) -> str: + salt = secrets.token_bytes(self._salt_bytes) + digest = hashlib.pbkdf2_hmac(self._algorithm, raw_password.encode("utf-8"), salt, self._iterations) + return ( + f"{self._algorithm}${self._iterations}$" + f"{base64.b64encode(salt).decode('ascii')}${base64.b64encode(digest).decode('ascii')}" + ) + + def verify(self, raw_password: str, hashed_password: str) -> bool: + try: + algorithm, iterations_s, salt_b64, digest_b64 = hashed_password.split("$") + iterations = int(iterations_s) + salt = base64.b64decode(salt_b64) + expected = base64.b64decode(digest_b64) + except (ValueError, TypeError): + return False + actual = hashlib.pbkdf2_hmac(algorithm, raw_password.encode("utf-8"), salt, iterations, dklen=len(expected)) + return hmac.compare_digest(actual, expected) + + +class ScryptPasswordEncoder: + """PasswordEncoder using scrypt (stdlib ``hashlib.scrypt``). + + Produces ``$$

$$``. Memory-hard; defaults follow + common interactive-login parameters (n=2**14, r=8, p=1). + """ + + def __init__(self, *, n: int = 2**14, r: int = 8, p: int = 1, salt_bytes: int = 16, dklen: int = 32) -> None: + self._n = n + self._r = r + self._p = p + self._salt_bytes = salt_bytes + self._dklen = dklen + + def hash(self, raw_password: str) -> str: + salt = secrets.token_bytes(self._salt_bytes) + digest = hashlib.scrypt( + raw_password.encode("utf-8"), salt=salt, n=self._n, r=self._r, p=self._p, dklen=self._dklen + ) + return ( + f"{self._n}${self._r}${self._p}$" + f"{base64.b64encode(salt).decode('ascii')}${base64.b64encode(digest).decode('ascii')}" + ) + + def verify(self, raw_password: str, hashed_password: str) -> bool: + try: + n_s, r_s, p_s, salt_b64, digest_b64 = hashed_password.split("$") + n, r, p = int(n_s), int(r_s), int(p_s) + salt = base64.b64decode(salt_b64) + expected = base64.b64decode(digest_b64) + except (ValueError, TypeError): + return False + try: + actual = hashlib.scrypt(raw_password.encode("utf-8"), salt=salt, n=n, r=r, p=p, dklen=len(expected)) + except ValueError: + return False + return hmac.compare_digest(actual, expected) + + +class Argon2PasswordEncoder: + """PasswordEncoder using Argon2id (OWASP-preferred). Requires ``argon2-cffi``. + + The dependency is imported lazily so the rest of the security module works + without it; install with ``pip install pyfly[argon2]`` to use this encoder. + """ + + def __init__(self, *, time_cost: int = 3, memory_cost: int = 65536, parallelism: int = 4) -> None: + self._time_cost = time_cost + self._memory_cost = memory_cost + self._parallelism = parallelism + + def _hasher(self) -> object: + try: + from argon2 import PasswordHasher # type: ignore[import-not-found, unused-ignore] + except ImportError as exc: # pragma: no cover - exercised only without argon2-cffi + raise ImportError("Argon2PasswordEncoder requires argon2-cffi — `pip install pyfly[argon2]`") from exc + return PasswordHasher(time_cost=self._time_cost, memory_cost=self._memory_cost, parallelism=self._parallelism) + + def hash(self, raw_password: str) -> str: + return str(self._hasher().hash(raw_password)) # type: ignore[attr-defined] + + def verify(self, raw_password: str, hashed_password: str) -> bool: + from argon2.exceptions import ( # type: ignore[import-not-found, unused-ignore] + VerificationError, + VerifyMismatchError, ) + + try: + return bool(self._hasher().verify(hashed_password, raw_password)) # type: ignore[attr-defined] + except (VerifyMismatchError, VerificationError): + return False + + +class DelegatingPasswordEncoder: + """Password encoder that prefixes hashes with ``{id}`` and delegates by id. + + Spring Security parity (``DelegatingPasswordEncoder``): :meth:`hash` produces + ``{}`` using the default encoder; :meth:`verify` + reads the ``{id}`` prefix and dispatches to the matching encoder. A stored + value with an unknown or missing prefix never matches. :meth:`upgrade_encoding` + reports whether a stored hash should be re-hashed with the current default — + enabling transparent on-login migration between algorithms. + """ + + def __init__(self, encoders: dict[str, PasswordEncoder], encoding_id: str) -> None: + if encoding_id not in encoders: + raise ValueError(f"encoding_id {encoding_id!r} is not present in the encoders map") + self._encoders = dict(encoders) + self._encoding_id = encoding_id + + @staticmethod + def _split(stored: str) -> tuple[str | None, str]: + """Return ``(id, remainder)`` for a ``{id}...`` value, or ``(None, stored)``.""" + if stored.startswith("{"): + end = stored.find("}") + if end > 0: + return stored[1:end], stored[end + 1 :] + return None, stored + + def hash(self, raw_password: str) -> str: + inner = self._encoders[self._encoding_id].hash(raw_password) + return f"{{{self._encoding_id}}}{inner}" + + def verify(self, raw_password: str, hashed_password: str) -> bool: + encoding_id, inner = self._split(hashed_password) + encoder = self._encoders.get(encoding_id) if encoding_id is not None else None + if encoder is None: + return False + return encoder.verify(raw_password, inner) + + def upgrade_encoding(self, hashed_password: str) -> bool: + """Whether *hashed_password* should be re-hashed with the current default.""" + encoding_id, _ = self._split(hashed_password) + return encoding_id != self._encoding_id + + +def create_delegating_password_encoder(*, bcrypt_rounds: int = 12) -> DelegatingPasswordEncoder: + """Build a :class:`DelegatingPasswordEncoder` with bcrypt as the default id. + + Mirrors Spring's ``PasswordEncoderFactories.createDelegatingPasswordEncoder()``: + new hashes use bcrypt (``{bcrypt}``), while ``{pbkdf2}``, ``{scrypt}`` and + ``{argon2}`` hashes are still recognised for verification and migration. + """ + return DelegatingPasswordEncoder( + { + "bcrypt": BcryptPasswordEncoder(rounds=bcrypt_rounds), + "pbkdf2": Pbkdf2PasswordEncoder(), + "scrypt": ScryptPasswordEncoder(), + "argon2": Argon2PasswordEncoder(), + }, + encoding_id="bcrypt", + ) diff --git a/src/pyfly/security/permission.py b/src/pyfly/security/permission.py new file mode 100644 index 00000000..0816aa84 --- /dev/null +++ b/src/pyfly/security/permission.py @@ -0,0 +1,44 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PermissionEvaluator — ACL-style ``hasPermission`` SPI (Spring parity). + +Install one via :func:`pyfly.security.expression.set_permission_evaluator` to back +``hasPermission(target, 'perm')`` / ``hasPermission(id, 'Type', 'perm')`` method +-security expressions with domain-object permission checks. +""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class PermissionEvaluator(Protocol): + """Decides whether the current principal holds *permission* on a target object.""" + + def has_permission( + self, + context: Any, + target: Any, + permission: str, + *, + target_type: str | None = None, + ) -> bool: + """Return whether the principal (in *context*) has *permission* on *target*. + + *target* is the domain object (2-arg form) or its identifier (3-arg form, + where *target_type* names the object type). *context* is the active + :class:`~pyfly.security.context.SecurityContext`. + """ + ... diff --git a/src/pyfly/security/user_details.py b/src/pyfly/security/user_details.py new file mode 100644 index 00000000..7905de32 --- /dev/null +++ b/src/pyfly/security/user_details.py @@ -0,0 +1,55 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UserDetails / UserDetailsService — the credential-lookup SPI. + +Spring Security parity: a :class:`UserDetailsService` resolves a username to a +:class:`UserDetails` (a stored password hash plus authorities), which the HTTP +Basic / form-login filters verify against a :class:`~pyfly.security.password.PasswordEncoder`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Protocol, runtime_checkable + + +@dataclass(frozen=True) +class UserDetails: + """A resolved principal: a stored credential plus granted authorities.""" + + username: str + password_hash: str + roles: list[str] = field(default_factory=list) + permissions: list[str] = field(default_factory=list) + enabled: bool = True + + +@runtime_checkable +class UserDetailsService(Protocol): + """Port that resolves a username to its :class:`UserDetails`, or ``None``.""" + + async def load_user_by_username(self, username: str) -> UserDetails | None: ... + + +class InMemoryUserDetailsService: + """A :class:`UserDetailsService` backed by an in-memory dict (dev / testing).""" + + def __init__(self, *users: UserDetails) -> None: + self._users: dict[str, UserDetails] = {u.username: u for u in users} + + async def load_user_by_username(self, username: str) -> UserDetails | None: + return self._users.get(username) + + def add(self, user: UserDetails) -> None: + self._users[user.username] = user diff --git a/src/pyfly/web/adapters/starlette/filters/csrf_filter.py b/src/pyfly/web/adapters/starlette/filters/csrf_filter.py index 0553b081..2916f7fe 100644 --- a/src/pyfly/web/adapters/starlette/filters/csrf_filter.py +++ b/src/pyfly/web/adapters/starlette/filters/csrf_filter.py @@ -73,6 +73,15 @@ class CsrfFilter(OncePerRequestFilter): exclude_patterns = ["/actuator/*", "/health", "/ready"] + def __init__(self, *, cookie_gated: bool = True) -> None: + # ``cookie_gated`` (default): only enforce CSRF on unsafe requests that + # carry cookies — i.e. requests with ambient authority a cross-site forgery + # could abuse. A request with no cookies (a stateless API client) has no + # CSRF surface and is exempt, so CSRF can be on by default without breaking + # token/stateless clients. Set ``cookie_gated=False`` for strict enforcement + # of every unsafe request regardless of cookies. + self._cookie_gated = cookie_gated + async def do_filter(self, request: Any, call_next: CallNext) -> Any: method: str = request.method @@ -91,6 +100,13 @@ async def do_filter(self, request: Any, call_next: CallNext) -> Any: if auth_header and auth_header.startswith("Bearer "): return await call_next(request) + # ----------------------------------------------------------------- + # Cookie-gated exemption — no cookies means no ambient authority for a + # cross-site request to abuse, so there is nothing to protect. + # ----------------------------------------------------------------- + if self._cookie_gated and not request.cookies: + return await call_next(request) + # ----------------------------------------------------------------- # Unsafe methods — validate double-submit cookie. # ----------------------------------------------------------------- diff --git a/src/pyfly/web/adapters/starlette/filters/form_login_filter.py b/src/pyfly/web/adapters/starlette/filters/form_login_filter.py new file mode 100644 index 00000000..8cd68bdf --- /dev/null +++ b/src/pyfly/web/adapters/starlette/filters/form_login_filter.py @@ -0,0 +1,103 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Form-login filter (Spring ``formLogin``). + +Processes a POST of username/password to the login URL, authenticates via an +:class:`~pyfly.security.authentication.ProviderManager`, and on success rotates +the session id (fixation defense) and stores the :class:`SecurityContext` in the +session — where :class:`OAuth2SessionSecurityFilter` restores it on later +requests. Browser (redirect) and API (JSON) responses are both supported. +""" + +from __future__ import annotations + +import logging + +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response + +from pyfly.container.ordering import HIGHEST_PRECEDENCE +from pyfly.security.authentication import Authentication, AuthenticationException, ProviderManager +from pyfly.web.filters import OncePerRequestFilter +from pyfly.web.ports.filter import CallNext + +logger = logging.getLogger(__name__) + +_SECURITY_CONTEXT_KEY = "SECURITY_CONTEXT" + + +class FormLoginFilter(OncePerRequestFilter): + """Authenticates a username/password form POST and establishes a session. + + Runs at ``HIGHEST_PRECEDENCE + 230`` — after the session-restoring filter + (``+225``) so a successful login overrides any prior anonymous context. + """ + + __pyfly_order__ = HIGHEST_PRECEDENCE + 230 + + def __init__( + self, + authentication_manager: ProviderManager, + *, + login_url: str = "/login", + username_param: str = "username", + password_param: str = "password", + success_url: str = "/", + failure_url: str = "/login?error", + use_redirect: bool = True, + ) -> None: + self._manager = authentication_manager + self._login_url = login_url + self._username_param = username_param + self._password_param = password_param + self._success_url = success_url + self._failure_url = failure_url + self._use_redirect = use_redirect + + async def do_filter(self, request: Request, call_next: CallNext) -> Response: + if request.method == "POST" and request.url.path == self._login_url: + return await self._attempt_login(request) + return await call_next(request) # type: ignore[no-any-return] + + async def _attempt_login(self, request: Request) -> Response: + form = await request.form() + username = str(form.get(self._username_param, "") or "") + password = str(form.get(self._password_param, "") or "") + + try: + result = await self._manager.authenticate(Authentication(principal=username, credentials=password)) + except AuthenticationException: + logger.warning("Form login failed for user %r", username) + return self._failure() + + context = result.to_security_context() + session = getattr(getattr(request, "state", None), "session", None) + if session is not None: + # Rotate the session id on authentication to prevent session fixation, + # then bind the authenticated context to the (new) session. + session.rotate_id() + session.set_attribute(_SECURITY_CONTEXT_KEY, context) + request.state.security_context = context + logger.info("Form login successful for user: %s", context.user_id) + return self._success() + + def _success(self) -> Response: + if self._use_redirect: + return RedirectResponse(url=self._success_url, status_code=302) + return JSONResponse({"authenticated": True}) + + def _failure(self) -> Response: + if self._use_redirect: + return RedirectResponse(url=self._failure_url, status_code=302) + return JSONResponse({"error": "invalid_credentials"}, status_code=401) diff --git a/src/pyfly/web/adapters/starlette/filters/http_basic_filter.py b/src/pyfly/web/adapters/starlette/filters/http_basic_filter.py new file mode 100644 index 00000000..9f55207f --- /dev/null +++ b/src/pyfly/web/adapters/starlette/filters/http_basic_filter.py @@ -0,0 +1,138 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""HTTP Basic authentication filter (RFC 7617). + +Parses an ``Authorization: Basic`` header, resolves the user via a +:class:`~pyfly.security.user_details.UserDetailsService`, verifies the password +with a :class:`~pyfly.security.password.PasswordEncoder`, and populates the +request :class:`SecurityContext`. + +``error_mode`` mirrors the OAuth2 resource-server filter: + +* ``"anonymous"`` (default): bad/missing credentials yield an anonymous context + and the request proceeds — the ``HttpSecurity`` gate decides. +* ``"401"``: present-but-invalid credentials are rejected here with + ``401 Unauthorized`` and a ``WWW-Authenticate: Basic realm="…"`` challenge. + Missing credentials still fall through to the gate. +""" + +from __future__ import annotations + +import base64 +import binascii +import logging +from typing import cast + +from anyio import to_thread +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from pyfly.container.ordering import HIGHEST_PRECEDENCE, order +from pyfly.context.request_context import RequestContext +from pyfly.security.context import SecurityContext +from pyfly.security.password import PasswordEncoder +from pyfly.security.user_details import UserDetailsService +from pyfly.web.filters import OncePerRequestFilter +from pyfly.web.ports.filter import CallNext + +logger = logging.getLogger(__name__) + +ERROR_MODE_ANONYMOUS = "anonymous" +ERROR_MODE_401 = "401" + + +@order(HIGHEST_PRECEDENCE + 215) +class HttpBasicAuthenticationFilter(OncePerRequestFilter): + """Authenticates ``Authorization: Basic`` credentials against a UserDetailsService. + + Ordered just before the symmetric JWT ``SecurityFilter`` (``+220``) so it can + establish a context for credential-based clients while leaving token-based + auth to the later filters when no Basic header is present. + """ + + def __init__( + self, + user_details_service: UserDetailsService, + password_encoder: PasswordEncoder, + *, + realm: str = "Realm", + error_mode: str = ERROR_MODE_ANONYMOUS, + ) -> None: + self._users = user_details_service + self._encoder = password_encoder + self._realm = realm + self._error_mode = error_mode if error_mode in (ERROR_MODE_ANONYMOUS, ERROR_MODE_401) else ERROR_MODE_ANONYMOUS + + async def do_filter(self, request: Request, call_next: CallNext) -> Response: + credentials = self._extract_basic(request.headers.get("authorization", "")) + + if credentials is None: + # No Basic credentials presented — leave any existing context alone and + # default to anonymous so downstream filters/handlers always have one. + if not hasattr(request.state, "security_context"): + request.state.security_context = SecurityContext.anonymous() + return cast(Response, await call_next(request)) + + username, password = credentials + context = await self._authenticate(username, password) + + if context is None: + logger.warning("HTTP Basic authentication failed for user %r", username) + if self._error_mode == ERROR_MODE_401: + return self._challenge() + context = SecurityContext.anonymous() + + request.state.security_context = context + req_ctx = RequestContext.current() + if req_ctx is not None: + req_ctx.security_context = context + return cast(Response, await call_next(request)) + + async def _authenticate(self, username: str, password: str) -> SecurityContext | None: + user = await self._users.load_user_by_username(username) + if user is None or not user.enabled: + return None + # bcrypt/argon2 verification is CPU-bound; offload so we never block the loop. + ok = await to_thread.run_sync(self._encoder.verify, password, user.password_hash) + if not ok: + return None + return SecurityContext(user_id=user.username, roles=list(user.roles), permissions=list(user.permissions)) + + @staticmethod + def _extract_basic(auth_header: str) -> tuple[str, str] | None: + """Return ``(username, password)`` from a Basic header, or ``None``. + + Returns ``("", "")``-style failures as ``None`` only for a *missing* or + *non-Basic* header; a malformed Basic payload raises through to a 401 by + returning a sentinel the caller treats as an auth failure. + """ + parts = auth_header.split(" ", 1) + if len(parts) != 2 or parts[0].lower() != "basic" or not parts[1].strip(): + return None + try: + decoded = base64.b64decode(parts[1].strip(), validate=True).decode("utf-8") + except (binascii.Error, ValueError, UnicodeDecodeError): + # Malformed credentials — treat as a (present) failed attempt. + return ("\x00invalid", "") + username, sep, password = decoded.partition(":") + if not sep: + return ("\x00invalid", "") + return (username, password) + + def _challenge(self) -> Response: + return JSONResponse( + {"error": "invalid_credentials", "error_description": "Authentication failed."}, + status_code=401, + headers={"WWW-Authenticate": f'Basic realm="{self._realm}"'}, + ) diff --git a/src/pyfly/web/adapters/starlette/filters/http_security_filter.py b/src/pyfly/web/adapters/starlette/filters/http_security_filter.py index 118e3341..f4db26bd 100644 --- a/src/pyfly/web/adapters/starlette/filters/http_security_filter.py +++ b/src/pyfly/web/adapters/starlette/filters/http_security_filter.py @@ -85,9 +85,14 @@ async def do_filter(self, request: Request, call_next: CallNext) -> Response: path: str = request.url.path security_context: SecurityContext = getattr(request.state, "security_context", SecurityContext.anonymous()) + method: str = request.method.upper() for security_rule in self._rules: if not _matches(path, security_rule.patterns): continue + # A rule scoped to specific HTTP methods only applies to those methods; + # an empty method list matches any method. + if security_rule.methods and method not in security_rule.methods: + continue rule = security_rule.rule rule_type = rule.rule_type diff --git a/src/pyfly/web/adapters/starlette/filters/logout_filter.py b/src/pyfly/web/adapters/starlette/filters/logout_filter.py new file mode 100644 index 00000000..7da684e3 --- /dev/null +++ b/src/pyfly/web/adapters/starlette/filters/logout_filter.py @@ -0,0 +1,82 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generic logout filter (Spring ``logout`` / ``LogoutConfigurer``). + +Handles a POST to the logout URL by invalidating the HTTP session, clearing the +security context, and deleting configured cookies — independent of OAuth2. Browser +(redirect) and API (204) responses are both supported. +""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence + +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response + +from pyfly.container.ordering import HIGHEST_PRECEDENCE +from pyfly.web.filters import OncePerRequestFilter +from pyfly.web.ports.filter import CallNext + +logger = logging.getLogger(__name__) + +_SECURITY_CONTEXT_KEY = "SECURITY_CONTEXT" + + +class LogoutFilter(OncePerRequestFilter): + """Invalidates the session on a POST to the logout URL. + + Runs at ``HIGHEST_PRECEDENCE + 235`` (after form login). Configure the URL, + success URL, response mode, and cookies to clear. + """ + + __pyfly_order__ = HIGHEST_PRECEDENCE + 235 + + def __init__( + self, + *, + logout_url: str = "/logout", + logout_success_url: str = "/login?logout", + delete_cookies: Sequence[str] = (), + use_redirect: bool = True, + ) -> None: + self._logout_url = logout_url + self._logout_success_url = logout_success_url + self._delete_cookies = list(delete_cookies) + self._use_redirect = use_redirect + + async def do_filter(self, request: Request, call_next: CallNext) -> Response: + if request.method == "POST" and request.url.path == self._logout_url: + return self._logout(request) + return await call_next(request) # type: ignore[no-any-return] + + def _logout(self, request: Request) -> Response: + session = getattr(getattr(request, "state", None), "session", None) + if session is not None: + session.set_attribute(_SECURITY_CONTEXT_KEY, None) + session.invalidate() + if hasattr(request, "state"): + from pyfly.security.context import SecurityContext + + request.state.security_context = SecurityContext.anonymous() + response: Response + if self._use_redirect: + response = RedirectResponse(url=self._logout_success_url, status_code=302) + else: + response = Response(status_code=204) + for cookie in self._delete_cookies: + response.delete_cookie(cookie, path="/") + logger.info("Logout processed for path %s", request.url.path) + return response diff --git a/src/pyfly/web/adapters/starlette/filters/oauth2_resource_filter.py b/src/pyfly/web/adapters/starlette/filters/oauth2_resource_filter.py index 8c42c079..2d1894d2 100644 --- a/src/pyfly/web/adapters/starlette/filters/oauth2_resource_filter.py +++ b/src/pyfly/web/adapters/starlette/filters/oauth2_resource_filter.py @@ -17,7 +17,7 @@ import logging from collections.abc import Sequence -from typing import cast +from typing import Any, cast from anyio import to_thread from starlette.requests import Request @@ -70,23 +70,37 @@ def __init__( exclude_patterns: Sequence[str] = (), *, error_mode: str = ERROR_MODE_ANONYMOUS, + enforce_sender_constraints: bool = False, + dpop_validator: Any = None, + mtls_cert_header: str = "x-client-cert", ) -> None: self._token_validator = token_validator self.exclude_patterns = list(exclude_patterns) self._error_mode = error_mode if error_mode in (ERROR_MODE_ANONYMOUS, ERROR_MODE_401) else ERROR_MODE_ANONYMOUS + self._enforce_sc = enforce_sender_constraints + self._dpop_validator = dpop_validator + self._mtls_cert_header = mtls_cert_header async def do_filter(self, request: Request, call_next: CallNext) -> Response: - token = self._extract_bearer(request.headers.get("authorization", "")) + token = self._extract_token(request.headers.get("authorization", "")) if token is not None: try: # Offload to a worker thread: JWKS key lookup may do blocking # urllib I/O on a cache miss, which would otherwise stall the loop. - security_context = await to_thread.run_sync(self._token_validator.to_security_context, token) + if self._enforce_sc: + payload, security_context = await to_thread.run_sync( + self._token_validator.validate_and_context, token + ) + # Sender-constrained tokens (RFC 9449 DPoP / RFC 8705 mTLS) must be + # accompanied by proof of possession; a stolen token alone is useless. + self._enforce_sender_constraint(request, payload, token) + else: + security_context = await to_thread.run_sync(self._token_validator.to_security_context, token) except SecurityException: # A token was presented but failed validation (bad signature, - # expired, wrong iss/aud, unknown kid, ...). - logger.warning("OAuth2 bearer token rejected (invalid_token)") + # expired, wrong iss/aud, unknown kid, failed proof-of-possession). + logger.warning("OAuth2 token rejected (invalid_token)") if self._error_mode == ERROR_MODE_401: return self._invalid_token_response() security_context = SecurityContext.anonymous() @@ -100,16 +114,42 @@ async def do_filter(self, request: Request, call_next: CallNext) -> Response: req_ctx.security_context = security_context return cast(Response, await call_next(request)) + def _enforce_sender_constraint(self, request: Request, payload: dict[str, Any], token: str) -> None: + """Enforce DPoP/mTLS proof-of-possession when the token carries a ``cnf`` claim.""" + cnf = payload.get("cnf") + if not isinstance(cnf, dict): + return # plain bearer token — nothing to enforce + if "jkt" in cnf: + from urllib.parse import urlsplit, urlunsplit + + from pyfly.security.oauth2.dpop import DPoPProofValidator, confirm_dpop_binding + + proof = request.headers.get("dpop") + if not proof: + raise SecurityException("DPoP proof required for this token", code="INVALID_TOKEN") + validator = self._dpop_validator or DPoPProofValidator() + parts = urlsplit(str(request.url)) + http_url = urlunsplit((parts.scheme, parts.netloc, parts.path, "", "")) + jkt = validator.validate(proof, http_method=request.method, http_url=http_url, access_token=token) + confirm_dpop_binding(payload, jkt) + elif "x5t#S256" in cnf: + from urllib.parse import unquote + + from pyfly.security.oauth2.dpop import confirm_mtls_binding + + cert = request.headers.get(self._mtls_cert_header) + if not cert: + raise SecurityException("Client certificate required for this token", code="INVALID_TOKEN") + confirm_mtls_binding(payload, unquote(cert)) + @staticmethod - def _extract_bearer(auth_header: str) -> str | None: - """Return the token from an ``Authorization`` header, or ``None``. + def _extract_token(auth_header: str) -> str | None: + """Return the token from a ``Bearer`` or ``DPoP`` ``Authorization`` header. - The auth scheme is matched case-insensitively (RFC 7235 §2.1: the scheme - is a case-insensitive token), so ``Bearer``, ``bearer`` and ``BEARER`` - are all accepted. + The auth scheme is matched case-insensitively (RFC 7235 §2.1). """ parts = auth_header.split(" ", 1) - if len(parts) == 2 and parts[0].lower() == "bearer" and parts[1].strip(): + if len(parts) == 2 and parts[0].lower() in ("bearer", "dpop") and parts[1].strip(): return parts[1].strip() return None diff --git a/src/pyfly/web/security_filters_auto_configuration.py b/src/pyfly/web/security_filters_auto_configuration.py index 04b7bf0b..2530dd77 100644 --- a/src/pyfly/web/security_filters_auto_configuration.py +++ b/src/pyfly/web/security_filters_auto_configuration.py @@ -50,15 +50,28 @@ def _exclude_patterns(config: Config, key: str) -> Sequence[str]: @auto_configuration @conditional_on_class("starlette") -@conditional_on_property("pyfly.security.csrf.enabled", having_value="true") +@conditional_on_property("pyfly.security.csrf.enabled", having_value="true", match_if_missing=True) class CsrfFilterAutoConfiguration: - """Registers the double-submit-cookie CSRF filter (opt-in).""" + """Registers the double-submit-cookie CSRF filter. + + Secure by default: active unless ``pyfly.security.csrf.enabled=false``. The + filter runs in cookie-gated mode (``pyfly.security.csrf.cookie-gated``, + default true), so stateless/token (no-cookie) clients are unaffected while + browser/session requests are protected. Set ``cookie-gated: false`` for + strict enforcement of every unsafe request. + """ @bean def csrf_filter(self, config: Config) -> WebFilter: from pyfly.web.adapters.starlette.filters.csrf_filter import CsrfFilter - filter_ = CsrfFilter() + cookie_gated = str(config.get("pyfly.security.csrf.cookie-gated", True)).strip().lower() not in ( + "0", + "false", + "no", + "off", + ) + filter_ = CsrfFilter(cookie_gated=cookie_gated) excludes = _exclude_patterns(config, "pyfly.security.csrf.exclude-patterns") if excludes: filter_.exclude_patterns = list(excludes) diff --git a/tests/config/test_auto.py b/tests/config/test_auto.py index fc70bca6..65d1e8a4 100644 --- a/tests/config/test_auto.py +++ b/tests/config/test_auto.py @@ -46,7 +46,7 @@ def test_detect_messaging_provider(self): class TestDiscoverAutoConfigurations: def test_returns_all_auto_config_classes(self): classes = discover_auto_configurations() - assert len(classes) == 46 + assert len(classes) == 49 def test_all_classes_have_auto_configuration_marker(self): for cls in discover_auto_configurations(): @@ -81,6 +81,9 @@ def test_contains_expected_class_names(self): "ConfigServerAutoConfiguration", "CsrfFilterAutoConfiguration", "HttpSecurityFilterAutoConfiguration", + "HttpBasicAutoConfiguration", + "FormLoginAutoConfiguration", + "LogoutAutoConfiguration", "CqrsAutoConfiguration", "DocumentAutoConfiguration", "EcmAutoConfiguration", diff --git a/tests/idp/test_azure_ad_behavior.py b/tests/idp/test_azure_ad_behavior.py index f8121257..b9faced4 100644 --- a/tests/idp/test_azure_ad_behavior.py +++ b/tests/idp/test_azure_ad_behavior.py @@ -114,9 +114,21 @@ def _adapter() -> AzureAdIdpAdapter: tenant_id=TENANT_ID, client_id=CLIENT_ID, client_secret=CLIENT_SECRET, + allow_password_grant=True, ) +@pytest.mark.asyncio +async def test_login_refused_without_password_grant_optin() -> None: + """ROPC (grant_type=password) is refused unless explicitly enabled (RFC 9700 §2.4).""" + from pyfly.kernel.exceptions import SecurityException + + adapter = AzureAdIdpAdapter(tenant_id=TENANT_ID, client_id=CLIENT_ID, client_secret=CLIENT_SECRET) + with pytest.raises(SecurityException) as exc: + await adapter.login(LoginRequest(username="alice@example.com", password="s3cr3t!")) + assert exc.value.code == "ROPC_DISABLED" + + def _inject(adapter: AzureAdIdpAdapter, fake: FakeClient) -> None: """Make every ``await self._client()`` return the same recording fake.""" diff --git a/tests/idp/test_cognito_behavior.py b/tests/idp/test_cognito_behavior.py index 52aeb426..3ff42570 100644 --- a/tests/idp/test_cognito_behavior.py +++ b/tests/idp/test_cognito_behavior.py @@ -117,9 +117,21 @@ def _adapter(fake: _FakeCognitoClient) -> AwsCognitoIdpAdapter: client_id=CLIENT_ID, region=REGION, client=fake, + allow_password_grant=True, ) +@pytest.mark.asyncio +async def test_login_refused_without_password_grant_optin() -> None: + """ROPC (USER_PASSWORD_AUTH) is refused unless explicitly enabled (RFC 9700 §2.4).""" + from pyfly.kernel.exceptions import SecurityException + + adapter = AwsCognitoIdpAdapter(user_pool_id=USER_POOL_ID, client_id=CLIENT_ID, region=REGION, client=object()) + with pytest.raises(SecurityException) as exc: + await adapter.login(LoginRequest(username="alice", password="hunter2")) + assert exc.value.code == "ROPC_DISABLED" + + # --------------------------------------------------------------------------- # # login — initiate_auth USER_PASSWORD_AUTH → AuthResult # --------------------------------------------------------------------------- # diff --git a/tests/idp/test_keycloak_behavior.py b/tests/idp/test_keycloak_behavior.py index f72743c4..91ca44bf 100644 --- a/tests/idp/test_keycloak_behavior.py +++ b/tests/idp/test_keycloak_behavior.py @@ -103,9 +103,21 @@ def _adapter() -> KeycloakIdpAdapter: realm=REALM, client_id="admin-cli", client_secret="s3cr3t", + allow_password_grant=True, ) +@pytest.mark.asyncio +async def test_login_refused_without_password_grant_optin() -> None: + """ROPC (grant_type=password) is refused unless explicitly enabled (RFC 9700 §2.4).""" + from pyfly.kernel.exceptions import SecurityException + + adapter = KeycloakIdpAdapter(base_url=BASE_URL, realm=REALM, client_id="admin-cli", client_secret="s3cr3t") + with pytest.raises(SecurityException) as exc: + await adapter.login(LoginRequest(username="bob", password="hunter2")) + assert exc.value.code == "ROPC_DISABLED" + + def _inject(adapter: KeycloakIdpAdapter, fake: FakeClient) -> None: """Make every ``await self._client()`` return the same recording fake.""" diff --git a/tests/idp/test_wave_idp_web.py b/tests/idp/test_wave_idp_web.py index 148ca276..a2142e3a 100644 --- a/tests/idp/test_wave_idp_web.py +++ b/tests/idp/test_wave_idp_web.py @@ -97,7 +97,12 @@ async def test_cognito_login_includes_secret_hash() -> None: fake = _FakeBoto() adapter = AwsCognitoIdpAdapter( - user_pool_id="pool", client_id="cid", region="us-east-1", client_secret="shh", client=fake + user_pool_id="pool", + client_id="cid", + region="us-east-1", + client_secret="shh", + client=fake, + allow_password_grant=True, ) await adapter.login(LoginRequest(username="bob", password="pw")) assert "SECRET_HASH" in fake.auth_params # audit #23 diff --git a/tests/security/test_as_asymmetric.py b/tests/security/test_as_asymmetric.py new file mode 100644 index 00000000..868d0fb1 --- /dev/null +++ b/tests/security/test_as_asymmetric.py @@ -0,0 +1,94 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Asymmetric (RS256) authorization-server signing + JWKS publication.""" + +from __future__ import annotations + +import jwt as pyjwt +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +from pyfly.security.oauth2.authorization_server import AuthorizationServer, InMemoryTokenStore +from pyfly.security.oauth2.client import ClientRegistration, InMemoryClientRegistrationRepository + + +def _rsa_pem() -> str: + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + return key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ).decode("utf-8") + + +def _repo() -> InMemoryClientRegistrationRepository: + return InMemoryClientRegistrationRepository( + ClientRegistration( + registration_id="c", + client_id="c", + client_secret="s3cr3t-value", + authorization_grant_type="client_credentials", + scopes=["read"], + ) + ) + + +def _as_rs256() -> AuthorizationServer: + return AuthorizationServer( + secret="", + client_repository=_repo(), + token_store=InMemoryTokenStore(), + algorithm="RS256", + private_key=_rsa_pem(), + key_id="k1", + issuer="https://as.example.com", + ) + + +class TestAsymmetricSigning: + @pytest.mark.asyncio + async def test_token_verifies_against_published_jwks(self) -> None: + server = _as_rs256() + result = await server.token(grant_type="client_credentials", client_id="c", client_secret="s3cr3t-value") + + jwks = server.jwks() + assert len(jwks["keys"]) == 1 + key = pyjwt.PyJWK.from_dict(jwks["keys"][0]).key + payload = pyjwt.decode(result["access_token"], key, algorithms=["RS256"], issuer="https://as.example.com") + assert payload["sub"] == "c" + assert payload["scope"] == "read" + + @pytest.mark.asyncio + async def test_token_header_carries_kid(self) -> None: + server = _as_rs256() + result = await server.token(grant_type="client_credentials", client_id="c", client_secret="s3cr3t-value") + header = pyjwt.get_unverified_header(result["access_token"]) + assert header["kid"] == "k1" + assert header["alg"] == "RS256" + + def test_jwks_entry_has_kid_use_alg(self) -> None: + jwk = _as_rs256().jwks()["keys"][0] + assert jwk["kid"] == "k1" + assert jwk["use"] == "sig" + assert jwk["alg"] == "RS256" + assert jwk["kty"] == "RSA" + + def test_hs256_jwks_is_empty(self) -> None: + server = AuthorizationServer( + secret="symmetric-secret-key-at-least-32b!!", + client_repository=_repo(), + token_store=InMemoryTokenStore(), + ) + assert server.jwks() == {"keys": []} diff --git a/tests/security/test_authentication.py b/tests/security/test_authentication.py new file mode 100644 index 00000000..84bf4b0e --- /dev/null +++ b/tests/security/test_authentication.py @@ -0,0 +1,92 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AuthenticationManager / AuthenticationProvider SPI.""" + +from __future__ import annotations + +import pytest + +from pyfly.security.authentication import ( + Authentication, + AuthenticationException, + BadCredentialsException, + DaoAuthenticationProvider, + DisabledException, + ProviderManager, +) +from pyfly.security.password import BcryptPasswordEncoder +from pyfly.security.user_details import InMemoryUserDetailsService, UserDetails + +_ENCODER = BcryptPasswordEncoder(rounds=4) + + +def _provider() -> DaoAuthenticationProvider: + service = InMemoryUserDetailsService( + UserDetails(username="alice", password_hash=_ENCODER.hash("pw"), roles=["ADMIN"], permissions=["read"]), + UserDetails(username="bob", password_hash=_ENCODER.hash("pw"), enabled=False), + ) + return DaoAuthenticationProvider(service, _ENCODER) + + +class TestDaoAuthenticationProvider: + @pytest.mark.asyncio + async def test_valid_credentials_authenticates(self) -> None: + result = await _provider().authenticate(Authentication(principal="alice", credentials="pw")) + assert result.authenticated is True + assert result.principal == "alice" + assert "ADMIN" in result.authorities + assert "read" in result.authorities + assert result.credentials is None # erased after authentication + + @pytest.mark.asyncio + async def test_wrong_password_raises_bad_credentials(self) -> None: + with pytest.raises(BadCredentialsException): + await _provider().authenticate(Authentication(principal="alice", credentials="WRONG")) + + @pytest.mark.asyncio + async def test_unknown_user_raises_bad_credentials(self) -> None: + with pytest.raises(BadCredentialsException): + await _provider().authenticate(Authentication(principal="ghost", credentials="pw")) + + @pytest.mark.asyncio + async def test_disabled_user_raises_disabled(self) -> None: + with pytest.raises(DisabledException): + await _provider().authenticate(Authentication(principal="bob", credentials="pw")) + + def test_supports_password_authentication(self) -> None: + assert _provider().supports(Authentication(principal="x", credentials="y")) is True + assert _provider().supports(Authentication(principal="x", credentials=None)) is False + + +class TestProviderManager: + @pytest.mark.asyncio + async def test_delegates_to_supporting_provider(self) -> None: + manager = ProviderManager(_provider()) + result = await manager.authenticate(Authentication(principal="alice", credentials="pw")) + assert result.authenticated is True + assert result.credentials is None + + @pytest.mark.asyncio + async def test_no_supporting_provider_raises(self) -> None: + manager = ProviderManager(_provider()) + with pytest.raises(AuthenticationException): + await manager.authenticate(Authentication(principal="x", credentials=None)) + + @pytest.mark.asyncio + async def test_to_security_context(self) -> None: + manager = ProviderManager(_provider()) + result = await manager.authenticate(Authentication(principal="alice", credentials="pw")) + ctx = result.to_security_context() + assert ctx.user_id == "alice" + assert ctx.is_authenticated diff --git a/tests/security/test_authorization_server.py b/tests/security/test_authorization_server.py index 072b1005..160e70a1 100644 --- a/tests/security/test_authorization_server.py +++ b/tests/security/test_authorization_server.py @@ -108,13 +108,13 @@ async def test_client_credentials_decodes_valid_jwt(self, auth_server: Authoriza assert "exp" in payload @pytest.mark.asyncio - async def test_client_credentials_custom_scope(self, auth_server: AuthorizationServer) -> None: - """Passing a custom scope overrides the registration's default scopes.""" + async def test_client_credentials_requested_scope_subset(self, auth_server: AuthorizationServer) -> None: + """A requested scope that is a subset of the registration's scopes is honoured.""" result = await auth_server.token( grant_type="client_credentials", client_id="test-client", client_secret="test-secret", - scope="admin superuser", + scope="read", ) payload = pyjwt.decode( @@ -123,8 +123,98 @@ async def test_client_credentials_custom_scope(self, auth_server: AuthorizationS algorithms=["HS256"], ) - assert payload["scope"] == "admin superuser" - assert result["scope"] == "admin superuser" + assert payload["scope"] == "read" + assert result["scope"] == "read" + + @pytest.mark.asyncio + async def test_client_credentials_rejects_unregistered_scope(self, auth_server: AuthorizationServer) -> None: + """Requesting a scope the client is not registered for is rejected (RFC 6749 §5.2). + + Prevents privilege escalation: a client registered for ``read write`` must not + be able to mint an ``admin`` token by simply asking for it. + """ + with pytest.raises(SecurityException) as exc_info: + await auth_server.token( + grant_type="client_credentials", + client_id="test-client", + client_secret="test-secret", + scope="admin superuser", + ) + assert exc_info.value.code == "INVALID_SCOPE" + + @pytest.mark.asyncio + async def test_client_credentials_partial_unregistered_scope_rejected( + self, auth_server: AuthorizationServer + ) -> None: + """A request mixing a registered and an unregistered scope is rejected wholesale.""" + with pytest.raises(SecurityException) as exc_info: + await auth_server.token( + grant_type="client_credentials", + client_id="test-client", + client_secret="test-secret", + scope="read admin", + ) + assert exc_info.value.code == "INVALID_SCOPE" + + +# --------------------------------------------------------------------------- +# Audience-restricted tokens +# --------------------------------------------------------------------------- + + +class TestAudienceClaim: + """Tokens carry an ``aud`` claim only when an audience is configured.""" + + @pytest.fixture + def auth_server_with_aud( + self, + client_repo: InMemoryClientRegistrationRepository, + token_store: InMemoryTokenStore, + ) -> AuthorizationServer: + return AuthorizationServer( + secret="test-signing-secret", + client_repository=client_repo, + token_store=token_store, + issuer="https://auth.example.com", + audience="api://lumen", + ) + + @pytest.mark.asyncio + async def test_client_credentials_token_includes_aud(self, auth_server_with_aud: AuthorizationServer) -> None: + result = await auth_server_with_aud.token( + grant_type="client_credentials", + client_id="test-client", + client_secret="test-secret", + ) + payload = pyjwt.decode( + result["access_token"], "test-signing-secret", algorithms=["HS256"], audience="api://lumen" + ) + assert payload["aud"] == "api://lumen" + + @pytest.mark.asyncio + async def test_refreshed_token_includes_aud(self, auth_server_with_aud: AuthorizationServer) -> None: + initial = await auth_server_with_aud.token( + grant_type="client_credentials", client_id="test-client", client_secret="test-secret" + ) + refreshed = await auth_server_with_aud.token( + grant_type="refresh_token", + client_id="test-client", + client_secret="test-secret", + refresh_token=initial["refresh_token"], + ) + payload = pyjwt.decode( + refreshed["access_token"], "test-signing-secret", algorithms=["HS256"], audience="api://lumen" + ) + assert payload["aud"] == "api://lumen" + + @pytest.mark.asyncio + async def test_no_aud_claim_when_audience_not_configured(self, auth_server: AuthorizationServer) -> None: + """Backward-compatible: tokens carry no ``aud`` unless an audience is set.""" + result = await auth_server.token( + grant_type="client_credentials", client_id="test-client", client_secret="test-secret" + ) + payload = pyjwt.decode(result["access_token"], "test-signing-secret", algorithms=["HS256"]) + assert "aud" not in payload # --------------------------------------------------------------------------- @@ -182,10 +272,7 @@ async def test_refresh_token_rotation( refresh_token=old_refresh, ) - # Old refresh token should be revoked - assert await token_store.find(old_refresh) is None - - # Attempting to reuse the old refresh token should fail + # Attempting to reuse the old (rotated) refresh token must fail. with pytest.raises(SecurityException) as exc_info: await auth_server.token( grant_type="refresh_token", @@ -196,6 +283,37 @@ async def test_refresh_token_rotation( assert exc_info.value.code == "INVALID_GRANT" +class TestRefreshTokenReuseDetection: + """OAuth 2.1 / RFC 9700: replaying a rotated refresh token revokes the whole family.""" + + @pytest.mark.asyncio + async def test_reuse_of_rotated_token_revokes_active_descendant(self, auth_server: AuthorizationServer) -> None: + initial = await auth_server.token( + grant_type="client_credentials", client_id="test-client", client_secret="test-secret" + ) + rt1 = initial["refresh_token"] + + # Rotate rt1 -> rt2 (rt2 is the live token). + second = await auth_server.token( + grant_type="refresh_token", client_id="test-client", client_secret="test-secret", refresh_token=rt1 + ) + rt2 = second["refresh_token"] + + # Replay the consumed rt1 -> reuse detected. + with pytest.raises(SecurityException) as exc_info: + await auth_server.token( + grant_type="refresh_token", client_id="test-client", client_secret="test-secret", refresh_token=rt1 + ) + assert exc_info.value.code == "INVALID_GRANT" + + # The whole family is now revoked: the previously-live rt2 no longer works. + with pytest.raises(SecurityException) as exc_info2: + await auth_server.token( + grant_type="refresh_token", client_id="test-client", client_secret="test-secret", refresh_token=rt2 + ) + assert exc_info2.value.code == "INVALID_GRANT" + + # --------------------------------------------------------------------------- # Error cases # --------------------------------------------------------------------------- diff --git a/tests/security/test_csrf.py b/tests/security/test_csrf.py index b45339fe..7b5337ed 100644 --- a/tests/security/test_csrf.py +++ b/tests/security/test_csrf.py @@ -16,6 +16,7 @@ from __future__ import annotations from types import SimpleNamespace +from typing import Any from unittest.mock import AsyncMock import pytest @@ -90,9 +91,9 @@ async def test_csrf_filter_safe_method_sets_cookie(self) -> None: assert "XSRF-TOKEN" in cookie_header @pytest.mark.asyncio - async def test_csrf_filter_unsafe_method_missing_cookie(self) -> None: - """POST without CSRF cookie returns 403.""" - csrf_filter = CsrfFilter() + async def test_csrf_filter_strict_mode_missing_cookie(self) -> None: + """In strict mode, a POST without the CSRF cookie returns 403.""" + csrf_filter = CsrfFilter(cookie_gated=False) request = _make_request( method="POST", headers={"X-XSRF-TOKEN": "some-token"}, @@ -104,6 +105,34 @@ async def test_csrf_filter_unsafe_method_missing_cookie(self) -> None: assert result.status_code == 403 call_next.assert_not_awaited() + @pytest.mark.asyncio + async def test_csrf_filter_cookie_gated_no_cookies_is_exempt(self) -> None: + """Default (cookie-gated) mode: a POST carrying NO cookies has no ambient + authority to abuse, so it is exempt from CSRF — keeping stateless API + clients working when CSRF is on by default.""" + csrf_filter = CsrfFilter() # cookie_gated=True by default + request = _make_request(method="POST", headers={"X-XSRF-TOKEN": "some-token"}) + response = Response(content="ok", status_code=200) + call_next = AsyncMock(return_value=response) + + result = await csrf_filter.do_filter(request, call_next) + + call_next.assert_awaited_once_with(request) + assert result is response + + @pytest.mark.asyncio + async def test_csrf_filter_cookie_present_requires_token(self) -> None: + """A POST that carries a (session) cookie but no valid CSRF pair is rejected, + even in cookie-gated mode — that is the actual CSRF scenario.""" + csrf_filter = CsrfFilter() + request = _make_request(method="POST", cookies={"SESSION": "abc"}) + call_next = AsyncMock() + + result = await csrf_filter.do_filter(request, call_next) + + assert result.status_code == 403 + call_next.assert_not_awaited() + @pytest.mark.asyncio async def test_csrf_filter_unsafe_method_missing_header(self) -> None: """POST with cookie but no header returns 403.""" @@ -171,3 +200,54 @@ async def test_csrf_filter_bearer_bypass(self) -> None: call_next.assert_awaited_once_with(request) assert result is response + + +class TestCsrfDefaultOn: + """CSRF is wired by default (secure-by-default) unless explicitly disabled.""" + + def _app(self, csrf: dict[str, object] | None = None) -> Any: + import contextlib + from collections.abc import AsyncIterator + + from pyfly.container.stereotypes import rest_controller + from pyfly.context.application_context import ApplicationContext + from pyfly.core.config import Config + from pyfly.web.adapters.starlette.app import create_app + from pyfly.web.mappings import get_mapping, request_mapping + + @rest_controller + @request_mapping("/api/ping") + class _PingController: + @get_mapping("/") + async def ping(self) -> dict: + return {"ok": True} + + security: dict[str, object] = {} + if csrf is not None: + security["csrf"] = csrf + ctx = ApplicationContext(Config({"pyfly": {"security": security}})) + ctx.register_bean(_PingController) + + @contextlib.asynccontextmanager + async def _lifespan(_app: Any) -> AsyncIterator[None]: + await ctx.start() + yield + await ctx.stop() + + return create_app(context=ctx, lifespan=_lifespan) + + def test_get_sets_xsrf_cookie_by_default(self) -> None: + from starlette.testclient import TestClient + + with TestClient(self._app()) as client: + resp = client.get("/api/ping/") + assert resp.status_code == 200 + assert "XSRF-TOKEN" in resp.cookies + + def test_can_be_disabled(self) -> None: + from starlette.testclient import TestClient + + with TestClient(self._app(csrf={"enabled": "false"})) as client: + resp = client.get("/api/ping/") + assert resp.status_code == 200 + assert "XSRF-TOKEN" not in resp.cookies diff --git a/tests/security/test_dpop_mtls.py b/tests/security/test_dpop_mtls.py new file mode 100644 index 00000000..7fc8b4f2 --- /dev/null +++ b/tests/security/test_dpop_mtls.py @@ -0,0 +1,271 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sender-constrained tokens — DPoP (RFC 9449) + mTLS (RFC 8705).""" + +from __future__ import annotations + +import base64 +import datetime +import hashlib +import json +import time + +import jwt as pyjwt +import pytest +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.x509.oid import NameOID + +from pyfly.kernel.exceptions import SecurityException +from pyfly.security.oauth2.dpop import ( + DPoPProofValidator, + certificate_thumbprint, + confirm_dpop_binding, + confirm_mtls_binding, + jwk_thumbprint, +) + + +def _ec_key() -> ec.EllipticCurvePrivateKey: + return ec.generate_private_key(ec.SECP256R1()) + + +def _public_jwk(key: ec.EllipticCurvePrivateKey) -> dict: + return json.loads(pyjwt.algorithms.ECAlgorithm.to_jwk(key.public_key())) + + +def _proof(key: ec.EllipticCurvePrivateKey, *, htm: str, htu: str, iat: int | None = None, jti: str = "id1") -> str: + claims = {"htm": htm, "htu": htu, "iat": iat if iat is not None else int(time.time()), "jti": jti} + return pyjwt.encode(claims, key, algorithm="ES256", headers={"typ": "dpop+jwt", "jwk": _public_jwk(key)}) + + +class TestJwkThumbprint: + def test_thumbprint_is_stable_and_base64url(self) -> None: + key = _ec_key() + jwk = _public_jwk(key) + t1 = jwk_thumbprint(jwk) + t2 = jwk_thumbprint(dict(reversed(list(jwk.items())))) # member order must not matter + assert t1 == t2 + assert "=" not in t1 and "+" not in t1 and "/" not in t1 + + +class TestDPoPProofValidator: + def test_valid_proof_returns_jkt(self) -> None: + key = _ec_key() + proof = _proof(key, htm="GET", htu="https://api.example.com/resource") + jkt = DPoPProofValidator().validate(proof, http_method="GET", http_url="https://api.example.com/resource") + assert jkt == jwk_thumbprint(_public_jwk(key)) + + def test_htu_query_is_ignored(self) -> None: + key = _ec_key() + proof = _proof(key, htm="GET", htu="https://api.example.com/resource") + # The request URL may carry a query string; htu compares origin+path only. + jkt = DPoPProofValidator().validate(proof, http_method="GET", http_url="https://api.example.com/resource?a=1") + assert jkt + + def test_method_mismatch_rejected(self) -> None: + key = _ec_key() + proof = _proof(key, htm="GET", htu="https://api.example.com/x") + with pytest.raises(SecurityException): + DPoPProofValidator().validate(proof, http_method="POST", http_url="https://api.example.com/x") + + def test_url_mismatch_rejected(self) -> None: + key = _ec_key() + proof = _proof(key, htm="GET", htu="https://api.example.com/x") + with pytest.raises(SecurityException): + DPoPProofValidator().validate(proof, http_method="GET", http_url="https://api.example.com/y") + + def test_stale_proof_rejected(self) -> None: + key = _ec_key() + proof = _proof(key, htm="GET", htu="https://api.example.com/x", iat=int(time.time()) - 600) + with pytest.raises(SecurityException): + DPoPProofValidator(max_age_seconds=60).validate( + proof, http_method="GET", http_url="https://api.example.com/x" + ) + + def test_replay_rejected(self) -> None: + key = _ec_key() + validator = DPoPProofValidator(replay_cache=set()) + proof = _proof(key, htm="GET", htu="https://api.example.com/x", jti="unique-1") + validator.validate(proof, http_method="GET", http_url="https://api.example.com/x") + with pytest.raises(SecurityException): + validator.validate(proof, http_method="GET", http_url="https://api.example.com/x") + + def test_symmetric_alg_rejected(self) -> None: + # A proof must be signed with an asymmetric key; alg=none/HS* is rejected. + forged = pyjwt.encode( + {"htm": "GET", "htu": "https://api/x", "iat": int(time.time()), "jti": "j"}, + "secret", + algorithm="HS256", + headers={"typ": "dpop+jwt", "jwk": {"kty": "oct"}}, + ) + with pytest.raises(SecurityException): + DPoPProofValidator().validate(forged, http_method="GET", http_url="https://api/x") + + +class TestDPoPBindingConfirmation: + def test_matching_jkt_passes(self) -> None: + confirm_dpop_binding({"cnf": {"jkt": "abc"}}, "abc") + + def test_mismatched_jkt_raises(self) -> None: + with pytest.raises(SecurityException): + confirm_dpop_binding({"cnf": {"jkt": "abc"}}, "different") + + def test_missing_cnf_raises(self) -> None: + with pytest.raises(SecurityException): + confirm_dpop_binding({"sub": "u"}, "abc") + + +def _self_signed_cert() -> bytes: + key = ec.generate_private_key(ec.SECP256R1()) + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "client")]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime(2020, 1, 1)) + .not_valid_after(datetime.datetime(2040, 1, 1)) + .sign(key, hashes.SHA256()) + ) + return cert.public_bytes(serialization.Encoding.PEM) + + +class TestResourceFilterDPoPEnforcement: + """The resource-server filter enforces proof-of-possession for cnf-bound tokens.""" + + def _filter_and_request(self, jkt: str, *, dpop_header: str | None): + from starlette.requests import Request + + from pyfly.security.context import SecurityContext + from pyfly.web.adapters.starlette.filters.oauth2_resource_filter import ( + ERROR_MODE_401, + OAuth2ResourceServerFilter, + ) + + class _FakeValidator: + def validate_and_context(self, token: str) -> tuple[dict, SecurityContext]: + return {"sub": "u", "cnf": {"jkt": jkt}}, SecurityContext(user_id="u") + + headers: list[tuple[bytes, bytes]] = [(b"authorization", b"DPoP the-access-token")] + if dpop_header is not None: + headers.append((b"dpop", dpop_header.encode("latin-1"))) + scope = { + "type": "http", + "method": "GET", + "path": "/r", + "headers": headers, + "query_string": b"", + "scheme": "https", + "server": ("api.example.com", 443), + } + flt = OAuth2ResourceServerFilter( + _FakeValidator(), # type: ignore[arg-type] + error_mode=ERROR_MODE_401, + enforce_sender_constraints=True, + ) + return flt, Request(scope) + + @pytest.mark.asyncio + async def test_valid_dpop_proof_accepted(self) -> None: + key = _ec_key() + jkt = jwk_thumbprint(_public_jwk(key)) + # ath must match the access token the filter passes ("the-access-token"). + from pyfly.security.oauth2.dpop import access_token_hash + + claims = { + "htm": "GET", + "htu": "https://api.example.com/r", + "iat": int(time.time()), + "jti": "p1", + "ath": access_token_hash("the-access-token"), + } + proof = pyjwt.encode(claims, key, algorithm="ES256", headers={"typ": "dpop+jwt", "jwk": _public_jwk(key)}) + flt, request = self._filter_and_request(jkt, dpop_header=proof) + + captured = {} + + async def call_next(r): + captured["ctx"] = r.state.security_context + from starlette.responses import PlainTextResponse + + return PlainTextResponse("ok") + + resp = await flt.do_filter(request, call_next) + assert resp.status_code == 200 + assert captured["ctx"].user_id == "u" + + @pytest.mark.asyncio + async def test_missing_dpop_proof_rejected(self) -> None: + key = _ec_key() + jkt = jwk_thumbprint(_public_jwk(key)) + flt, request = self._filter_and_request(jkt, dpop_header=None) + + async def call_next(r): + from starlette.responses import PlainTextResponse + + return PlainTextResponse("should not reach") + + resp = await flt.do_filter(request, call_next) + assert resp.status_code == 401 + + @pytest.mark.asyncio + async def test_wrong_key_proof_rejected(self) -> None: + bound_key = _ec_key() + jkt = jwk_thumbprint(_public_jwk(bound_key)) + # Attacker presents a proof signed with a DIFFERENT key. + attacker = _ec_key() + from pyfly.security.oauth2.dpop import access_token_hash + + claims = { + "htm": "GET", + "htu": "https://api.example.com/r", + "iat": int(time.time()), + "jti": "p2", + "ath": access_token_hash("the-access-token"), + } + proof = pyjwt.encode( + claims, attacker, algorithm="ES256", headers={"typ": "dpop+jwt", "jwk": _public_jwk(attacker)} + ) + flt, request = self._filter_and_request(jkt, dpop_header=proof) + + async def call_next(r): + from starlette.responses import PlainTextResponse + + return PlainTextResponse("should not reach") + + resp = await flt.do_filter(request, call_next) + assert resp.status_code == 401 + + +class TestMtlsBinding: + def test_thumbprint_matches_manual_sha256(self) -> None: + pem = _self_signed_cert() + cert = x509.load_pem_x509_certificate(pem) + expected = base64.urlsafe_b64encode(hashlib.sha256(cert.public_bytes(serialization.Encoding.DER)).digest()) + assert certificate_thumbprint(pem) == expected.rstrip(b"=").decode("ascii") + + def test_confirm_matching_cert(self) -> None: + pem = _self_signed_cert() + thumb = certificate_thumbprint(pem) + confirm_mtls_binding({"cnf": {"x5t#S256": thumb}}, pem) + + def test_confirm_mismatched_cert_raises(self) -> None: + pem = _self_signed_cert() + other = _self_signed_cert() + thumb = certificate_thumbprint(other) + with pytest.raises(SecurityException): + confirm_mtls_binding({"cnf": {"x5t#S256": thumb}}, pem) diff --git a/tests/security/test_form_login.py b/tests/security/test_form_login.py new file mode 100644 index 00000000..909ae0c4 --- /dev/null +++ b/tests/security/test_form_login.py @@ -0,0 +1,155 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Form-login filter.""" + +from __future__ import annotations + +from typing import Any +from urllib.parse import urlencode + +import pytest +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response + +from pyfly.security.authentication import DaoAuthenticationProvider, ProviderManager +from pyfly.security.password import BcryptPasswordEncoder +from pyfly.security.user_details import InMemoryUserDetailsService, UserDetails +from pyfly.session.session import HttpSession +from pyfly.web.adapters.starlette.filters.form_login_filter import FormLoginFilter + +_ENCODER = BcryptPasswordEncoder(rounds=4) +_SECURITY_CONTEXT_KEY = "SECURITY_CONTEXT" + + +def _manager() -> ProviderManager: + service = InMemoryUserDetailsService( + UserDetails(username="alice", password_hash=_ENCODER.hash("pw"), roles=["ADMIN"]) + ) + return ProviderManager(DaoAuthenticationProvider(service, _ENCODER)) + + +def _post(path: str, data: dict[str, str]) -> Request: + body = urlencode(data).encode() + + async def receive() -> dict[str, Any]: + return {"type": "http.request", "body": body, "more_body": False} + + scope = { + "type": "http", + "method": "POST", + "path": path, + "headers": [(b"content-type", b"application/x-www-form-urlencoded")], + "query_string": b"", + } + request = Request(scope, receive) + request.state.session = HttpSession("pre-auth-sid", {}) + return request + + +async def _call_next(request: Request) -> Response: + return PlainTextResponse("downstream") + + +class TestFormLoginFilter: + @pytest.mark.asyncio + async def test_valid_login_establishes_session_context(self) -> None: + flt = FormLoginFilter(_manager()) + request = _post("/login", {"username": "alice", "password": "pw"}) + resp = await flt.do_filter(request, _call_next) + assert resp.status_code == 302 + assert resp.headers["location"] == "/" + ctx = request.state.session.get_attribute(_SECURITY_CONTEXT_KEY) + assert ctx is not None and ctx.user_id == "alice" and ctx.has_role("ADMIN") + + @pytest.mark.asyncio + async def test_session_id_is_rotated_on_login(self) -> None: + flt = FormLoginFilter(_manager()) + request = _post("/login", {"username": "alice", "password": "pw"}) + await flt.do_filter(request, _call_next) + assert request.state.session.id != "pre-auth-sid" # fixation defense + + @pytest.mark.asyncio + async def test_invalid_login_redirects_to_failure(self) -> None: + flt = FormLoginFilter(_manager()) + request = _post("/login", {"username": "alice", "password": "WRONG"}) + resp = await flt.do_filter(request, _call_next) + assert resp.status_code == 302 + assert "error" in resp.headers["location"] + assert request.state.session.get_attribute(_SECURITY_CONTEXT_KEY) is None + + @pytest.mark.asyncio + async def test_non_login_request_passes_through(self) -> None: + flt = FormLoginFilter(_manager()) + request = _post("/other", {"x": "y"}) + resp = await flt.do_filter(request, _call_next) + assert resp.body == b"downstream" + + @pytest.mark.asyncio + async def test_json_mode_returns_200_and_401(self) -> None: + flt = FormLoginFilter(_manager(), use_redirect=False) + ok = await flt.do_filter(_post("/login", {"username": "alice", "password": "pw"}), _call_next) + assert ok.status_code == 200 + bad = await flt.do_filter(_post("/login", {"username": "alice", "password": "no"}), _call_next) + assert bad.status_code == 401 + + +class TestFormLoginAndLogoutAutoConfigEndToEnd: + """Form-login and logout auto-configs wire their filters into the live chain.""" + + def _client(self) -> Any: + import contextlib + from collections.abc import AsyncIterator + + from starlette.testclient import TestClient + + from pyfly.context.application_context import ApplicationContext + from pyfly.core.config import Config + from pyfly.web.adapters.starlette.app import create_app + + config = Config( + { + "pyfly": { + "security": { + "csrf": {"enabled": "false"}, + "form-login": { + "enabled": "true", + "use-redirect": "false", + "users": {"alice": {"password-hash": _ENCODER.hash("pw"), "roles": "ADMIN"}}, + }, + "logout": {"enabled": "true", "use-redirect": "false"}, + } + } + } + ) + ctx = ApplicationContext(config) + + @contextlib.asynccontextmanager + async def _lifespan(_app: Any) -> AsyncIterator[None]: + await ctx.start() + yield + await ctx.stop() + + return TestClient(create_app(context=ctx, lifespan=_lifespan)) + + def test_form_login_endpoint_authenticates(self) -> None: + with self._client() as client: + ok = client.post("/login", data={"username": "alice", "password": "pw"}) + assert ok.status_code == 200 and ok.json()["authenticated"] is True + bad = client.post("/login", data={"username": "alice", "password": "WRONG"}) + assert bad.status_code == 401 + + def test_logout_endpoint_wired(self) -> None: + with self._client() as client: + resp = client.post("/logout") + assert resp.status_code == 204 diff --git a/tests/security/test_http_basic.py b/tests/security/test_http_basic.py new file mode 100644 index 00000000..9e0a9982 --- /dev/null +++ b/tests/security/test_http_basic.py @@ -0,0 +1,193 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for HTTP Basic authentication (UserDetailsService + filter).""" + +from __future__ import annotations + +import base64 +from typing import Any + +import pytest +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response + +from pyfly.security.password import BcryptPasswordEncoder +from pyfly.security.user_details import InMemoryUserDetailsService, UserDetails, UserDetailsService +from pyfly.web.adapters.starlette.filters.http_basic_filter import HttpBasicAuthenticationFilter + +_ENCODER = BcryptPasswordEncoder(rounds=4) + + +def _service() -> InMemoryUserDetailsService: + return InMemoryUserDetailsService( + UserDetails(username="alice", password_hash=_ENCODER.hash("s3cret"), roles=["ADMIN"]), + UserDetails(username="bob", password_hash=_ENCODER.hash("hunter2"), roles=["USER"], enabled=False), + ) + + +def _request(auth_header: str | None = None) -> Request: + headers: list[tuple[bytes, bytes]] = [] + if auth_header is not None: + headers.append((b"authorization", auth_header.encode("latin-1"))) + scope: dict[str, Any] = {"type": "http", "method": "GET", "path": "/x", "headers": headers, "query_string": b""} + return Request(scope) + + +def _basic(username: str, password: str) -> str: + token = base64.b64encode(f"{username}:{password}".encode()).decode("ascii") + return f"Basic {token}" + + +async def _call_next(request: Request) -> Response: + return PlainTextResponse("ok") + + +class TestInMemoryUserDetailsService: + @pytest.mark.asyncio + async def test_loads_known_user(self) -> None: + svc = _service() + user = await svc.load_user_by_username("alice") + assert user is not None and user.username == "alice" + + @pytest.mark.asyncio + async def test_unknown_user_is_none(self) -> None: + assert await _service().load_user_by_username("nobody") is None + + def test_protocol_conformance(self) -> None: + assert isinstance(_service(), UserDetailsService) + + +class TestHttpBasicFilter: + @pytest.mark.asyncio + async def test_valid_credentials_set_authenticated_context(self) -> None: + f = HttpBasicAuthenticationFilter(_service(), _ENCODER) + request = _request(_basic("alice", "s3cret")) + response = await f.do_filter(request, _call_next) + assert response.status_code == 200 + ctx = request.state.security_context + assert ctx.is_authenticated + assert ctx.user_id == "alice" + assert ctx.has_role("ADMIN") + + @pytest.mark.asyncio + async def test_wrong_password_401_with_challenge(self) -> None: + f = HttpBasicAuthenticationFilter(_service(), _ENCODER, error_mode="401", realm="PyFly") + response = await f.do_filter(_request(_basic("alice", "wrong")), _call_next) + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == 'Basic realm="PyFly"' + + @pytest.mark.asyncio + async def test_unknown_user_401(self) -> None: + f = HttpBasicAuthenticationFilter(_service(), _ENCODER, error_mode="401") + response = await f.do_filter(_request(_basic("ghost", "x")), _call_next) + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_disabled_user_rejected(self) -> None: + f = HttpBasicAuthenticationFilter(_service(), _ENCODER, error_mode="401") + response = await f.do_filter(_request(_basic("bob", "hunter2")), _call_next) + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_wrong_password_anonymous_mode_falls_through(self) -> None: + f = HttpBasicAuthenticationFilter(_service(), _ENCODER, error_mode="anonymous") + request = _request(_basic("alice", "wrong")) + response = await f.do_filter(request, _call_next) + assert response.status_code == 200 # gate decides downstream + assert not request.state.security_context.is_authenticated + + @pytest.mark.asyncio + async def test_no_header_is_anonymous(self) -> None: + f = HttpBasicAuthenticationFilter(_service(), _ENCODER, error_mode="401") + request = _request(None) + response = await f.do_filter(request, _call_next) + assert response.status_code == 200 # missing creds fall through to the gate + assert not request.state.security_context.is_authenticated + + @pytest.mark.asyncio + async def test_non_basic_scheme_ignored(self) -> None: + f = HttpBasicAuthenticationFilter(_service(), _ENCODER, error_mode="401") + request = _request("Bearer sometoken") + response = await f.do_filter(request, _call_next) + assert response.status_code == 200 + assert not request.state.security_context.is_authenticated + + @pytest.mark.asyncio + async def test_malformed_base64_rejected(self) -> None: + f = HttpBasicAuthenticationFilter(_service(), _ENCODER, error_mode="401") + response = await f.do_filter(_request("Basic !!!not-base64!!!"), _call_next) + assert response.status_code == 401 + + +class TestHttpBasicAutoConfigEndToEnd: + """HTTP Basic wired from config, exercised through the full app stack.""" + + def _app(self) -> Any: + import contextlib + from collections.abc import AsyncIterator + + from pyfly.container.stereotypes import rest_controller + from pyfly.context.application_context import ApplicationContext + from pyfly.core.config import Config + from pyfly.web.adapters.starlette.app import create_app + from pyfly.web.mappings import get_mapping, request_mapping + + @rest_controller + @request_mapping("/api/secret") + class _SecretController: + @get_mapping("/") + async def secret(self) -> dict: + return {"ok": True} + + config = Config( + { + "pyfly": { + "security": { + "csrf": {"enabled": "false"}, + "http-basic": { + "enabled": "true", + "realm": "PyFly", + "error-mode": "401", + "users": {"alice": {"password-hash": _ENCODER.hash("s3cret"), "roles": "ADMIN"}}, + }, + } + } + } + ) + ctx = ApplicationContext(config) + ctx.register_bean(_SecretController) + + @contextlib.asynccontextmanager + async def _lifespan(_app: Any) -> AsyncIterator[None]: + await ctx.start() + yield + await ctx.stop() + + return create_app(context=ctx, lifespan=_lifespan) + + def test_valid_basic_credentials_pass(self) -> None: + from starlette.testclient import TestClient + + with TestClient(self._app()) as client: + resp = client.get("/api/secret/", headers={"Authorization": _basic("alice", "s3cret")}) + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + def test_bad_credentials_get_401_challenge(self) -> None: + from starlette.testclient import TestClient + + with TestClient(self._app()) as client: + resp = client.get("/api/secret/", headers={"Authorization": _basic("alice", "WRONG")}) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == 'Basic realm="PyFly"' diff --git a/tests/security/test_logout.py b/tests/security/test_logout.py new file mode 100644 index 00000000..bf16980f --- /dev/null +++ b/tests/security/test_logout.py @@ -0,0 +1,78 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generic logout filter.""" + +from __future__ import annotations + +import pytest +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response + +from pyfly.session.session import HttpSession +from pyfly.web.adapters.starlette.filters.logout_filter import LogoutFilter + + +def _post(path: str) -> Request: + scope = {"type": "http", "method": "POST", "path": path, "headers": [], "query_string": b""} + request = Request(scope) + session = HttpSession("sid", {}) + session.set_attribute("SECURITY_CONTEXT", object()) + request.state.session = session + return request + + +async def _call_next(request: Request) -> Response: + return PlainTextResponse("downstream") + + +class TestLogoutFilter: + @pytest.mark.asyncio + async def test_logout_invalidates_session_and_redirects(self) -> None: + flt = LogoutFilter() + request = _post("/logout") + resp = await flt.do_filter(request, _call_next) + assert resp.status_code == 302 + assert resp.headers["location"] == "/login?logout" + assert request.state.session.invalidated is True + + @pytest.mark.asyncio + async def test_logout_clears_configured_cookies(self) -> None: + flt = LogoutFilter(delete_cookies=["SESSION", "XSRF-TOKEN"]) + resp = await flt.do_filter(_post("/logout"), _call_next) + set_cookie = ( + resp.headers.getlist("set-cookie") if hasattr(resp.headers, "getlist") else [resp.headers["set-cookie"]] + ) + joined = " ".join(set_cookie) + assert "SESSION=" in joined and "XSRF-TOKEN=" in joined + + @pytest.mark.asyncio + async def test_non_logout_passes_through(self) -> None: + flt = LogoutFilter() + resp = await flt.do_filter(_post("/other"), _call_next) + assert resp.body == b"downstream" + + @pytest.mark.asyncio + async def test_json_mode_returns_204(self) -> None: + flt = LogoutFilter(use_redirect=False) + resp = await flt.do_filter(_post("/logout"), _call_next) + assert resp.status_code == 204 + + @pytest.mark.asyncio + async def test_custom_logout_url(self) -> None: + flt = LogoutFilter(logout_url="/sign-out") + resp = await flt.do_filter(_post("/sign-out"), _call_next) + assert resp.status_code == 302 + # The default path is no longer special. + passed = await flt.do_filter(_post("/logout"), _call_next) + assert passed.body == b"downstream" diff --git a/tests/security/test_method_filter.py b/tests/security/test_method_filter.py new file mode 100644 index 00000000..d24db6ef --- /dev/null +++ b/tests/security/test_method_filter.py @@ -0,0 +1,98 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""@pre_filter / @post_filter collection filtering.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from pyfly.context.request_context import RequestContext +from pyfly.security.context import SecurityContext +from pyfly.security.method_security import post_filter, pre_filter + + +@pytest.fixture(autouse=True) +def _clear_request_context() -> Any: + RequestContext.clear() + yield + RequestContext.clear() + + +def _ctx(user: str = "alice", roles: list[str] | None = None) -> None: + ctx = RequestContext.init() + ctx.security_context = SecurityContext(user_id=user, roles=roles or []) + + +def _docs() -> list[SimpleNamespace]: + return [SimpleNamespace(owner="alice"), SimpleNamespace(owner="bob"), SimpleNamespace(owner="alice")] + + +@post_filter("filterObject.owner == principal.user_id") +async def list_docs() -> list[SimpleNamespace]: + return _docs() + + +@post_filter("filterObject.owner == principal.user_id") +def list_docs_sync() -> list[SimpleNamespace]: + return _docs() + + +@pre_filter("filterObject.owner == principal.user_id", filter_target="docs") +async def save_all(docs: list[SimpleNamespace]) -> list[SimpleNamespace]: + return docs + + +@pre_filter("filterObject.owner == principal.user_id") +async def save_first_collection(docs: list[SimpleNamespace]) -> list[SimpleNamespace]: + return docs + + +class TestPostFilter: + @pytest.mark.asyncio + async def test_keeps_only_matching_elements(self) -> None: + _ctx("alice") + result = await list_docs() + assert [d.owner for d in result] == ["alice", "alice"] + + @pytest.mark.asyncio + async def test_preserves_collection_type(self) -> None: + _ctx("alice") + assert isinstance(await list_docs(), list) + + def test_sync_method(self) -> None: + _ctx("bob") + assert [d.owner for d in list_docs_sync()] == ["bob"] + + +class TestPreFilter: + @pytest.mark.asyncio + async def test_filters_named_argument(self) -> None: + _ctx("alice") + result = await save_all(docs=_docs()) + assert [d.owner for d in result] == ["alice", "alice"] + + @pytest.mark.asyncio + async def test_filters_positional_argument(self) -> None: + _ctx("bob") + result = await save_all(_docs()) + assert [d.owner for d in result] == ["bob"] + + @pytest.mark.asyncio + async def test_autodetects_first_collection(self) -> None: + _ctx("alice") + result = await save_first_collection(_docs()) + assert [d.owner for d in result] == ["alice", "alice"] diff --git a/tests/security/test_oauth2_endpoints.py b/tests/security/test_oauth2_endpoints.py new file mode 100644 index 00000000..6f5bcfd0 --- /dev/null +++ b/tests/security/test_oauth2_endpoints.py @@ -0,0 +1,260 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OAuth2 authorization-server HTTP endpoints (token / introspect / revoke / jwks).""" + +from __future__ import annotations + +from typing import Any + +import pytest +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from pyfly.security.oauth2.authorization_server import AuthorizationServer, InMemoryTokenStore +from pyfly.security.oauth2.client import ClientRegistration, InMemoryClientRegistrationRepository +from pyfly.security.oauth2.endpoints import AuthorizationServerEndpoints + +_SECRET = "authorization-server-secret-32bytes!!" + + +def _server() -> AuthorizationServer: + repo = InMemoryClientRegistrationRepository( + ClientRegistration( + registration_id="svc", + client_id="svc", + client_secret="svc-secret", + authorization_grant_type="client_credentials", + scopes=["read", "write"], + ) + ) + return AuthorizationServer( + secret=_SECRET, client_repository=repo, token_store=InMemoryTokenStore(), issuer="https://as" + ) + + +def _client(server: AuthorizationServer | None = None) -> TestClient: + server = server or _server() + app = Starlette(routes=AuthorizationServerEndpoints(server).routes()) + return TestClient(app) + + +class TestIntrospectMethod: + @pytest.mark.asyncio + async def test_active_access_token(self) -> None: + server = _server() + tok = await server.token(grant_type="client_credentials", client_id="svc", client_secret="svc-secret") + result = await server.introspect(tok["access_token"]) + assert result["active"] is True + assert result["sub"] == "svc" + assert result["scope"] == "read write" + + @pytest.mark.asyncio + async def test_active_refresh_token(self) -> None: + server = _server() + tok = await server.token(grant_type="client_credentials", client_id="svc", client_secret="svc-secret") + result = await server.introspect(tok["refresh_token"]) + assert result["active"] is True + assert result["token_type"] == "refresh_token" + + @pytest.mark.asyncio + async def test_unknown_token_inactive(self) -> None: + assert (await _server().introspect("garbage"))["active"] is False + + +class TestEndpoints: + def test_token_endpoint_issues_token(self) -> None: + resp = _client().post( + "/oauth2/token", + data={"grant_type": "client_credentials", "client_id": "svc", "client_secret": "svc-secret"}, + ) + assert resp.status_code == 200 + assert "access_token" in resp.json() + + def test_token_endpoint_bad_secret(self) -> None: + resp = _client().post( + "/oauth2/token", + data={"grant_type": "client_credentials", "client_id": "svc", "client_secret": "WRONG"}, + ) + assert resp.status_code == 401 + assert resp.json()["error"] == "invalid_client" + + def test_jwks_endpoint(self) -> None: + resp = _client().get("/oauth2/jwks") + assert resp.status_code == 200 + assert resp.json() == {"keys": []} # HS256 server publishes no keys + + def test_introspect_requires_client_auth(self) -> None: + resp = _client().post("/oauth2/introspect", data={"token": "x"}) + assert resp.status_code == 401 + + def test_introspect_active_then_revoke(self) -> None: + server = _server() + client = _client(server) + issued = client.post( + "/oauth2/token", + data={"grant_type": "client_credentials", "client_id": "svc", "client_secret": "svc-secret"}, + ).json() + rt = issued["refresh_token"] + auth = {"client_id": "svc", "client_secret": "svc-secret"} + + introspected = client.post("/oauth2/introspect", data={"token": rt, **auth}) + assert introspected.status_code == 200 + assert introspected.json()["active"] is True + + revoked = client.post("/oauth2/revoke", data={"token": rt, **auth}) + assert revoked.status_code == 200 + + again = client.post("/oauth2/introspect", data={"token": rt, **auth}) + assert again.json()["active"] is False + + +def _two_client_server() -> AuthorizationServer: + repo = InMemoryClientRegistrationRepository( + ClientRegistration( + registration_id="a", + client_id="a", + client_secret="a-secret", + authorization_grant_type="client_credentials", + scopes=["read"], + ), + ClientRegistration( + registration_id="b", + client_id="b", + client_secret="b-secret", + authorization_grant_type="client_credentials", + scopes=["read"], + ), + ClientRegistration( + registration_id="rs", + client_id="rs", + client_secret="rs-secret", + allow_introspection=True, + ), + ) + return AuthorizationServer(secret=_SECRET, client_repository=repo, token_store=InMemoryTokenStore()) + + +class TestEndpointAuthorization: + """RFC 7009/7662: a client may only act on its own tokens; introspection by a + non-owner is allowed only for designated resource-server clients.""" + + def test_introspect_other_clients_token_is_inactive(self) -> None: + server = _two_client_server() + client = _client(server) + b_token = client.post( + "/oauth2/token", data={"grant_type": "client_credentials", "client_id": "b", "client_secret": "b-secret"} + ).json()["access_token"] + # Client 'a' tries to introspect client 'b''s token. + resp = client.post("/oauth2/introspect", data={"token": b_token, "client_id": "a", "client_secret": "a-secret"}) + assert resp.json()["active"] is False + + def test_resource_server_client_can_introspect_any_token(self) -> None: + server = _two_client_server() + client = _client(server) + b_token = client.post( + "/oauth2/token", data={"grant_type": "client_credentials", "client_id": "b", "client_secret": "b-secret"} + ).json()["access_token"] + resp = client.post( + "/oauth2/introspect", data={"token": b_token, "client_id": "rs", "client_secret": "rs-secret"} + ) + assert resp.json()["active"] is True + + def test_revoke_other_clients_token_is_noop(self) -> None: + server = _two_client_server() + client = _client(server) + issued = client.post( + "/oauth2/token", data={"grant_type": "client_credentials", "client_id": "b", "client_secret": "b-secret"} + ).json() + b_refresh = issued["refresh_token"] + # Client 'a' attempts to revoke client 'b''s refresh token. + client.post("/oauth2/revoke", data={"token": b_refresh, "client_id": "a", "client_secret": "a-secret"}) + # 'b''s token is still usable -> it was NOT revoked. + refreshed = client.post( + "/oauth2/token", + data={ + "grant_type": "refresh_token", + "client_id": "b", + "client_secret": "b-secret", + "refresh_token": b_refresh, + }, + ) + assert refreshed.status_code == 200 + + def test_introspect_rejects_empty_credentials(self) -> None: + resp = _client(_two_client_server()).post("/oauth2/introspect", data={"token": "x"}) + assert resp.status_code == 401 + + def test_revoke_rejects_empty_credentials(self) -> None: + resp = _client(_two_client_server()).post("/oauth2/revoke", data={"token": "x"}) + assert resp.status_code == 401 + + +class TestOpaqueTokenIntrospector: + def test_active_token_builds_context(self, monkeypatch: pytest.MonkeyPatch) -> None: + from pyfly.security.oauth2.resource_server import OpaqueTokenIntrospector + + introspector = OpaqueTokenIntrospector( + "https://as/oauth2/introspect", client_id="rs", client_secret="rs-secret" + ) + + class _Resp: + status_code = 200 + + def json(self) -> dict[str, Any]: + return {"active": True, "sub": "user-1", "scope": "read write", "roles": ["ADMIN"]} + + class _C: + def __enter__(self) -> _C: + return self + + def __exit__(self, *a: object) -> None: + return None + + def post(self, *a: Any, **k: Any) -> _Resp: + return _Resp() + + import httpx + + monkeypatch.setattr(httpx, "Client", lambda *a, **k: _C()) + ctx = introspector.to_security_context("opaque-token") + assert ctx.user_id == "user-1" + assert "read" in ctx.permissions + + def test_inactive_token_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + from pyfly.kernel.exceptions import SecurityException + from pyfly.security.oauth2.resource_server import OpaqueTokenIntrospector + + introspector = OpaqueTokenIntrospector("https://as/introspect", client_id="rs", client_secret="s") + + class _Resp: + status_code = 200 + + def json(self) -> dict[str, Any]: + return {"active": False} + + class _C: + def __enter__(self) -> _C: + return self + + def __exit__(self, *a: object) -> None: + return None + + def post(self, *a: Any, **k: Any) -> _Resp: + return _Resp() + + import httpx + + monkeypatch.setattr(httpx, "Client", lambda *a, **k: _C()) + with pytest.raises(SecurityException): + introspector.introspect("opaque-token") diff --git a/tests/security/test_oauth2_mixup.py b/tests/security/test_oauth2_mixup.py new file mode 100644 index 00000000..8be13518 --- /dev/null +++ b/tests/security/test_oauth2_mixup.py @@ -0,0 +1,115 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RFC 9207 ``iss`` authorization-response validation (mix-up defense).""" + +from __future__ import annotations + +import json +from typing import Any + +import pytest +from starlette.requests import Request + +from pyfly.security.oauth2.client import ClientRegistration, InMemoryClientRegistrationRepository +from pyfly.security.oauth2.login import _OAUTH2_STATE_KEY, OAuth2LoginHandler +from pyfly.session.session import HttpSession + + +def _handler(**reg_overrides: Any) -> OAuth2LoginHandler: + base: dict[str, Any] = dict( + registration_id="acme", + client_id="cid", + client_secret="secret", + redirect_uri="https://app/cb", + scopes=["openid"], + authorization_uri="https://idp/auth", + token_uri="https://idp/token", + issuer_uri="https://good.example.com", + use_pkce=False, + ) + base.update(reg_overrides) + return OAuth2LoginHandler(InMemoryClientRegistrationRepository(ClientRegistration(**base))) + + +def _callback(query: str, *, state: str | None = "st") -> Request: + scope: dict[str, Any] = { + "type": "http", + "method": "GET", + "path": "/login/oauth2/code/acme", + "headers": [], + "query_string": query.encode(), + "path_params": {"registration_id": "acme"}, + } + request = Request(scope) + session = HttpSession("sid", {}) + if state is not None: + session.set_attribute(_OAUTH2_STATE_KEY, state) + request.state.session = session + return request + + +def _body(resp: Any) -> dict[str, Any]: + return json.loads(bytes(resp.body).decode("utf-8")) + + +@pytest.mark.asyncio +async def test_callback_aborts_on_iss_mismatch() -> None: + """A returned iss that differs from the registration's issuer aborts (mix-up).""" + handler = _handler() + resp = await handler._handle_callback(_callback("state=st&code=abc&iss=https://evil.example.com")) + assert resp.status_code == 400 + assert _body(resp)["error"] == "invalid_iss" + + +@pytest.mark.asyncio +async def test_callback_requires_iss_when_configured() -> None: + """With require_iss=True, a callback lacking the iss param is rejected.""" + handler = _handler(require_iss=True) + resp = await handler._handle_callback(_callback("state=st&code=abc")) + assert resp.status_code == 400 + assert _body(resp)["error"] == "invalid_iss" + + +@pytest.mark.asyncio +async def test_callback_iss_match_passes_to_token_exchange(monkeypatch: pytest.MonkeyPatch) -> None: + """A matching iss passes validation and proceeds to the token exchange.""" + handler = _handler(require_iss=True) + + async def _fake_exchange(*_a: Any, **_k: Any) -> dict[str, Any]: + return {} # empty -> handler returns 502 token_exchange_failed (proves we got past iss) + + monkeypatch.setattr(handler, "_exchange_code", _fake_exchange) + resp = await handler._handle_callback(_callback("state=st&code=abc&iss=https://good.example.com")) + assert resp.status_code == 502 + + +@pytest.mark.asyncio +async def test_callback_no_iss_param_allowed_when_not_required() -> None: + """Default (require_iss=False): a missing iss param does not block the flow.""" + handler = _handler() + + async def _fake_exchange(*_a: Any, **_k: Any) -> dict[str, Any]: + return {} + + monkeypatch_done = False + + async def _patched(*_a: Any, **_k: Any) -> dict[str, Any]: + nonlocal monkeypatch_done + monkeypatch_done = True + return {} + + handler._exchange_code = _patched # type: ignore[assignment] + resp = await handler._handle_callback(_callback("state=st&code=abc")) + assert resp.status_code == 502 + assert monkeypatch_done diff --git a/tests/security/test_oauth2_pkce.py b/tests/security/test_oauth2_pkce.py index d93b7856..53077918 100644 --- a/tests/security/test_oauth2_pkce.py +++ b/tests/security/test_oauth2_pkce.py @@ -42,6 +42,20 @@ def _handler(*, use_pkce: bool) -> OAuth2LoginHandler: return OAuth2LoginHandler(InMemoryClientRegistrationRepository(reg)) +def _reg(**overrides: Any) -> ClientRegistration: + base: dict[str, Any] = dict( + registration_id="acme", + client_id="cid", + client_secret="secret", + redirect_uri="https://app/cb", + scopes=["openid"], + authorization_uri="https://idp/auth", + token_uri="https://idp/token", + ) + base.update(overrides) + return ClientRegistration(**base) + + def _request(rid: str = "acme") -> Request: scope: dict[str, Any] = { "type": "http", @@ -88,6 +102,81 @@ async def test_authorization_omits_pkce_when_disabled() -> None: assert request.state.session.get_attribute(_OAUTH2_PKCE_VERIFIER_KEY) is None +def test_pkce_enabled_by_default() -> None: + """RFC 9700 / OAuth 2.1: PKCE is on by default for the authorization_code flow.""" + reg = ClientRegistration(registration_id="x", client_id="c") + assert reg.use_pkce is True + + +@pytest.mark.asyncio +async def test_authorization_adds_pkce_by_default() -> None: + """A registration that does not mention PKCE still gets a code_challenge.""" + handler = OAuth2LoginHandler(InMemoryClientRegistrationRepository(_reg())) + request = _request() + response = await handler._handle_authorization(request) + query = parse_qs(urlparse(response.headers["location"]).query) + assert query["code_challenge_method"] == ["S256"] + assert request.state.session.get_attribute(_OAUTH2_PKCE_VERIFIER_KEY) + + +@pytest.mark.asyncio +async def test_public_client_forces_pkce_even_if_disabled() -> None: + """A public client (no client_secret) gets PKCE even if it tries to opt out — + it has no other defense against authorization-code injection.""" + handler = OAuth2LoginHandler(InMemoryClientRegistrationRepository(_reg(client_secret="", use_pkce=False))) + request = _request() + response = await handler._handle_authorization(request) + query = parse_qs(urlparse(response.headers["location"]).query) + assert query["code_challenge_method"] == ["S256"] + assert request.state.session.get_attribute(_OAUTH2_PKCE_VERIFIER_KEY) + + +def test_client_autoconfig_enables_pkce_by_default() -> None: + from pyfly.core.config import Config + from pyfly.security.auto_configuration import OAuth2ClientAutoConfiguration + + cfg = Config( + { + "pyfly": { + "security": { + "oauth2": { + "client": { + "enabled": "true", + "registrations": {"acme": {"client-id": "c", "token-uri": "https://idp/token"}}, + } + } + } + } + } + ) + repo = OAuth2ClientAutoConfiguration().client_registration_repository(cfg) + reg = repo.find_by_registration_id("acme") + assert reg is not None and reg.use_pkce is True + + +def test_client_autoconfig_pkce_can_be_disabled() -> None: + from pyfly.core.config import Config + from pyfly.security.auto_configuration import OAuth2ClientAutoConfiguration + + cfg = Config( + { + "pyfly": { + "security": { + "oauth2": { + "client": { + "enabled": "true", + "registrations": {"acme": {"client-id": "c", "client-secret": "s", "use-pkce": "false"}}, + } + } + } + } + } + ) + repo = OAuth2ClientAutoConfiguration().client_registration_repository(cfg) + reg = repo.find_by_registration_id("acme") + assert reg is not None and reg.use_pkce is False + + @pytest.mark.asyncio async def test_exchange_code_sends_verifier(monkeypatch: pytest.MonkeyPatch) -> None: captured: dict[str, Any] = {} diff --git a/tests/security/test_password.py b/tests/security/test_password.py index 86c140a1..ae0312cd 100644 --- a/tests/security/test_password.py +++ b/tests/security/test_password.py @@ -11,11 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for PasswordEncoder protocol and BcryptPasswordEncoder adapter.""" +"""Tests for PasswordEncoder protocol and encoder adapters.""" from __future__ import annotations -from pyfly.security.password import BcryptPasswordEncoder, PasswordEncoder +from pyfly.security.password import ( + BcryptPasswordEncoder, + DelegatingPasswordEncoder, + PasswordEncoder, + Pbkdf2PasswordEncoder, + ScryptPasswordEncoder, + create_delegating_password_encoder, +) class TestBcryptPasswordEncoder: @@ -56,3 +63,115 @@ def test_empty_password_hashes(self): assert hashed.startswith("$2b$") assert encoder.verify("", hashed) is True assert encoder.verify("non-empty", hashed) is False + + +class TestPbkdf2PasswordEncoder: + def test_round_trip(self): + enc = Pbkdf2PasswordEncoder(iterations=10_000) + hashed = enc.hash("pw") + assert enc.verify("pw", hashed) is True + assert enc.verify("nope", hashed) is False + + def test_self_describing_format(self): + enc = Pbkdf2PasswordEncoder(iterations=10_000) + assert enc.hash("pw").startswith("sha256$10000$") + + def test_salt_is_random(self): + enc = Pbkdf2PasswordEncoder(iterations=10_000) + assert enc.hash("pw") != enc.hash("pw") + + def test_protocol_conformance(self): + assert isinstance(Pbkdf2PasswordEncoder(), PasswordEncoder) + + def test_corrupt_hash_returns_false(self): + assert Pbkdf2PasswordEncoder().verify("pw", "not-a-valid-hash") is False + + +class TestScryptPasswordEncoder: + def test_round_trip(self): + enc = ScryptPasswordEncoder(n=2**14) + hashed = enc.hash("pw") + assert enc.verify("pw", hashed) is True + assert enc.verify("bad", hashed) is False + + def test_protocol_conformance(self): + assert isinstance(ScryptPasswordEncoder(), PasswordEncoder) + + def test_corrupt_hash_returns_false(self): + assert ScryptPasswordEncoder().verify("pw", "garbage") is False + + +class TestDelegatingPasswordEncoder: + def _enc(self) -> DelegatingPasswordEncoder: + return DelegatingPasswordEncoder( + {"bcrypt": BcryptPasswordEncoder(rounds=4), "pbkdf2": Pbkdf2PasswordEncoder(iterations=10_000)}, + encoding_id="bcrypt", + ) + + def test_hash_is_prefixed_with_default_id(self): + assert self._enc().hash("pw").startswith("{bcrypt}$2b$") + + def test_verify_round_trip(self): + enc = self._enc() + assert enc.verify("pw", enc.hash("pw")) is True + assert enc.verify("bad", enc.hash("pw")) is False + + def test_verify_dispatches_by_prefix(self): + enc = self._enc() + inner = Pbkdf2PasswordEncoder(iterations=10_000).hash("pw") + assert enc.verify("pw", "{pbkdf2}" + inner) is True + assert enc.verify("bad", "{pbkdf2}" + inner) is False + + def test_unknown_prefix_returns_false(self): + assert self._enc().verify("pw", "{md5}deadbeef") is False + + def test_missing_prefix_returns_false(self): + assert self._enc().verify("pw", "$2b$unprefixed") is False + + def test_upgrade_encoding_true_for_non_default_id(self): + enc = self._enc() + stored = "{pbkdf2}" + Pbkdf2PasswordEncoder(iterations=10_000).hash("pw") + assert enc.upgrade_encoding(stored) is True + + def test_upgrade_encoding_false_for_default_id(self): + enc = self._enc() + assert enc.upgrade_encoding(enc.hash("pw")) is False + + def test_upgrade_encoding_true_for_unprefixed(self): + assert self._enc().upgrade_encoding("$2b$legacy") is True + + def test_unknown_default_encoding_id_rejected(self): + import pytest + + with pytest.raises(ValueError, match="encoding_id"): + DelegatingPasswordEncoder({"bcrypt": BcryptPasswordEncoder(rounds=4)}, encoding_id="pbkdf2") + + def test_protocol_conformance(self): + assert isinstance(self._enc(), PasswordEncoder) + + +class TestPasswordEncoderFactory: + def test_create_delegating_default_is_bcrypt(self): + enc = create_delegating_password_encoder(bcrypt_rounds=4) + hashed = enc.hash("pw") + assert hashed.startswith("{bcrypt}") + assert enc.verify("pw", hashed) is True + + def test_create_delegating_recognizes_pbkdf2_and_scrypt(self): + enc = create_delegating_password_encoder(bcrypt_rounds=4) + pbkdf2 = "{pbkdf2}" + Pbkdf2PasswordEncoder(iterations=10_000).hash("pw") + scrypt = "{scrypt}" + ScryptPasswordEncoder(n=2**14).hash("pw") + assert enc.verify("pw", pbkdf2) is True + assert enc.verify("pw", scrypt) is True + + +class TestDelegatingEncoderAutoConfig: + def test_opt_in_provides_delegating_encoder(self): + from pyfly.core.config import Config + from pyfly.security.auto_configuration import PasswordEncoderAutoConfiguration + + cfg = Config({"pyfly": {"security": {"password": {"delegating": {"enabled": "true"}, "bcrypt-rounds": 4}}}}) + enc = PasswordEncoderAutoConfiguration().delegating_password_encoder(cfg) + hashed = enc.hash("pw") + assert hashed.startswith("{bcrypt}") + assert enc.verify("pw", hashed) is True diff --git a/tests/security/test_permission_evaluator.py b/tests/security/test_permission_evaluator.py new file mode 100644 index 00000000..266ce6b8 --- /dev/null +++ b/tests/security/test_permission_evaluator.py @@ -0,0 +1,68 @@ +# Copyright 2026 Firefly Software Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PermissionEvaluator SPI wired into the method-security expression engine.""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +import pytest + +from pyfly.security.context import SecurityContext +from pyfly.security.expression import evaluate_security_expression, set_permission_evaluator +from pyfly.security.permission import PermissionEvaluator + + +class _OwnerEvaluator: + def has_permission(self, context: Any, target: Any, permission: str, *, target_type: str | None = None) -> bool: + return target == "owned" and permission == "read" + + +@pytest.fixture(autouse=True) +def _reset_evaluator() -> Iterator[None]: + yield + set_permission_evaluator(None) + + +def test_protocol_conformance() -> None: + assert isinstance(_OwnerEvaluator(), PermissionEvaluator) + + +def test_evaluator_receives_target_and_permission() -> None: + set_permission_evaluator(_OwnerEvaluator()) + ctx = SecurityContext(user_id="u") + assert evaluate_security_expression("hasPermission(#doc, 'read')", ctx, args={"doc": "owned"}) is True + assert evaluate_security_expression("hasPermission(#doc, 'read')", ctx, args={"doc": "other"}) is False + + +def test_without_evaluator_falls_back_to_context_permission() -> None: + ctx = SecurityContext(user_id="u", permissions=["read"]) + # No evaluator installed: target is ignored, the permission is checked on the context. + assert evaluate_security_expression("hasPermission(#doc, 'read')", ctx, args={"doc": "x"}) is True + assert evaluate_security_expression("hasPermission(#doc, 'write')", ctx, args={"doc": "x"}) is False + + +def test_three_arg_form_passes_target_type() -> None: + captured: dict[str, Any] = {} + + class _Capture: + def has_permission(self, context: Any, target: Any, permission: str, *, target_type: str | None = None) -> bool: + captured.update(target=target, permission=permission, target_type=target_type) + return True + + set_permission_evaluator(_Capture()) + ctx = SecurityContext(user_id="u") + assert evaluate_security_expression("hasPermission(#id, 'Document', 'read')", ctx, args={"id": "7"}) is True + assert captured == {"target": "7", "permission": "read", "target_type": "Document"} diff --git a/tests/security/test_security_hardening.py b/tests/security/test_security_hardening.py index 1f7d3398..6983e551 100644 --- a/tests/security/test_security_hardening.py +++ b/tests/security/test_security_hardening.py @@ -32,6 +32,7 @@ from starlette.requests import Request from starlette.responses import PlainTextResponse, Response +from pyfly.core.config import Config from pyfly.kernel.exceptions import SecurityException from pyfly.security.context import SecurityContext from pyfly.security.http_security import HttpSecurity @@ -145,3 +146,110 @@ async def test_any_request_permit_all_restores_open_behavior(self) -> None: async def test_empty_httpsecurity_is_a_noop(self) -> None: response = await HttpSecurity().build().do_filter(self._request("/anything"), self._call_next) assert response.status_code == 200 + + +class TestHttpMethodMatchers: + """URL authorization rules can be scoped to specific HTTP methods (Spring's + ``requestMatchers(HttpMethod.X, ...)``), so a read can be public while a write + on the same path requires a role.""" + + @staticmethod + def _request(path: str, method: str = "GET", ctx: SecurityContext | None = None) -> Request: + scope: dict[str, Any] = {"type": "http", "method": method, "path": path, "headers": [], "query_string": b""} + request = Request(scope) + request.state.security_context = ctx or SecurityContext.anonymous() + return request + + @staticmethod + async def _call_next(request: Request) -> Response: + return PlainTextResponse("ok") + + @pytest.mark.asyncio + async def test_method_specific_rule_only_matches_that_method(self) -> None: + sec = HttpSecurity() + builder = sec.authorize_requests() + builder.request_matchers("/api/**", methods="POST").authenticated() + builder.request_matchers("/api/**").permit_all() + built = sec.build() + # GET falls past the POST rule to the permit-all rule. + assert (await built.do_filter(self._request("/api/x", "GET"), self._call_next)).status_code == 200 + # POST (anonymous) matches the method-scoped authenticated rule. + assert (await built.do_filter(self._request("/api/x", "POST"), self._call_next)).status_code == 401 + + @pytest.mark.asyncio + async def test_method_list_matches_any_listed(self) -> None: + sec = HttpSecurity() + sec.authorize_requests().request_matchers("/api/**", methods=["PUT", "DELETE"]).has_role("ADMIN") + built = sec.build() + # GET matches no rule -> deny-by-default 403. + assert (await built.do_filter(self._request("/api/x", "GET"), self._call_next)).status_code == 403 + # DELETE matches the method-scoped role rule (anonymous -> 401). + assert (await built.do_filter(self._request("/api/x", "DELETE"), self._call_next)).status_code == 401 + + @pytest.mark.asyncio + async def test_no_method_means_any_method(self) -> None: + sec = HttpSecurity() + sec.authorize_requests().request_matchers("/api/**").permit_all() + built = sec.build() + for method in ("GET", "POST", "PUT", "PATCH", "DELETE"): + resp = await built.do_filter(self._request("/api/x", method), self._call_next) + assert resp.status_code == 200 + + +class TestSigningSecretHardening: + """The auto-config composition root refuses to sign tokens with the built-in + placeholder secret or a secret too short for the HMAC algorithm (RFC 7518 §3.2).""" + + def _as_config(self, **overrides: Any) -> Config: + server: dict[str, Any] = {"enabled": "true"} + server.update(overrides) + return Config({"pyfly": {"security": {"oauth2": {"authorization-server": server}}}}) + + def _build_as(self, config: Config) -> AuthorizationServer: + from pyfly.container.container import Container + from pyfly.security.auto_configuration import OAuth2AuthorizationServerAutoConfiguration + + ac = OAuth2AuthorizationServerAutoConfiguration() + repo = InMemoryClientRegistrationRepository() + return ac.authorization_server(config, repo, Container()) + + def test_authorization_server_bean_rejects_placeholder_secret(self) -> None: + with pytest.raises(SecurityException) as exc: + self._build_as(self._as_config()) # no secret -> placeholder default + assert exc.value.code == "INSECURE_SIGNING_SECRET" + + def test_authorization_server_bean_rejects_short_secret(self) -> None: + with pytest.raises(SecurityException) as exc: + self._build_as(self._as_config(secret="too-short")) + assert exc.value.code == "WEAK_SIGNING_SECRET" + + def test_authorization_server_bean_accepts_strong_secret(self) -> None: + server = self._build_as(self._as_config(secret="a" * 32)) + assert isinstance(server, AuthorizationServer) + + def test_jwt_service_bean_rejects_placeholder_secret_when_filter_enabled(self) -> None: + from pyfly.core.config import Config + from pyfly.security.auto_configuration import JwtAutoConfiguration + + cfg = Config({"pyfly": {"security": {"enabled": "true", "jwt": {"filter": {"enabled": "true"}}}}}) + with pytest.raises(SecurityException) as exc: + JwtAutoConfiguration().jwt_service(cfg) + assert exc.value.code == "INSECURE_SIGNING_SECRET" + + def test_jwt_service_without_filter_tolerates_placeholder(self) -> None: + """A resource-server-only app (symmetric JWT filter off) must still boot even + though the symmetric signer is left at its (unused) placeholder secret.""" + from pyfly.core.config import Config + from pyfly.security.auto_configuration import JwtAutoConfiguration + + cfg = Config({"pyfly": {"security": {"enabled": "true"}}}) + svc = JwtAutoConfiguration().jwt_service(cfg) + assert svc is not None + + def test_jwt_service_bean_accepts_strong_secret(self) -> None: + from pyfly.core.config import Config + from pyfly.security.auto_configuration import JwtAutoConfiguration + + cfg = Config({"pyfly": {"security": {"jwt": {"filter": {"enabled": "true"}, "secret": "z" * 40}}}}) + svc = JwtAutoConfiguration().jwt_service(cfg) + assert svc is not None