Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Annotated
from urllib.parse import parse_qs

from fastapi import APIRouter, Depends, Request, Response
from fastapi import APIRouter, BackgroundTasks, Depends, Request, Response
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -167,6 +167,7 @@ def _password_login_requires_verified_email() -> bool:
async def login(
payload: LoginRequest,
request: Request,
background_tasks: BackgroundTasks,
db_session: Annotated[AsyncSession, Depends(get_database_session)],
user_service: Annotated[UserService, Depends(get_user_service)],
token_service: Annotated[TokenService, Depends(get_token_service)],
Expand All @@ -187,8 +188,8 @@ async def login(

if user is None or user.password_hash is None:
user_service.dummy_verify()
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.failure",
actor_type="user",
success=False,
Expand All @@ -205,8 +206,8 @@ async def login(
try:
await brute_force_service.ensure_not_locked(str(user.id))
except BruteForceProtectionError as exc:
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.failure",
actor_type="user",
success=False,
Expand All @@ -231,8 +232,8 @@ async def login(
ip_address=client_ip,
)
except BruteForceProtectionError as exc:
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.failure",
actor_type="user",
success=False,
Expand All @@ -249,8 +250,8 @@ async def login(
)

if failure_decision.locked:
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.locked",
actor_type="user",
success=False,
Expand All @@ -275,8 +276,8 @@ async def login(
"distributed_attack": failure_decision.distributed_attack,
},
)
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.failure",
actor_type="user",
success=False,
Expand All @@ -292,8 +293,8 @@ async def login(
headers={"Retry-After": str(failure_decision.retry_after or 1)},
)

await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.failure",
actor_type="user",
success=False,
Expand All @@ -311,8 +312,8 @@ async def login(
if _password_login_requires_verified_email() and not bool(
getattr(user, "email_verified", False)
):
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.failure",
actor_type="user",
success=False,
Expand Down Expand Up @@ -344,17 +345,17 @@ async def login(
headers=exc.headers,
)

await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.otp_required",
actor_type="user",
success=True,
request=request,
actor_id=challenge.user_id,
metadata={"provider": "password"},
)
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="otp.sent",
actor_type="user",
success=True,
Expand All @@ -375,8 +376,8 @@ async def login(
user_agent=user_agent,
)
except BruteForceProtectionError as exc:
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.failure",
actor_type="user",
success=False,
Expand Down Expand Up @@ -429,8 +430,8 @@ async def login(
)
return _error_response(status_code=exc.status_code, detail=exc.detail, code=exc.code)

await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.success",
actor_type="user",
success=True,
Expand All @@ -439,17 +440,17 @@ async def login(
metadata={"provider": "password"},
)
if suspicious_login.suspicious:
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.login.suspicious",
actor_type="user",
success=True,
request=request,
actor_id=str(user.id),
metadata={"provider": "password", **suspicious_login.metadata},
)
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="session.created",
actor_type="user",
success=True,
Expand All @@ -467,8 +468,8 @@ async def login(
"provider": "password",
},
)
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="token.issued",
actor_type="user",
success=True,
Expand Down
12 changes: 6 additions & 6 deletions app/routers/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ async def signup(
password=payload.password,
)
except LifecycleServiceError as exc:
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.signup.accepted",
actor_type="user",
success=False,
Expand All @@ -97,8 +97,8 @@ async def signup(
)
return _error_response(status_code=exc.status_code, detail=exc.detail, code=exc.code)

await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.signup.accepted",
actor_type="user",
success=True,
Expand All @@ -121,8 +121,8 @@ async def signup(
to_email=user.email,
verification_link=signup_result.verification_link,
)
await audit_service.record(
db=db_session,
audit_service.enqueue_record(
background_tasks,
event_type="user.created",
actor_type="user",
success=True,
Expand Down
100 changes: 92 additions & 8 deletions app/services/audit_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,27 @@

import ipaddress
import re
from copy import deepcopy
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any
from uuid import NAMESPACE_URL, UUID, uuid5

import structlog
from fastapi import Request
from fastapi import BackgroundTasks, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.config import reloadable_singleton
from app.core.client_ip import extract_client_ip as extract_trusted_client_ip
from app.db.session import get_session_factory
from app.models.audit_event import AuditActorType, AuditEvent
from app.services.pagination import CursorPage, apply_created_at_cursor, build_page, decode_cursor
from app.services.pagination import (
CursorPage,
apply_created_at_cursor,
build_page,
decode_cursor,
)

logger = structlog.get_logger(__name__)

Expand All @@ -35,6 +43,40 @@
_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")


@dataclass(frozen=True)
class AuditRequestSnapshot:
"""Serializable request context captured for deferred audit writes."""

headers: dict[str, str]
client_host: str | None
correlation_id: str | None

@classmethod
def capture(cls, request: Request) -> AuditRequestSnapshot:
"""Capture the minimum request fields needed for a later audit write."""
raw_correlation_id = getattr(getattr(request, "state", None), "correlation_id", None)
if raw_correlation_id is None:
raw_correlation_id = request.headers.get("x-correlation-id")

return cls(
headers={str(key).lower(): value for key, value in request.headers.items()},
client_host=getattr(getattr(request, "client", None), "host", None),
correlation_id=None if raw_correlation_id is None else str(raw_correlation_id),
)

def to_request_like(self) -> Any:
"""Build a lightweight request-like object for existing extraction helpers."""
client = None
if self.client_host is not None:
client = SimpleNamespace(host=self.client_host)

return SimpleNamespace(
headers=self.headers,
client=client,
state=SimpleNamespace(correlation_id=self.correlation_id),
)


def _is_sensitive_key(key: str) -> bool:
"""Return True when metadata key likely contains sensitive data."""
normalized = key.strip().lower().replace("-", "_")
Expand Down Expand Up @@ -114,6 +156,44 @@ def _sanitize_metadata(metadata: dict[str, Any] | None) -> dict[str, Any] | None
class AuditService:
"""Persist immutable audit events without affecting auth outcomes."""

def enqueue_record(
self,
background_tasks: BackgroundTasks,
*,
event_type: str,
actor_type: str,
success: bool,
request: Request,
actor_id: str | None = None,
target_id: str | None = None,
target_type: str | None = None,
failure_reason: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Schedule one audit write to run after the response has been sent."""
try:
background_tasks.add_task(
self.record,
db=None,
event_type=event_type,
actor_type=actor_type,
success=success,
request=AuditRequestSnapshot.capture(request),
actor_id=actor_id,
target_id=target_id,
target_type=target_type,
failure_reason=failure_reason,
metadata=deepcopy(metadata) if metadata is not None else None,
)
except Exception as exc:
logger.error(
"audit_enqueue_failed",
event_type=event_type,
actor_type=actor_type,
success=success,
error=str(exc),
)

async def list_events_page(
self,
db_session: AsyncSession,
Expand Down Expand Up @@ -150,18 +230,22 @@ async def list_events_page(

async def record(
self,
db: AsyncSession,
db: AsyncSession | Any | None,
event_type: str,
actor_type: str,
success: bool,
request: Request,
request: Request | AuditRequestSnapshot,
actor_id: str | None = None,
target_id: str | None = None,
target_type: str | None = None,
failure_reason: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Write one append-only audit row and swallow write failures."""
request_like = (
request.to_request_like() if isinstance(request, AuditRequestSnapshot) else request
)

try:
normalized_actor_type = AuditActorType(actor_type)
except ValueError:
Expand All @@ -173,16 +257,16 @@ async def record(
actor_type=normalized_actor_type,
target_id=_coerce_uuid(target_id),
target_type=target_type.strip() if target_type else None,
ip_address=_coerce_ip(extract_trusted_client_ip(request)),
user_agent=request.headers.get("user-agent"),
correlation_id=_extract_correlation_id(request),
ip_address=_coerce_ip(extract_trusted_client_ip(request_like)),
user_agent=request_like.headers.get("user-agent"),
correlation_id=_extract_correlation_id(request_like),
success=success,
failure_reason=failure_reason.strip() if failure_reason else None,
event_metadata=_sanitize_metadata(metadata),
)

try:
if isinstance(db, AsyncSession):
if db is None or isinstance(db, AsyncSession):
session_factory = get_session_factory()
async with session_factory() as audit_db:
audit_db.add(audit_event)
Expand Down
Loading