From 74be0675b2f75d79dc2b4f7d42a5f083f5f952a1 Mon Sep 17 00:00:00 2001 From: cosmin chauciuc Date: Mon, 8 Jun 2026 17:49:11 +0300 Subject: [PATCH] =?UTF-8?q?feat(phase4):=20scheduling,=20distribution=20&?= =?UTF-8?q?=20governance=20(M1=E2=80=93M4)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Phase 4 of the platform roadmap full-stack: audit logging, scheduled reports + notification adapters, a data-policy enforcement engine, and cost/usage analytics. Migrations 009–012. M1 — Audit log - AuditEvent model; fire-and-forget audit_service.record (inline, SAVEPOINT-isolated, never raises) wired at login, connection CRUD/rotation, and query executed/blocked - Admin org-scoped read + CSV export API; admin-gated AuditPage M2 — Scheduled reports - Schedule model + schedule_service (cron via croniter [scheduling] extra, threshold alerting, run/deliver pipeline) - In-process scheduler loop: atomic FOR UPDATE SKIP LOCKED claim → run_schedule job via the Phase 0 queue (works for inprocess + arq) - Pluggable app/notifications adapters (email SMTP / Slack webhook / log fallback); magic-link delivery now routed through the notifier (closes the Phase 1 deferral) - /schedules CRUD + run-now; SchedulesPage with target picker M3 — Policy engine - DataPolicy model + policy_service: most-restrictive resolve_effective, fail-closed enforce_sql (sqlglot: allow/block tables, blocked columns, row-filter injection), post-exec column masking, row/runtime caps - Enforced in the query pipeline before the connector on both NL and raw-SQL paths, re-applied across error-handler retries; masking applied in place so no PII reaches the interpreter LLM - /connections/{id}/policies CRUD; admin-gated PoliciesPage M4 — Cost & usage analytics - CostAttribution model + QueryResult.stats connector seam (BigQuery scanned-bytes/slot-ms populated); cost_service (pluggable pricing, fire-and-forget capture, aggregations) recorded after each execution - Admin /analytics/{usage,cost,slowest,tables} API; admin-gated AnalyticsPage Tests: +50 unit tests (audit, notifications, schedules, policy, cost); 220 backend tests pass. ruff clean (modulo the standard FastAPI B008). Frontend tsc + eslint + build clean. CI installs the [scheduling] extra. Deferred: human-approval gate + certified-metrics-only mode. Known caveat: the saved-query result cache is not yet role-aware — key it by effective policy before relying on masking/row-filters there. Co-Authored-By: Claude Opus 4.8 (1M context) --- .github/workflows/ci.yml | 2 +- backend/alembic/versions/009_audit_events.py | 74 ++++ backend/alembic/versions/010_schedules.py | 80 ++++ backend/alembic/versions/011_data_policies.py | 68 +++ .../alembic/versions/012_cost_attribution.py | 85 ++++ backend/app/api/v1/endpoints/analytics.py | 65 +++ backend/app/api/v1/endpoints/audit.py | 93 ++++ backend/app/api/v1/endpoints/auth.py | 23 +- backend/app/api/v1/endpoints/policies.py | 72 ++++ backend/app/api/v1/endpoints/schedules.py | 79 ++++ backend/app/api/v1/router.py | 8 + backend/app/api/v1/schemas/analytics.py | 29 ++ backend/app/api/v1/schemas/audit_event.py | 16 + backend/app/api/v1/schemas/data_policy.py | 49 +++ backend/app/api/v1/schemas/schedule.py | 67 +++ backend/app/config.py | 27 ++ backend/app/connectors/base_connector.py | 5 + backend/app/connectors/bigquery/connector.py | 10 + backend/app/db/models/__init__.py | 8 + backend/app/db/models/audit_event.py | 41 ++ backend/app/db/models/cost_attribution.py | 54 +++ backend/app/db/models/data_policy.py | 64 +++ backend/app/db/models/schedule.py | 71 ++++ backend/app/jobs/scheduler.py | 117 +++++ backend/app/jobs/tasks.py | 3 + backend/app/main.py | 6 + backend/app/notifications/__init__.py | 68 +++ backend/app/notifications/base.py | 36 ++ backend/app/notifications/email.py | 39 ++ backend/app/notifications/log.py | 28 ++ backend/app/notifications/slack.py | 21 + backend/app/services/audit_service.py | 139 ++++++ backend/app/services/auth_service.py | 21 +- backend/app/services/connection_service.py | 37 ++ backend/app/services/cost_service.py | 208 +++++++++ backend/app/services/policy_service.py | 361 ++++++++++++++++ backend/app/services/query_service.py | 151 ++++++- backend/app/services/schedule_service.py | 400 ++++++++++++++++++ backend/pyproject.toml | 3 + backend/tests/test_audit_service.py | 82 ++++ backend/tests/test_cost_service.py | 55 +++ backend/tests/test_notifications.py | 48 +++ backend/tests/test_policy_service.py | 160 +++++++ backend/tests/test_schedule_service.py | 64 +++ frontend/src/App.tsx | 8 + frontend/src/api/analyticsApi.ts | 15 + frontend/src/api/auditApi.ts | 45 ++ frontend/src/api/policiesApi.ts | 15 + frontend/src/api/schedulesApi.ts | 13 + frontend/src/components/layout/AppLayout.tsx | 10 +- frontend/src/hooks/useAnalytics.ts | 34 ++ frontend/src/hooks/useAudit.ts | 19 + frontend/src/hooks/usePolicies.ts | 36 ++ frontend/src/hooks/useSchedules.ts | 42 ++ frontend/src/pages/AnalyticsPage.tsx | 208 +++++++++ frontend/src/pages/AuditPage.tsx | 208 +++++++++ frontend/src/pages/PoliciesPage.tsx | 314 ++++++++++++++ frontend/src/pages/SchedulesPage.tsx | 342 +++++++++++++++ frontend/src/types/api.ts | 95 +++++ planfull.md | 2 +- 60 files changed, 4523 insertions(+), 20 deletions(-) create mode 100644 backend/alembic/versions/009_audit_events.py create mode 100644 backend/alembic/versions/010_schedules.py create mode 100644 backend/alembic/versions/011_data_policies.py create mode 100644 backend/alembic/versions/012_cost_attribution.py create mode 100644 backend/app/api/v1/endpoints/analytics.py create mode 100644 backend/app/api/v1/endpoints/audit.py create mode 100644 backend/app/api/v1/endpoints/policies.py create mode 100644 backend/app/api/v1/endpoints/schedules.py create mode 100644 backend/app/api/v1/schemas/analytics.py create mode 100644 backend/app/api/v1/schemas/audit_event.py create mode 100644 backend/app/api/v1/schemas/data_policy.py create mode 100644 backend/app/api/v1/schemas/schedule.py create mode 100644 backend/app/db/models/audit_event.py create mode 100644 backend/app/db/models/cost_attribution.py create mode 100644 backend/app/db/models/data_policy.py create mode 100644 backend/app/db/models/schedule.py create mode 100644 backend/app/jobs/scheduler.py create mode 100644 backend/app/notifications/__init__.py create mode 100644 backend/app/notifications/base.py create mode 100644 backend/app/notifications/email.py create mode 100644 backend/app/notifications/log.py create mode 100644 backend/app/notifications/slack.py create mode 100644 backend/app/services/audit_service.py create mode 100644 backend/app/services/cost_service.py create mode 100644 backend/app/services/policy_service.py create mode 100644 backend/app/services/schedule_service.py create mode 100644 backend/tests/test_audit_service.py create mode 100644 backend/tests/test_cost_service.py create mode 100644 backend/tests/test_notifications.py create mode 100644 backend/tests/test_policy_service.py create mode 100644 backend/tests/test_schedule_service.py create mode 100644 frontend/src/api/analyticsApi.ts create mode 100644 frontend/src/api/auditApi.ts create mode 100644 frontend/src/api/policiesApi.ts create mode 100644 frontend/src/api/schedulesApi.ts create mode 100644 frontend/src/hooks/useAnalytics.ts create mode 100644 frontend/src/hooks/useAudit.ts create mode 100644 frontend/src/hooks/usePolicies.ts create mode 100644 frontend/src/hooks/useSchedules.ts create mode 100644 frontend/src/pages/AnalyticsPage.tsx create mode 100644 frontend/src/pages/AuditPage.tsx create mode 100644 frontend/src/pages/PoliciesPage.tsx create mode 100644 frontend/src/pages/SchedulesPage.tsx diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b043e95..3658adc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: cache: pip - name: Install dependencies - run: pip install -e ".[llm,dev,observability,lineage]" + run: pip install -e ".[llm,dev,observability,lineage,scheduling]" # Gating: the test suite must pass before any Phase 0+ refactor lands. - name: Tests diff --git a/backend/alembic/versions/009_audit_events.py b/backend/alembic/versions/009_audit_events.py new file mode 100644 index 0000000..fcc226c --- /dev/null +++ b/backend/alembic/versions/009_audit_events.py @@ -0,0 +1,74 @@ +"""Audit events (Phase 4 — Milestone 1) + +Revision ID: 009 +Revises: 008 +Create Date: 2026-06-08 + +Adds ``audit_events`` — an append-only log of security- and governance-relevant +actions (login, connection CRUD, credential rotation, introspection, query +generated/executed/blocked, metric certified, knowledge imported). Written +fire-and-forget so auditing never breaks the audited action. Org-scoped and +exportable. ``actor_id`` / ``workspace_id`` are nullable so system-driven and +pre-auth events can still be recorded. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "009" +down_revision: str = "008" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "audit_events", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "organization_id", + UUID(as_uuid=True), + sa.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("event_type", sa.String(64), nullable=False), + sa.Column( + "actor_id", + UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column( + "workspace_id", + UUID(as_uuid=True), + sa.ForeignKey("teams.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("payload", JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + ) + op.create_index("ix_audit_events_event_type", "audit_events", ["event_type"]) + op.create_index("ix_audit_events_created_at", "audit_events", ["created_at"]) + # Primary access pattern: an org's events, newest first. + op.create_index( + "ix_audit_events_org_created", + "audit_events", + ["organization_id", "created_at"], + ) + + +def downgrade() -> None: + op.drop_index("ix_audit_events_org_created", table_name="audit_events") + op.drop_index("ix_audit_events_created_at", table_name="audit_events") + op.drop_index("ix_audit_events_event_type", table_name="audit_events") + op.drop_table("audit_events") diff --git a/backend/alembic/versions/010_schedules.py b/backend/alembic/versions/010_schedules.py new file mode 100644 index 0000000..cad4833 --- /dev/null +++ b/backend/alembic/versions/010_schedules.py @@ -0,0 +1,80 @@ +"""Scheduled reports (Phase 4 — Milestone 2) + +Revision ID: 010 +Revises: 009 +Create Date: 2026-06-08 + +Adds ``schedules`` — recurring delivery of a saved query or dashboard on a cron +schedule over a notification channel (email/Slack/log), with optional +alert-on-threshold. Workspace-scoped like dashboards. The scheduler claims due +rows on ``next_run_at``. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "010" +down_revision: str = "009" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "schedules", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "organization_id", + UUID(as_uuid=True), + sa.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "workspace_id", + UUID(as_uuid=True), + sa.ForeignKey("teams.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "owner_id", + UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("target_type", sa.String(20), nullable=False), + sa.Column("target_id", UUID(as_uuid=True), nullable=False), + sa.Column("cron", sa.String(120), nullable=False), + sa.Column("channel", sa.String(20), nullable=False, server_default="email"), + sa.Column("recipients", JSONB, nullable=False, server_default=sa.text("'[]'::jsonb")), + sa.Column("params", JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")), + sa.Column("threshold", JSONB, nullable=True), + sa.Column( + "only_on_threshold", sa.Boolean, nullable=False, server_default=sa.text("false") + ), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.text("true")), + sa.Column("next_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_status", sa.String(20), nullable=True), + sa.Column("last_error", sa.Text, nullable=True), + sa.Column( + "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False + ), + sa.Column( + "updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False + ), + ) + # The scheduler scans for enabled, due rows ordered by next_run_at. + op.create_index("ix_schedules_next_run_at", "schedules", ["next_run_at"]) + op.create_index("ix_schedules_workspace_id", "schedules", ["workspace_id"]) + + +def downgrade() -> None: + op.drop_index("ix_schedules_workspace_id", table_name="schedules") + op.drop_index("ix_schedules_next_run_at", table_name="schedules") + op.drop_table("schedules") diff --git a/backend/alembic/versions/011_data_policies.py b/backend/alembic/versions/011_data_policies.py new file mode 100644 index 0000000..5b16315 --- /dev/null +++ b/backend/alembic/versions/011_data_policies.py @@ -0,0 +1,68 @@ +"""Data policies (Phase 4 — Milestone 3) + +Revision ID: 011 +Revises: 010 +Create Date: 2026-06-08 + +Adds ``data_policies`` — governance rules enforced before a query reaches the +connector: role-scoped row/runtime caps, allow/block table lists, blocked +columns, PII column masking, and row-level filters. Connection-scoped like the +semantic layer. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "011" +down_revision: str = "010" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "data_policies", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "organization_id", + UUID(as_uuid=True), + sa.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "connection_id", + UUID(as_uuid=True), + sa.ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.text("true")), + sa.Column("priority", sa.Integer, nullable=False, server_default=sa.text("100")), + sa.Column( + "applies_to_roles", JSONB, nullable=False, server_default=sa.text("'[]'::jsonb") + ), + sa.Column("max_rows", sa.Integer, nullable=True), + sa.Column("max_runtime_seconds", sa.Integer, nullable=True), + sa.Column("allowed_tables", JSONB, nullable=False, server_default=sa.text("'[]'::jsonb")), + sa.Column("blocked_tables", JSONB, nullable=False, server_default=sa.text("'[]'::jsonb")), + sa.Column("blocked_columns", JSONB, nullable=False, server_default=sa.text("'[]'::jsonb")), + sa.Column("masked_columns", JSONB, nullable=False, server_default=sa.text("'[]'::jsonb")), + sa.Column("row_filters", JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")), + sa.Column( + "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False + ), + sa.Column( + "updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False + ), + ) + op.create_index("ix_data_policies_connection_id", "data_policies", ["connection_id"]) + + +def downgrade() -> None: + op.drop_index("ix_data_policies_connection_id", table_name="data_policies") + op.drop_table("data_policies") diff --git a/backend/alembic/versions/012_cost_attribution.py b/backend/alembic/versions/012_cost_attribution.py new file mode 100644 index 0000000..3883437 --- /dev/null +++ b/backend/alembic/versions/012_cost_attribution.py @@ -0,0 +1,85 @@ +"""Cost & usage attribution (Phase 4 — Milestone 4) + +Revision ID: 012 +Revises: 011 +Create Date: 2026-06-08 + +Adds ``cost_attributions`` — per-execution usage + estimated cost, attributed to +a workspace/user/connection. Powers the usage analytics dashboards (slowest +queries, error rate, most-queried tables, cost per team). Populated best-effort +after each query execution. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "012" +down_revision: str = "011" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "cost_attributions", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "organization_id", + UUID(as_uuid=True), + sa.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "workspace_id", + UUID(as_uuid=True), + sa.ForeignKey("teams.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column( + "connection_id", + UUID(as_uuid=True), + sa.ForeignKey("database_connections.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column( + "user_id", + UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column( + "query_execution_id", + UUID(as_uuid=True), + sa.ForeignKey("query_executions.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("source_provider", sa.String(50), nullable=True), + sa.Column("status", sa.String(20), nullable=False, server_default="success"), + sa.Column("execution_time_ms", sa.Float, nullable=True), + sa.Column("row_count", sa.Integer, nullable=True), + sa.Column("scanned_bytes", sa.Integer, nullable=True), + sa.Column("slot_ms", sa.Integer, nullable=True), + sa.Column("dbu", sa.Float, nullable=True), + sa.Column("cost_usd", sa.Float, nullable=False, server_default="0"), + sa.Column("tables", JSONB, nullable=False, server_default=sa.text("'[]'::jsonb")), + sa.Column( + "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False + ), + ) + op.create_index("ix_cost_attributions_created_at", "cost_attributions", ["created_at"]) + op.create_index( + "ix_cost_attributions_org_created", + "cost_attributions", + ["organization_id", "created_at"], + ) + + +def downgrade() -> None: + op.drop_index("ix_cost_attributions_org_created", table_name="cost_attributions") + op.drop_index("ix_cost_attributions_created_at", table_name="cost_attributions") + op.drop_table("cost_attributions") diff --git a/backend/app/api/v1/endpoints/analytics.py b/backend/app/api/v1/endpoints/analytics.py new file mode 100644 index 0000000..9e9f5bc --- /dev/null +++ b/backend/app/api/v1/endpoints/analytics.py @@ -0,0 +1,65 @@ +"""Cost & usage analytics. Admin-only, org-scoped, windowed by ``days``.""" + +from datetime import UTC, datetime, timedelta + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.schemas.analytics import ( + CostByEntry, + SlowestQuery, + TableUsage, + UsageSummary, +) +from app.core.auth import AuthContext, get_org_context +from app.db.session import get_db +from app.services import cost_service + +router = APIRouter(prefix="/analytics", tags=["analytics"]) + + +def _since(days: int) -> datetime: + return datetime.now(UTC) - timedelta(days=days) + + +@router.get("/usage", response_model=UsageSummary) +async def usage( + days: int = Query(30, ge=1, le=365), + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + ctx.require_role("admin") + return await cost_service.usage_summary(db, ctx.organization_id, _since(days)) + + +@router.get("/cost", response_model=list[CostByEntry]) +async def cost_by( + by: str = Query("workspace", pattern="^(workspace|user|connection)$"), + days: int = Query(30, ge=1, le=365), + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + ctx.require_role("admin") + return await cost_service.cost_by(db, ctx.organization_id, by, _since(days)) + + +@router.get("/slowest", response_model=list[SlowestQuery]) +async def slowest( + days: int = Query(30, ge=1, le=365), + limit: int = Query(10, ge=1, le=100), + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + ctx.require_role("admin") + return await cost_service.slowest_queries(db, ctx.organization_id, _since(days), limit) + + +@router.get("/tables", response_model=list[TableUsage]) +async def tables( + days: int = Query(30, ge=1, le=365), + limit: int = Query(10, ge=1, le=100), + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + ctx.require_role("admin") + return await cost_service.most_queried_tables(db, ctx.organization_id, _since(days), limit) diff --git a/backend/app/api/v1/endpoints/audit.py b/backend/app/api/v1/endpoints/audit.py new file mode 100644 index 0000000..ccfaf75 --- /dev/null +++ b/backend/app/api/v1/endpoints/audit.py @@ -0,0 +1,93 @@ +"""Audit-event read API. Admin-only, org-scoped, with CSV export. + +Audit writes happen fire-and-forget at call sites via ``audit_service.record``; +this router is the governance read surface over them. +""" + +import csv +import io +import json +import uuid + +from fastapi import APIRouter, Depends, Query +from fastapi.responses import StreamingResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.schemas.audit_event import AuditEventResponse +from app.core.auth import AuthContext, get_org_context +from app.db.session import get_db +from app.services import audit_service + +router = APIRouter(prefix="/audit-events", tags=["audit"]) + + +@router.get("", response_model=list[AuditEventResponse]) +async def list_audit_events( + event_type: str | None = Query(None), + actor_id: uuid.UUID | None = Query(None), + limit: int = Query(100, ge=1, le=500), + offset: int = Query(0, ge=0), + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + """List the org's audit events, newest first. Requires admin.""" + ctx.require_role("admin") + events = await audit_service.list_events( + db, + organization_id=ctx.organization_id, + event_type=event_type, + actor_id=actor_id, + limit=limit, + offset=offset, + ) + return [AuditEventResponse.model_validate(e) for e in events] + + +@router.get("/event-types", response_model=list[str]) +async def list_event_types( + ctx: AuthContext = Depends(get_org_context), +): + """The canonical set of event types, for building a filter UI.""" + ctx.require_role("admin") + return list(audit_service.EVENT_TYPES) + + +@router.get("/export") +async def export_audit_events( + event_type: str | None = Query(None), + actor_id: uuid.UUID | None = Query(None), + limit: int = Query(10000, ge=1, le=100000), + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + """Export the org's audit events as CSV. Requires admin.""" + ctx.require_role("admin") + events = await audit_service.list_events( + db, + organization_id=ctx.organization_id, + event_type=event_type, + actor_id=actor_id, + limit=limit, + offset=0, + ) + + buf = io.StringIO() + writer = csv.writer(buf) + writer.writerow(["id", "event_type", "actor_id", "workspace_id", "created_at", "payload"]) + for e in events: + writer.writerow( + [ + str(e.id), + e.event_type, + str(e.actor_id) if e.actor_id else "", + str(e.workspace_id) if e.workspace_id else "", + e.created_at.isoformat(), + json.dumps(e.payload, separators=(",", ":")), + ] + ) + buf.seek(0) + return StreamingResponse( + iter([buf.getvalue()]), + media_type="text/csv", + headers={"Content-Disposition": "attachment; filename=audit_events.csv"}, + ) diff --git a/backend/app/api/v1/endpoints/auth.py b/backend/app/api/v1/endpoints/auth.py index 5fae5a4..dd960bc 100644 --- a/backend/app/api/v1/endpoints/auth.py +++ b/backend/app/api/v1/endpoints/auth.py @@ -19,6 +19,7 @@ from app.core.auth_providers import get_auth_provider from app.db.models.user import User from app.db.session import get_db +from app.notifications import deliver from app.services import auth_service, identity_service logger = logging.getLogger("querywise.auth") @@ -53,11 +54,23 @@ async def register(body: RegisterRequest, response: Response, db: AsyncSession = @router.post("/magic-link", response_model=MagicLinkResponse) async def request_magic_link(body: MagicLinkRequest, db: AsyncSession = Depends(get_db)): token = await auth_service.request_magic_link(db, body.email) - # Delivery (email/Slack) is wired in Phase 4; for now log it and, outside - # production, surface it so local dev can complete the flow. - frontend = settings.cors_origins[0] if settings.cors_origins else None - verify_url = f"{frontend}/login/verify?token={token}" if frontend else None - logger.info("Magic link issued for %s: %s", body.email, verify_url or token) + base = settings.app_base_url or (settings.cors_origins[0] if settings.cors_origins else None) + verify_url = f"{base}/login/verify?token={token}" if base else None + + # Deliver via the configured channel (Phase 4). With no SMTP host set this + # degrades to logging — so local dev still completes the flow. + await deliver( + "email", + subject="Your QueryWise sign-in link", + text_body=( + f"Click to sign in: {verify_url}\n\n" + f"This link expires in {settings.magic_link_ttl_minutes} minutes." + if verify_url + else f"Your sign-in token: {token}" + ), + recipients=[body.email], + ) + # Outside production, also surface the token so a dev client can complete login. expose = settings.environment != "production" return MagicLinkResponse( sent=True, diff --git a/backend/app/api/v1/endpoints/policies.py b/backend/app/api/v1/endpoints/policies.py new file mode 100644 index 0000000..741cedf --- /dev/null +++ b/backend/app/api/v1/endpoints/policies.py @@ -0,0 +1,72 @@ +"""Data-policy CRUD. Connection-scoped; editor+ to modify (admins typically).""" + +import uuid + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.deps import require_connection_read, require_connection_write +from app.api.v1.schemas.data_policy import ( + DataPolicyCreate, + DataPolicyResponse, + DataPolicyUpdate, +) +from app.core.auth import AuthContext +from app.db.session import get_db +from app.services import policy_service + +router = APIRouter(tags=["policies"]) + +_BASE = "/connections/{connection_id}/policies" + + +@router.get(_BASE, response_model=list[DataPolicyResponse]) +async def list_policies( + connection_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + return await policy_service.list_policies(db, connection_id) + + +@router.post(_BASE, response_model=DataPolicyResponse, status_code=201) +async def create_policy( + connection_id: uuid.UUID, + body: DataPolicyCreate, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + return await policy_service.create_policy(db, connection_id, ctx, **body.model_dump()) + + +@router.get(_BASE + "/{policy_id}", response_model=DataPolicyResponse) +async def get_policy( + connection_id: uuid.UUID, + policy_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + return await policy_service.get_policy(db, connection_id, policy_id) + + +@router.put(_BASE + "/{policy_id}", response_model=DataPolicyResponse) +async def update_policy( + connection_id: uuid.UUID, + policy_id: uuid.UUID, + body: DataPolicyUpdate, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + return await policy_service.update_policy( + db, connection_id, policy_id, ctx, body.model_dump(exclude_unset=True) + ) + + +@router.delete(_BASE + "/{policy_id}", status_code=204) +async def delete_policy( + connection_id: uuid.UUID, + policy_id: uuid.UUID, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + await policy_service.delete_policy(db, connection_id, policy_id, ctx) diff --git a/backend/app/api/v1/endpoints/schedules.py b/backend/app/api/v1/endpoints/schedules.py new file mode 100644 index 0000000..9b1d1c1 --- /dev/null +++ b/backend/app/api/v1/endpoints/schedules.py @@ -0,0 +1,79 @@ +"""Scheduled-report CRUD + manual trigger. Workspace-scoped (like dashboards).""" + +import uuid + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.schemas.schedule import ( + ScheduleCreate, + ScheduleResponse, + ScheduleRunResponse, + ScheduleUpdate, +) +from app.core.auth import AuthContext, get_org_context +from app.db.session import get_db +from app.services import schedule_service + +router = APIRouter(prefix="/schedules", tags=["schedules"]) + + +@router.get("", response_model=list[ScheduleResponse]) +async def list_schedules( + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + return await schedule_service.list_schedules(db, ctx) + + +@router.post("", response_model=ScheduleResponse, status_code=201) +async def create_schedule( + body: ScheduleCreate, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + data = body.model_dump() + if data.get("threshold") is not None: + data["threshold"] = body.threshold.model_dump() + return await schedule_service.create_schedule(db, ctx, **data) + + +@router.get("/{schedule_id}", response_model=ScheduleResponse) +async def get_schedule( + schedule_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + return await schedule_service.get_schedule(db, schedule_id, ctx) + + +@router.put("/{schedule_id}", response_model=ScheduleResponse) +async def update_schedule( + schedule_id: uuid.UUID, + body: ScheduleUpdate, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + updates = body.model_dump(exclude_unset=True) + if "threshold" in updates and body.threshold is not None: + updates["threshold"] = body.threshold.model_dump() + return await schedule_service.update_schedule(db, schedule_id, ctx, updates) + + +@router.delete("/{schedule_id}", status_code=204) +async def delete_schedule( + schedule_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + await schedule_service.delete_schedule(db, schedule_id, ctx) + + +@router.post("/{schedule_id}/run", response_model=ScheduleRunResponse) +async def run_schedule_now( + schedule_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + """Trigger a schedule immediately without changing its cron cadence.""" + return await schedule_service.run_now(db, schedule_id, ctx) diff --git a/backend/app/api/v1/router.py b/backend/app/api/v1/router.py index 730faa8..7b8cf2d 100644 --- a/backend/app/api/v1/router.py +++ b/backend/app/api/v1/router.py @@ -1,8 +1,10 @@ from fastapi import APIRouter from app.api.v1.endpoints import ( + analytics, api_keys, assistant, + audit, auth, catalog, connections, @@ -12,10 +14,12 @@ health, knowledge, metrics, + policies, query, query_history, sample_queries, saved_queries, + schedules, schemas, teams, ) @@ -39,3 +43,7 @@ api_router.include_router(query_history.router) api_router.include_router(knowledge.router) api_router.include_router(catalog.router) +api_router.include_router(audit.router) +api_router.include_router(schedules.router) +api_router.include_router(policies.router) +api_router.include_router(analytics.router) diff --git a/backend/app/api/v1/schemas/analytics.py b/backend/app/api/v1/schemas/analytics.py new file mode 100644 index 0000000..60021de --- /dev/null +++ b/backend/app/api/v1/schemas/analytics.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel + + +class UsageSummary(BaseModel): + total_queries: int + error_count: int + error_rate: float + total_cost_usd: float + total_scanned_bytes: int + avg_execution_ms: float | None + + +class CostByEntry(BaseModel): + key: str | None + cost_usd: float + query_count: int + + +class SlowestQuery(BaseModel): + query_execution_id: str | None + execution_time_ms: float | None + cost_usd: float + source_provider: str | None + question: str | None + + +class TableUsage(BaseModel): + table: str + query_count: int diff --git a/backend/app/api/v1/schemas/audit_event.py b/backend/app/api/v1/schemas/audit_event.py new file mode 100644 index 0000000..f157178 --- /dev/null +++ b/backend/app/api/v1/schemas/audit_event.py @@ -0,0 +1,16 @@ +from datetime import datetime +from typing import Any +from uuid import UUID + +from pydantic import BaseModel + + +class AuditEventResponse(BaseModel): + id: UUID + event_type: str + actor_id: UUID | None + workspace_id: UUID | None + payload: dict[str, Any] + created_at: datetime + + model_config = {"from_attributes": True} diff --git a/backend/app/api/v1/schemas/data_policy.py b/backend/app/api/v1/schemas/data_policy.py new file mode 100644 index 0000000..5580064 --- /dev/null +++ b/backend/app/api/v1/schemas/data_policy.py @@ -0,0 +1,49 @@ +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, Field + + +class DataPolicyBase(BaseModel): + name: str + enabled: bool = True + priority: int = 100 + applies_to_roles: list[str] = Field( + default_factory=list, description="Roles this applies to (empty = all)" + ) + max_rows: int | None = None + max_runtime_seconds: int | None = None + allowed_tables: list[str] = Field(default_factory=list) + blocked_tables: list[str] = Field(default_factory=list) + blocked_columns: list[str] = Field(default_factory=list) + masked_columns: list[str] = Field(default_factory=list) + row_filters: dict[str, str] = Field( + default_factory=dict, description="table -> SQL boolean condition" + ) + + +class DataPolicyCreate(DataPolicyBase): + pass + + +class DataPolicyUpdate(BaseModel): + name: str | None = None + enabled: bool | None = None + priority: int | None = None + applies_to_roles: list[str] | None = None + max_rows: int | None = None + max_runtime_seconds: int | None = None + allowed_tables: list[str] | None = None + blocked_tables: list[str] | None = None + blocked_columns: list[str] | None = None + masked_columns: list[str] | None = None + row_filters: dict[str, str] | None = None + + +class DataPolicyResponse(DataPolicyBase): + id: UUID + connection_id: UUID + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} diff --git a/backend/app/api/v1/schemas/schedule.py b/backend/app/api/v1/schemas/schedule.py new file mode 100644 index 0000000..d444fed --- /dev/null +++ b/backend/app/api/v1/schemas/schedule.py @@ -0,0 +1,67 @@ +from datetime import datetime +from typing import Any +from uuid import UUID + +from pydantic import BaseModel, Field + +from app.db.models.schedule import TARGET_TYPES + + +class ScheduleThreshold(BaseModel): + metric: str = Field("row_count", description="'row_count' or a result column name") + op: str = Field(">", description="One of: > >= < <= == !=") + value: float + + +class ScheduleCreate(BaseModel): + name: str + target_type: str = Field(description=f"One of: {', '.join(TARGET_TYPES)}") + target_id: UUID + cron: str = Field(description="5-field cron expression (UTC)") + channel: str = Field("email", description="email | slack | log") + recipients: list[str] = Field(default_factory=list) + params: dict[str, Any] = Field(default_factory=dict) + threshold: ScheduleThreshold | None = None + only_on_threshold: bool = False + enabled: bool = True + + +class ScheduleUpdate(BaseModel): + name: str | None = None + cron: str | None = None + channel: str | None = None + recipients: list[str] | None = None + params: dict[str, Any] | None = None + threshold: ScheduleThreshold | None = None + only_on_threshold: bool | None = None + enabled: bool | None = None + + +class ScheduleResponse(BaseModel): + id: UUID + name: str + target_type: str + target_id: UUID + cron: str + channel: str + recipients: list[str] + params: dict[str, Any] + threshold: dict[str, Any] | None + only_on_threshold: bool + enabled: bool + next_run_at: datetime | None + last_run_at: datetime | None + last_status: str | None + last_error: str | None + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class ScheduleRunResponse(BaseModel): + schedule_id: UUID + status: str + delivered: bool + threshold_met: bool | None + error: str | None diff --git a/backend/app/config.py b/backend/app/config.py index 0fbf59a..f2ce738 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -35,6 +35,33 @@ class Settings(BaseSettings): job_backend: str = "inprocess" # inprocess (asyncio) | arq (Redis) redis_url: str = "redis://localhost:6379/0" + # Scheduling & notifications (Phase 4 — Milestone 2) + # The in-process scheduler loop claims due schedules and dispatches report + # jobs. Disable when a separate process owns scheduling, or in tests. + scheduler_enabled: bool = True + scheduler_tick_seconds: int = 60 # how often the loop scans for due schedules + # Public base URL of the frontend, used to build links in delivered reports + # and magic-link emails. Falls back to the first CORS origin when unset. + app_base_url: str | None = None + # Email (SMTP) delivery. When smtp_host is unset, email degrades to logging. + smtp_host: str | None = None + smtp_port: int = 587 + smtp_username: str | None = None + smtp_password: str | None = None + smtp_use_tls: bool = True # STARTTLS + smtp_from: str = "querywise@localhost" + # Slack delivery via an Incoming Webhook URL. Unset → Slack degrades to log. + slack_webhook_url: str | None = None + + # Cost attribution pricing (Phase 4 — Milestone 4). Estimates only; tune to + # your contract. BigQuery on-demand is ~$6.25 / TiB scanned. + cost_per_tib_scanned_usd: float = 6.25 + cost_per_slot_ms_usd: float = 0.0 + cost_per_dbu_usd: float = 0.0 + # Fallback when no warehouse stats are reported (e.g. PostgreSQL): a rough + # per-second compute estimate. 0 = no time-based cost. + cost_per_second_usd: float = 0.0 + # Query defaults default_query_timeout_seconds: int = 30 default_max_rows: int = 1000 diff --git a/backend/app/connectors/base_connector.py b/backend/app/connectors/base_connector.py index 7725621..f11347c 100644 --- a/backend/app/connectors/base_connector.py +++ b/backend/app/connectors/base_connector.py @@ -51,6 +51,11 @@ class QueryResult: row_count: int execution_time_ms: float truncated: bool + # Optional connector-reported execution stats for cost attribution + # (e.g. BigQuery ``scanned_bytes`` / ``slot_ms``, Databricks ``dbu``). + # Default empty — connectors populate what they can; cost_service falls + # back to time-based estimation otherwise. + stats: dict[str, Any] = field(default_factory=dict) class BaseConnector(ABC): diff --git a/backend/app/connectors/bigquery/connector.py b/backend/app/connectors/bigquery/connector.py index 53671ff..bbdbaa1 100644 --- a/backend/app/connectors/bigquery/connector.py +++ b/backend/app/connectors/bigquery/connector.py @@ -199,11 +199,20 @@ async def execute_query( truncated = len(rows) > max_rows rows = rows[:max_rows] + # Job stats for cost attribution (bytes scanned / billed, slot time). + job_stats = { + "scanned_bytes": getattr(job, "total_bytes_processed", None), + "billed_bytes": getattr(job, "total_bytes_billed", None), + "slot_ms": getattr(job, "slot_millis", None), + } + job_stats = {k: v for k, v in job_stats.items() if v is not None} + if not rows: return QueryResult( columns=[], column_types=[], rows=[], + stats=job_stats, row_count=0, execution_time_ms=elapsed_ms, truncated=False, @@ -223,6 +232,7 @@ async def execute_query( row_count=len(result_rows), execution_time_ms=elapsed_ms, truncated=truncated, + stats=job_stats, ) async def get_sample_values( diff --git a/backend/app/db/models/__init__.py b/backend/app/db/models/__init__.py index 39758da..01c5d1e 100644 --- a/backend/app/db/models/__init__.py +++ b/backend/app/db/models/__init__.py @@ -1,8 +1,11 @@ from app.db.models.api_key import ApiKey from app.db.models.artifact_dependency import ArtifactDependency +from app.db.models.audit_event import AuditEvent from app.db.models.chart import Chart from app.db.models.connection import DatabaseConnection +from app.db.models.cost_attribution import CostAttribution from app.db.models.dashboard import Dashboard +from app.db.models.data_policy import DataPolicy from app.db.models.dashboard_tile import DashboardTile from app.db.models.dictionary import DictionaryEntry from app.db.models.glossary import GlossaryTerm @@ -14,6 +17,7 @@ from app.db.models.result_snapshot import ResultSnapshot from app.db.models.sample_query import SampleQuery from app.db.models.saved_query import SavedQuery +from app.db.models.schedule import Schedule from app.db.models.schema_cache import CachedColumn, CachedRelationship, CachedTable from app.db.models.semantic_version import SemanticVersion from app.db.models.team import Team @@ -43,4 +47,8 @@ "DashboardTile", "SemanticVersion", "ArtifactDependency", + "AuditEvent", + "Schedule", + "DataPolicy", + "CostAttribution", ] diff --git a/backend/app/db/models/audit_event.py b/backend/app/db/models/audit_event.py new file mode 100644 index 0000000..6ee1e0f --- /dev/null +++ b/backend/app/db/models/audit_event.py @@ -0,0 +1,41 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, String, func +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.base import Base + + +class AuditEvent(Base): + """An append-only record of a security- or governance-relevant action. + + Written fire-and-forget (see ``app.services.audit_service.record``) so a + failure to audit never breaks the action being audited. Org-scoped and + exportable for compliance. ``actor_id`` is nullable so system-driven events + (startup auto-setup, scheduled jobs) and pre-auth events (failed login) can + still be recorded. + """ + + __tablename__ = "audit_events" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) + # Dotted action name, e.g. "connection.created", "query.blocked". See + # audit_service for the canonical set of constants. + event_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + actor_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) + # Optional workspace the action occurred in (events like login are org-level). + workspace_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("teams.id", ondelete="SET NULL") + ) + # Free-form structured context: target ids, names, outcome, request id, etc. + payload: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), index=True + ) diff --git a/backend/app/db/models/cost_attribution.py b/backend/app/db/models/cost_attribution.py new file mode 100644 index 0000000..38fbf13 --- /dev/null +++ b/backend/app/db/models/cost_attribution.py @@ -0,0 +1,54 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, Float, ForeignKey, Integer, String, func +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.base import Base + + +class CostAttribution(Base): + """Per-execution cost + usage stats, attributed to a workspace/user. + + Written best-effort after each query execution (post-hoc, since warehouse + job stats are only available once the query completes). Powers the usage + analytics dashboards (slowest queries, error rate, most-queried tables, cost + per team). ``cost_usd`` is an estimate from the configured pricing model; + ``scanned_bytes`` / ``slot_ms`` / ``dbu`` are populated when the connector + reports them (BigQuery today), else null. + """ + + __tablename__ = "cost_attributions" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) + workspace_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("teams.id", ondelete="SET NULL") + ) + connection_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="SET NULL") + ) + user_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) + query_execution_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("query_executions.id", ondelete="SET NULL") + ) + + source_provider: Mapped[str | None] = mapped_column(String(50)) + status: Mapped[str] = mapped_column(String(20), nullable=False, default="success") + execution_time_ms: Mapped[float | None] = mapped_column(Float) + row_count: Mapped[int | None] = mapped_column(Integer) + scanned_bytes: Mapped[int | None] = mapped_column(Integer) + slot_ms: Mapped[int | None] = mapped_column(Integer) + dbu: Mapped[float | None] = mapped_column(Float) + cost_usd: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) + # Referenced tables ("schema.table" or "table"), for most-queried analytics. + tables: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), index=True + ) diff --git a/backend/app/db/models/data_policy.py b/backend/app/db/models/data_policy.py new file mode 100644 index 0000000..a649bbd --- /dev/null +++ b/backend/app/db/models/data_policy.py @@ -0,0 +1,64 @@ +import uuid +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, func +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.base import Base + + +class DataPolicy(Base): + """A governance rule enforced *before* a query reaches the connector. + + Connection-scoped (like the semantic layer). A policy applies to a request + when the caller's role is in ``applies_to_roles`` (empty = all roles). When + several policies apply they are merged most-restrictively into an effective + policy (see ``policy_service.resolve_effective``): + + * ``max_rows`` / ``max_runtime_seconds`` — tightened to the minimum. + * ``allowed_tables`` — a referenced table must appear in *every* non-empty + allow-list (intersection); empty = no restriction. + * ``blocked_tables`` / ``blocked_columns`` — union; referencing one blocks + the query with an explanation. + * ``masked_columns`` — union; values are redacted in the result. + * ``row_filters`` — ``{table: ""}``; injected as a row-level + filter (AND-combined when multiple policies filter the same table). + + Table/column names may be bare (``email``) or schema/table-qualified + (``public.users`` / ``users.email``). + """ + + __tablename__ = "data_policies" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) + connection_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + name: Mapped[str] = mapped_column(String(255), nullable=False) + enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + # Lower numbers are reported first when explaining a block; does not change + # the merge (which is always most-restrictive). + priority: Mapped[int] = mapped_column(Integer, nullable=False, default=100) + # Roles this policy applies to (admin|editor|viewer). Empty = all roles. + applies_to_roles: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + + max_rows: Mapped[int | None] = mapped_column(Integer) + max_runtime_seconds: Mapped[int | None] = mapped_column(Integer) + allowed_tables: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + blocked_tables: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + blocked_columns: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + masked_columns: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + row_filters: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict) + + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) diff --git a/backend/app/db/models/schedule.py b/backend/app/db/models/schedule.py new file mode 100644 index 0000000..b960ccc --- /dev/null +++ b/backend/app/db/models/schedule.py @@ -0,0 +1,71 @@ +import uuid +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.base import Base + +# What a schedule runs. +TARGET_SAVED_QUERY = "saved_query" +TARGET_DASHBOARD = "dashboard" +TARGET_TYPES = (TARGET_SAVED_QUERY, TARGET_DASHBOARD) + +# Last-run outcome. +STATUS_PENDING = "pending" +STATUS_SUCCESS = "success" +STATUS_ERROR = "error" +STATUS_SKIPPED = "skipped" # ran but delivery suppressed (threshold not met) + + +class Schedule(Base): + """A recurring report: run a saved query or dashboard on a cron schedule and + deliver the result over a notification channel. + + Workspace-scoped (like :class:`Dashboard`). ``next_run_at`` is computed from + ``cron`` and is the column the scheduler claims on; ``last_*`` capture the + most recent run for display + audit. + """ + + __tablename__ = "schedules" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) + workspace_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), nullable=False + ) + owner_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) + + name: Mapped[str] = mapped_column(String(255), nullable=False) + target_type: Mapped[str] = mapped_column(String(20), nullable=False) + target_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + + # Standard 5-field cron expression (UTC). + cron: Mapped[str] = mapped_column(String(120), nullable=False) + # Delivery channel: "email" | "slack" | "log". + channel: Mapped[str] = mapped_column(String(20), nullable=False, default="email") + recipients: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + # Params supplied to the saved query / dashboard filters at run time. + params: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict) + + # Optional alert-on-threshold: {"metric": "row_count"|"", + # "op": ">"|">="|"<"|"<="|"=="|"!=", "value": }. + threshold: Mapped[dict | None] = mapped_column(JSONB) + # When true, deliver only if the threshold is met (otherwise mark skipped). + only_on_threshold: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + + enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + next_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), index=True) + last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + last_status: Mapped[str | None] = mapped_column(String(20)) + last_error: Mapped[str | None] = mapped_column(Text) + + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) diff --git a/backend/app/jobs/scheduler.py b/backend/app/jobs/scheduler.py new file mode 100644 index 0000000..5cf3c02 --- /dev/null +++ b/backend/app/jobs/scheduler.py @@ -0,0 +1,117 @@ +"""In-process scheduler loop for recurring reports. + +A single asyncio loop (started from the FastAPI lifespan) ticks every +``SCHEDULER_TICK_SECONDS``, atomically claims due schedules, and dispatches a +``run_schedule`` job per schedule through the job queue. The claim advances +``next_run_at`` to the next cron slot *before* dispatching, using +``FOR UPDATE SKIP LOCKED`` — so concurrent ticks (or multiple replicas) never +double-fire the same schedule, and a worker crash skips at most one run rather +than stalling or looping. + +Dispatch goes through ``get_job_queue()``, so under ``JOB_BACKEND=arq`` the +report runs in the worker process while this loop only does the claiming. +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from datetime import UTC, datetime + +from sqlalchemy import select + +from app.config import settings +from app.db.models.schedule import Schedule +from app.db.session import async_session_factory +from app.jobs import get_job_queue, register_job + +logger = logging.getLogger("querywise") + + +async def run_schedule_job(schedule_id: str) -> None: + """Job body: load one schedule and run it end-to-end (own session).""" + from app.services import schedule_service + + async with async_session_factory() as db: + try: + schedule = await db.get(Schedule, uuid.UUID(str(schedule_id))) + if schedule is None: + logger.warning("run_schedule: schedule %s not found", schedule_id) + return + # next_run_at was already advanced by the claim, so don't reschedule. + await schedule_service.run_one(db, schedule, reschedule=False) + await db.commit() + except Exception: # noqa: BLE001 — never crash the worker + await db.rollback() + logger.exception("run_schedule: schedule %s failed", schedule_id) + + +register_job("run_schedule", run_schedule_job) + + +async def _claim_and_dispatch() -> int: + """Claim due schedules (advancing next_run_at) and dispatch their jobs.""" + now = datetime.now(UTC) + async with async_session_factory() as db: + from app.services import schedule_service + + result = await db.execute( + select(Schedule) + .where( + Schedule.enabled.is_(True), + Schedule.next_run_at.isnot(None), + Schedule.next_run_at <= now, + ) + .with_for_update(skip_locked=True) + ) + due = list(result.scalars().all()) + claimed: list[str] = [] + for s in due: + s.next_run_at = schedule_service.compute_next_run(s.cron, after=now) + claimed.append(str(s.id)) + await db.commit() + + if claimed: + queue = get_job_queue() + for sid in claimed: + queue.submit("run_schedule", sid, name=f"schedule-{sid}") + logger.info("Scheduler dispatched %d due schedule(s)", len(claimed)) + return len(claimed) + + +_task: asyncio.Task | None = None + + +async def _loop() -> None: + interval = max(5, settings.scheduler_tick_seconds) + while True: + try: + await _claim_and_dispatch() + except Exception: # noqa: BLE001 — a bad tick must not kill the loop + logger.exception("Scheduler tick failed") + await asyncio.sleep(interval) + + +def start_scheduler() -> None: + """Start the scheduler loop (no-op if disabled or already running).""" + global _task + if not settings.scheduler_enabled: + logger.info("Scheduler disabled (SCHEDULER_ENABLED=false)") + return + if _task is not None and not _task.done(): + return + _task = asyncio.create_task(_loop(), name="scheduler-loop") + logger.info("Scheduler started (tick=%ss)", settings.scheduler_tick_seconds) + + +async def stop_scheduler() -> None: + """Cancel the scheduler loop on shutdown.""" + global _task + if _task is not None: + _task.cancel() + try: + await _task + except asyncio.CancelledError: + pass + _task = None diff --git a/backend/app/jobs/tasks.py b/backend/app/jobs/tasks.py index d29d013..ca08a55 100644 --- a/backend/app/jobs/tasks.py +++ b/backend/app/jobs/tasks.py @@ -8,3 +8,6 @@ # Registers "generate_embeddings". import app.services.setup_service # noqa: F401 + +# Registers "run_schedule". +import app.jobs.scheduler # noqa: F401 diff --git a/backend/app/main.py b/backend/app/main.py index 0c51eed..6e6f943 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -30,8 +30,14 @@ async def lifespan(app: FastAPI): from app.services.setup_service import auto_setup_sample_db await auto_setup_sample_db() + + # Start the recurring-report scheduler loop (registers "run_schedule"). + from app.jobs.scheduler import start_scheduler, stop_scheduler + + start_scheduler() yield # Shutdown + await stop_scheduler() await engine.dispose() diff --git a/backend/app/notifications/__init__.py b/backend/app/notifications/__init__.py new file mode 100644 index 0000000..afafd5d --- /dev/null +++ b/backend/app/notifications/__init__.py @@ -0,0 +1,68 @@ +"""Notification delivery — pluggable per-channel adapters. + +``get_notifier(channel)`` returns the adapter for a channel, falling back to the +:class:`LogNotifier` when the channel is unconfigured so delivery degrades +gracefully (no SMTP host / no Slack webhook → the message is logged). + +``deliver(...)`` is the fire-and-forget convenience used by callers that must +never fail because of a delivery error (e.g. magic-link). Scheduled reports call +``get_notifier(...).send(...)`` directly so they can record success/failure. +""" + +from __future__ import annotations + +import logging + +from app.config import settings +from app.notifications.base import NotificationMessage, Notifier +from app.notifications.email import EmailNotifier +from app.notifications.log import LogNotifier +from app.notifications.slack import SlackNotifier + +logger = logging.getLogger("querywise") + +CHANNELS = ("email", "slack", "log") + + +def get_notifier(channel: str) -> Notifier: + """Return the adapter for ``channel``, or a LogNotifier if unconfigured.""" + if channel == "email": + return EmailNotifier() if settings.smtp_host else LogNotifier("email") + if channel == "slack": + return SlackNotifier() if settings.slack_webhook_url else LogNotifier("slack") + return LogNotifier(channel) + + +async def deliver( + channel: str, + *, + subject: str, + text_body: str, + html_body: str | None = None, + recipients: list[str] | None = None, +) -> bool: + """Best-effort delivery. Returns True on success, False on failure (logged).""" + message = NotificationMessage( + subject=subject, + text_body=text_body, + html_body=html_body, + recipients=recipients or [], + ) + try: + await get_notifier(channel).send(message) + return True + except Exception: # noqa: BLE001 — delivery must not crash the caller + logger.exception("Notification delivery on channel '%s' failed", channel) + return False + + +__all__ = [ + "NotificationMessage", + "Notifier", + "EmailNotifier", + "SlackNotifier", + "LogNotifier", + "CHANNELS", + "get_notifier", + "deliver", +] diff --git a/backend/app/notifications/base.py b/backend/app/notifications/base.py new file mode 100644 index 0000000..25a26f0 --- /dev/null +++ b/backend/app/notifications/base.py @@ -0,0 +1,36 @@ +"""Notification delivery seam. + +A ``Notifier`` turns a :class:`NotificationMessage` into a delivered message on +one channel (email, Slack, …). Adapters degrade gracefully: when a channel is +not configured, the factory returns :class:`~app.notifications.log.LogNotifier` +so callers (scheduled reports, magic-link delivery) keep working in dev without +SMTP/Slack credentials — mirroring the pre-Phase-4 "log the token" behaviour. + +``send`` raises on a hard delivery failure so the caller can record it; callers +that must never fail (e.g. fire-and-forget) wrap the call themselves. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + + +@dataclass +class NotificationMessage: + subject: str + text_body: str + html_body: str | None = None + # Channel-specific destinations (email addresses for email; ignored by + # Slack, which targets the configured webhook's channel). + recipients: list[str] = field(default_factory=list) + + +class Notifier(ABC): + """Delivers a message on one channel.""" + + channel: str = "base" + + @abstractmethod + async def send(self, message: NotificationMessage) -> None: + """Deliver ``message``. Raises on hard failure.""" diff --git a/backend/app/notifications/email.py b/backend/app/notifications/email.py new file mode 100644 index 0000000..cf20235 --- /dev/null +++ b/backend/app/notifications/email.py @@ -0,0 +1,39 @@ +"""SMTP email notifier (stdlib ``smtplib``, no extra dependency). + +``smtplib`` is synchronous, so the blocking send runs in a worker thread via +``asyncio.to_thread`` to avoid stalling the event loop. +""" + +from __future__ import annotations + +import asyncio +import smtplib +from email.message import EmailMessage + +from app.config import settings +from app.notifications.base import NotificationMessage, Notifier + + +class EmailNotifier(Notifier): + channel = "email" + + async def send(self, message: NotificationMessage) -> None: + if not message.recipients: + raise ValueError("Email delivery requires at least one recipient.") + await asyncio.to_thread(self._send_sync, message) + + def _send_sync(self, message: NotificationMessage) -> None: + email = EmailMessage() + email["Subject"] = message.subject + email["From"] = settings.smtp_from + email["To"] = ", ".join(message.recipients) + email.set_content(message.text_body) + if message.html_body: + email.add_alternative(message.html_body, subtype="html") + + with smtplib.SMTP(settings.smtp_host, settings.smtp_port, timeout=30) as smtp: + if settings.smtp_use_tls: + smtp.starttls() + if settings.smtp_username: + smtp.login(settings.smtp_username, settings.smtp_password or "") + smtp.send_message(email) diff --git a/backend/app/notifications/log.py b/backend/app/notifications/log.py new file mode 100644 index 0000000..34ad219 --- /dev/null +++ b/backend/app/notifications/log.py @@ -0,0 +1,28 @@ +"""Fallback notifier that logs instead of delivering. + +Used when a channel is unconfigured (no SMTP host / no Slack webhook). Keeps the +scheduled-report and magic-link flows working in local dev without external +credentials — the message (and any link it carries) lands in the logs. +""" + +from __future__ import annotations + +import logging + +from app.notifications.base import NotificationMessage, Notifier + +logger = logging.getLogger("querywise") + + +class LogNotifier(Notifier): + def __init__(self, channel: str = "log") -> None: + self.channel = channel + + async def send(self, message: NotificationMessage) -> None: + logger.info( + "[notify:%s] to=%s subject=%r body=%s", + self.channel, + ", ".join(message.recipients) or "-", + message.subject, + message.text_body[:500], + ) diff --git a/backend/app/notifications/slack.py b/backend/app/notifications/slack.py new file mode 100644 index 0000000..d6cce3b --- /dev/null +++ b/backend/app/notifications/slack.py @@ -0,0 +1,21 @@ +"""Slack notifier via an Incoming Webhook (uses the existing httpx dependency).""" + +from __future__ import annotations + +import httpx + +from app.config import settings +from app.notifications.base import NotificationMessage, Notifier + + +class SlackNotifier(Notifier): + channel = "slack" + + async def send(self, message: NotificationMessage) -> None: + if not settings.slack_webhook_url: + raise ValueError("Slack delivery requires SLACK_WEBHOOK_URL to be set.") + # The webhook targets a fixed channel; render subject + body as text. + text = f"*{message.subject}*\n{message.text_body}" + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post(settings.slack_webhook_url, json={"text": text}) + resp.raise_for_status() diff --git a/backend/app/services/audit_service.py b/backend/app/services/audit_service.py new file mode 100644 index 0000000..0bb15be --- /dev/null +++ b/backend/app/services/audit_service.py @@ -0,0 +1,139 @@ +"""Append-only audit log of security- and governance-relevant actions. + +``record`` is **fire-and-forget**: it is wrapped so that a failure to write an +audit row never propagates into — and never fails — the action being audited. +This is a deliberate trade-off: an audit miss is logged but tolerated, whereas a +broken login or blocked-query path is not. + +Events are written inline (a small INSERT on the request's own session) rather +than through the job queue, so they are durable the moment the request commits +and need no running worker. The write uses a *nested* transaction (SAVEPOINT) so +that an audit failure can be rolled back without poisoning the caller's outer +transaction. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.telemetry import get_request_id +from app.db.models.audit_event import AuditEvent + +logger = logging.getLogger("uvicorn.error") + +# --- Canonical event types -------------------------------------------------- +# Dotted "." names. Keep this list authoritative so the API can +# expose it as a filter facet and call sites don't drift into typos. +AUTH_LOGIN = "auth.login" +AUTH_LOGIN_FAILED = "auth.login_failed" +AUTH_MAGIC_LINK_REQUESTED = "auth.magic_link_requested" +AUTH_LOGOUT = "auth.logout" + +CONNECTION_CREATED = "connection.created" +CONNECTION_UPDATED = "connection.updated" +CONNECTION_DELETED = "connection.deleted" +CONNECTION_INTROSPECTED = "connection.introspected" +CREDENTIAL_ROTATED = "connection.credential_rotated" + +QUERY_GENERATED = "query.generated" +QUERY_EXECUTED = "query.executed" +QUERY_BLOCKED = "query.blocked" + +METRIC_CERTIFIED = "metric.certified" +KNOWLEDGE_IMPORTED = "knowledge.imported" + +SCHEDULE_CREATED = "schedule.created" +SCHEDULE_UPDATED = "schedule.updated" +SCHEDULE_DELETED = "schedule.deleted" +SCHEDULE_RUN = "schedule.run" +REPORT_DELIVERED = "report.delivered" + +POLICY_CREATED = "policy.created" +POLICY_UPDATED = "policy.updated" +POLICY_DELETED = "policy.deleted" + +EVENT_TYPES: tuple[str, ...] = ( + AUTH_LOGIN, + AUTH_LOGIN_FAILED, + AUTH_MAGIC_LINK_REQUESTED, + AUTH_LOGOUT, + CONNECTION_CREATED, + CONNECTION_UPDATED, + CONNECTION_DELETED, + CONNECTION_INTROSPECTED, + CREDENTIAL_ROTATED, + QUERY_GENERATED, + QUERY_EXECUTED, + QUERY_BLOCKED, + METRIC_CERTIFIED, + KNOWLEDGE_IMPORTED, + SCHEDULE_CREATED, + SCHEDULE_UPDATED, + SCHEDULE_DELETED, + SCHEDULE_RUN, + REPORT_DELIVERED, + POLICY_CREATED, + POLICY_UPDATED, + POLICY_DELETED, +) + + +async def record( + db: AsyncSession, + *, + organization_id: uuid.UUID, + event_type: str, + actor_id: uuid.UUID | None = None, + workspace_id: uuid.UUID | None = None, + payload: dict[str, Any] | None = None, +) -> None: + """Write one audit event. Never raises. + + The caller is responsible for committing its own transaction; this adds the + event to the session within a SAVEPOINT so a write failure is isolated. The + current request id (if any) is folded into the payload for correlation. + """ + + data = dict(payload or {}) + rid = get_request_id() + if rid and rid != "-": + data.setdefault("request_id", rid) + + event = AuditEvent( + organization_id=organization_id, + event_type=event_type, + actor_id=actor_id, + workspace_id=workspace_id, + payload=data, + ) + try: + async with db.begin_nested(): + db.add(event) + except Exception: # noqa: BLE001 — auditing must never break the caller + logger.warning("Failed to record audit event '%s'", event_type, exc_info=True) + + +async def list_events( + db: AsyncSession, + *, + organization_id: uuid.UUID, + event_type: str | None = None, + actor_id: uuid.UUID | None = None, + limit: int = 100, + offset: int = 0, +) -> list[AuditEvent]: + """Return an org's audit events, newest first, with optional filters.""" + + stmt = select(AuditEvent).where(AuditEvent.organization_id == organization_id) + if event_type: + stmt = stmt.where(AuditEvent.event_type == event_type) + if actor_id: + stmt = stmt.where(AuditEvent.actor_id == actor_id) + stmt = stmt.order_by(AuditEvent.created_at.desc()).limit(limit).offset(offset) + result = await db.execute(stmt) + return list(result.scalars().all()) diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index e3d925f..21772b4 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -23,8 +23,9 @@ verify_password, ) from app.db.models.membership import ROLE_VIEWER, Membership +from app.db.models.team import Team from app.db.models.user import User -from app.services import identity_service +from app.services import audit_service, identity_service def _normalize_email(email: str) -> str: @@ -121,6 +122,24 @@ async def verify_magic_link(db: AsyncSession, token: str) -> User: async def _touch_login(db: AsyncSession, user: User) -> None: user.last_login_at = datetime.now(UTC) await db.flush() + # Resolve the user's home org (earliest membership) for the audit record. + result = await db.execute( + select(Team.organization_id, Membership.team_id) + .join(Team, Team.id == Membership.team_id) + .where(Membership.user_id == user.id) + .order_by(Membership.created_at) + .limit(1) + ) + row = result.first() + if row is not None: + await audit_service.record( + db, + organization_id=row[0], + workspace_id=row[1], + actor_id=user.id, + event_type=audit_service.AUTH_LOGIN, + payload={"email": user.email}, + ) def issue_session_token(user: User) -> str: diff --git a/backend/app/services/connection_service.py b/backend/app/services/connection_service.py index c887611..2e25d41 100644 --- a/backend/app/services/connection_service.py +++ b/backend/app/services/connection_service.py @@ -13,6 +13,7 @@ from app.core.secrets import get_secrets_provider from app.db.models.connection import DatabaseConnection from app.db.models.membership import ROLE_ADMIN, ROLE_EDITOR +from app.services import audit_service # Encryption of connection strings is delegated to the configured secrets # backend (env/Fernet by default — see app.core.secrets). @@ -104,6 +105,14 @@ async def create_connection( ) db.add(conn) await db.flush() + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.CONNECTION_CREATED, + payload={"connection_id": str(conn.id), "name": name, "connector_type": connector_type}, + ) return conn @@ -115,8 +124,10 @@ async def update_connection( ) -> DatabaseConnection: conn = await get_connection(db, connection_id, ctx, write=True) + rotated = False if "connection_string" in updates and updates["connection_string"] is not None: conn.connection_string_encrypted = _encrypt(str(updates.pop("connection_string"))) + rotated = True for key, value in updates.items(): if value is not None and hasattr(conn, key): @@ -125,6 +136,23 @@ async def update_connection( await db.flush() # Invalidate cached connector since config may have changed await remove_connector(str(connection_id)) + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.CONNECTION_UPDATED, + payload={"connection_id": str(conn.id), "name": conn.name}, + ) + if rotated: + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.CREDENTIAL_ROTATED, + payload={"connection_id": str(conn.id), "name": conn.name}, + ) return conn @@ -132,9 +160,18 @@ async def delete_connection( db: AsyncSession, connection_id: uuid.UUID, ctx: AuthContext ) -> None: conn = await get_connection(db, connection_id, ctx, write=True) + name = conn.name await remove_connector(str(connection_id)) await db.delete(conn) await db.flush() + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.CONNECTION_DELETED, + payload={"connection_id": str(connection_id), "name": name}, + ) async def test_connection( diff --git a/backend/app/services/cost_service.py b/backend/app/services/cost_service.py new file mode 100644 index 0000000..0e9e76d --- /dev/null +++ b/backend/app/services/cost_service.py @@ -0,0 +1,208 @@ +"""Cost & usage attribution (Phase 4 — Milestone 4). + +``compute_cost`` turns connector-reported stats into a USD estimate using the +configured pricing. ``record_execution_cost`` writes one +:class:`CostAttribution` per execution — best-effort, never raising into the +query response (like audit). The aggregation helpers back the analytics API. + +Cost is connector-specific and post-hoc: only BigQuery reports scanned bytes / +slot time today; other connectors fall back to the optional time-based estimate +(``COST_PER_SECOND_USD``, default 0). +""" + +from __future__ import annotations + +import logging +import uuid +from collections import Counter +from datetime import datetime +from typing import Any + +from sqlalchemy import case, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.core.auth import AuthContext +from app.db.models.cost_attribution import CostAttribution +from app.db.models.query_history import QueryExecution +from app.services.lineage_service import REF_TABLE, dialect_for, extract_refs + +logger = logging.getLogger("querywise") + +_TIB = 1024**4 + + +def compute_cost(stats: dict[str, Any] | None, execution_time_ms: float | None) -> float: + """Estimate query cost (USD) from connector stats, else time-based fallback.""" + stats = stats or {} + cost = 0.0 + scanned = stats.get("billed_bytes") or stats.get("scanned_bytes") + if scanned: + cost += (scanned / _TIB) * settings.cost_per_tib_scanned_usd + if stats.get("slot_ms"): + cost += stats["slot_ms"] * settings.cost_per_slot_ms_usd + if stats.get("dbu"): + cost += stats["dbu"] * settings.cost_per_dbu_usd + if cost == 0.0 and execution_time_ms and settings.cost_per_second_usd: + cost += (execution_time_ms / 1000.0) * settings.cost_per_second_usd + return round(cost, 6) + + +def _referenced_tables(sql: str | None, connector_type: str | None) -> list[str]: + if not sql: + return [] + refs = extract_refs(sql, dialect_for(connector_type)) + return [ + f"{r.schema_name}.{r.table_name}" if r.schema_name else r.table_name + for r in refs + if r.ref_kind == REF_TABLE + ] + + +async def record_execution_cost( + db: AsyncSession, + *, + execution: QueryExecution, + ctx: AuthContext, + connector_type: str | None, + stats: dict[str, Any] | None, + final_sql: str | None, +) -> None: + """Write a CostAttribution for one execution. Never raises.""" + try: + stats = stats or {} + attribution = CostAttribution( + organization_id=execution.organization_id, + workspace_id=ctx.workspace_id, + connection_id=execution.connection_id, + user_id=execution.user_id, + query_execution_id=execution.id, + source_provider=connector_type, + status=execution.execution_status, + execution_time_ms=execution.execution_time_ms, + row_count=execution.row_count, + scanned_bytes=stats.get("scanned_bytes") or stats.get("billed_bytes"), + slot_ms=stats.get("slot_ms"), + dbu=stats.get("dbu"), + cost_usd=compute_cost(stats, execution.execution_time_ms), + tables=_referenced_tables(final_sql, connector_type), + ) + async with db.begin_nested(): + db.add(attribution) + except Exception: # noqa: BLE001 — cost capture must never break the query + logger.warning("Failed to record cost attribution", exc_info=True) + + +# --------------------------------------------------------------------------- # +# Aggregations (org-scoped, admin-facing analytics) +# --------------------------------------------------------------------------- # +async def usage_summary( + db: AsyncSession, organization_id: uuid.UUID, since: datetime +) -> dict[str, Any]: + stmt = select( + func.count().label("total"), + func.coalesce(func.sum(case((CostAttribution.status == "error", 1), else_=0)), 0).label( + "errors" + ), + func.coalesce(func.sum(CostAttribution.cost_usd), 0.0).label("cost"), + func.coalesce(func.sum(CostAttribution.scanned_bytes), 0).label("scanned"), + func.avg(CostAttribution.execution_time_ms).label("avg_ms"), + ).where( + CostAttribution.organization_id == organization_id, + CostAttribution.created_at >= since, + ) + row = (await db.execute(stmt)).one() + total = row.total or 0 + errors = row.errors or 0 + return { + "total_queries": total, + "error_count": errors, + "error_rate": round(errors / total, 4) if total else 0.0, + "total_cost_usd": round(float(row.cost or 0.0), 6), + "total_scanned_bytes": int(row.scanned or 0), + "avg_execution_ms": round(float(row.avg_ms), 2) if row.avg_ms is not None else None, + } + + +_DIMENSIONS = { + "workspace": CostAttribution.workspace_id, + "user": CostAttribution.user_id, + "connection": CostAttribution.connection_id, +} + + +async def cost_by( + db: AsyncSession, organization_id: uuid.UUID, dimension: str, since: datetime +) -> list[dict[str, Any]]: + col = _DIMENSIONS.get(dimension) + if col is None: + raise ValueError(f"dimension must be one of {sorted(_DIMENSIONS)}") + stmt = ( + select( + col.label("key"), + func.coalesce(func.sum(CostAttribution.cost_usd), 0.0).label("cost"), + func.count().label("n"), + ) + .where( + CostAttribution.organization_id == organization_id, + CostAttribution.created_at >= since, + ) + .group_by(col) + .order_by(func.sum(CostAttribution.cost_usd).desc()) + ) + rows = (await db.execute(stmt)).all() + return [ + { + "key": str(r.key) if r.key else None, + "cost_usd": round(float(r.cost or 0.0), 6), + "query_count": r.n, + } + for r in rows + ] + + +async def slowest_queries( + db: AsyncSession, organization_id: uuid.UUID, since: datetime, limit: int = 10 +) -> list[dict[str, Any]]: + stmt = ( + select( + CostAttribution.query_execution_id, + CostAttribution.execution_time_ms, + CostAttribution.cost_usd, + CostAttribution.source_provider, + QueryExecution.natural_language, + ) + .join(QueryExecution, QueryExecution.id == CostAttribution.query_execution_id, isouter=True) + .where( + CostAttribution.organization_id == organization_id, + CostAttribution.created_at >= since, + CostAttribution.execution_time_ms.isnot(None), + ) + .order_by(CostAttribution.execution_time_ms.desc()) + .limit(limit) + ) + rows = (await db.execute(stmt)).all() + return [ + { + "query_execution_id": str(r.query_execution_id) if r.query_execution_id else None, + "execution_time_ms": r.execution_time_ms, + "cost_usd": round(float(r.cost_usd or 0.0), 6), + "source_provider": r.source_provider, + "question": r.natural_language, + } + for r in rows + ] + + +async def most_queried_tables( + db: AsyncSession, organization_id: uuid.UUID, since: datetime, limit: int = 10 +) -> list[dict[str, Any]]: + """Top referenced tables in the window (aggregated from the ``tables`` lists).""" + stmt = select(CostAttribution.tables).where( + CostAttribution.organization_id == organization_id, + CostAttribution.created_at >= since, + ) + counter: Counter[str] = Counter() + for (tables,) in (await db.execute(stmt)).all(): + counter.update(tables or []) + return [{"table": name, "query_count": n} for name, n in counter.most_common(limit)] diff --git a/backend/app/services/policy_service.py b/backend/app/services/policy_service.py new file mode 100644 index 0000000..baa60e9 --- /dev/null +++ b/backend/app/services/policy_service.py @@ -0,0 +1,361 @@ +"""Data-policy resolution + enforcement (Phase 4 — Milestone 3). + +Policies are enforced in the query pipeline *before* the SQL reaches the +connector. ``resolve_effective`` merges every applicable policy into one +most-restrictive :class:`EffectivePolicy`; ``enforce_sql`` then blocks or +rewrites the SQL; ``mask_result`` redacts PII columns in the returned rows. + +Security stance — **fail closed**. SQL-level rules (allow/block tables, blocked +columns, row filters) need sqlglot to analyze the query. If sqlglot is absent or +the SQL can't be parsed, those rules raise :class:`PolicyViolationError` rather than +letting the query through unfiltered. Row caps and column masking don't need +sqlglot and always apply. + +Known boundary: blocked-column checks see only *explicitly referenced* columns — +``SELECT *`` is not expanded against the schema. PII protection under ``SELECT *`` +is therefore the job of ``masked_columns`` (post-execution redaction), which is +star-safe. +""" + +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.auth import AuthContext +from app.core.exceptions import NotFoundError +from app.db.models.connection import DatabaseConnection +from app.db.models.data_policy import DataPolicy +from app.db.models.membership import ROLE_EDITOR +from app.services import audit_service + +logger = logging.getLogger("querywise") + +MASK_TOKEN = "***" + + +class PolicyViolationError(Exception): + """Raised when a data policy blocks a query. ``reason`` is user-facing.""" + + def __init__(self, reason: str) -> None: + super().__init__(reason) + self.reason = reason + + +def _norm(name: str) -> str: + return str(name).strip().lower() + + +def _fqtn(schema: str | None, name: str) -> str: + return f"{schema}.{name}" if schema else name + + +@dataclass +class EffectivePolicy: + """The merged, most-restrictive policy for one (connection, role).""" + + max_rows: int | None = None + max_runtime_seconds: int | None = None + # None = no allow-restriction; otherwise the *intersection* of allow-lists. + allowed_tables: set[str] | None = None + blocked_tables: set[str] = field(default_factory=set) + blocked_columns: set[str] = field(default_factory=set) + masked_columns: set[str] = field(default_factory=set) + row_filters: dict[str, str] = field(default_factory=dict) + sources: list[str] = field(default_factory=list) + + def has_sql_rules(self) -> bool: + return bool( + self.allowed_tables is not None + or self.blocked_tables + or self.blocked_columns + or self.row_filters + ) + + +def _applies(policy: DataPolicy, role: str | None) -> bool: + if not policy.enabled: + return False + roles = policy.applies_to_roles or [] + return not roles or (role in roles) + + +def merge_policies(policies: list[DataPolicy], role: str | None) -> EffectivePolicy | None: + """Merge applicable policies into one EffectivePolicy, or None if none apply.""" + applicable = [p for p in policies if _applies(p, role)] + if not applicable: + return None + + eff = EffectivePolicy() + for p in sorted(applicable, key=lambda x: x.priority): + if p.max_rows is not None: + eff.max_rows = p.max_rows if eff.max_rows is None else min(eff.max_rows, p.max_rows) + if p.max_runtime_seconds is not None: + eff.max_runtime_seconds = ( + p.max_runtime_seconds + if eff.max_runtime_seconds is None + else min(eff.max_runtime_seconds, p.max_runtime_seconds) + ) + if p.allowed_tables: + allow = {_norm(x) for x in p.allowed_tables} + eff.allowed_tables = allow if eff.allowed_tables is None else eff.allowed_tables & allow + eff.blocked_tables |= {_norm(x) for x in (p.blocked_tables or [])} + eff.blocked_columns |= {_norm(x) for x in (p.blocked_columns or [])} + eff.masked_columns |= {_norm(x) for x in (p.masked_columns or [])} + for k, v in (p.row_filters or {}).items(): + key = _norm(k) + if key in eff.row_filters: + eff.row_filters[key] = f"({eff.row_filters[key]}) AND ({v})" + else: + eff.row_filters[key] = v + eff.sources.append(p.name) + return eff + + +async def resolve_effective( + db: AsyncSession, connection_id: uuid.UUID, role: str | None +) -> EffectivePolicy | None: + """Load + merge the policies that apply to ``role`` on this connection.""" + result = await db.execute( + select(DataPolicy).where( + DataPolicy.connection_id == connection_id, + DataPolicy.enabled.is_(True), + ) + ) + return merge_policies(list(result.scalars().all()), role) + + +# --------------------------------------------------------------------------- # +# Enforcement +# --------------------------------------------------------------------------- # +def _table_matches(schema: str | None, name: str, entries: set[str]) -> bool: + return name in entries or _fqtn(schema, name) in entries + + +def _column_matches(tbl: tuple[str | None, str] | None, col: str, entries: set[str]) -> bool: + if col in entries: + return True + return tbl is not None and f"{tbl[1]}.{col}" in entries + + +def enforce_sql(eff: EffectivePolicy, sql: str, dialect: str | None) -> str: + """Enforce SQL-level rules; return the (possibly row-filtered) SQL. + + Raises :class:`PolicyViolationError` on a block or when the SQL can't be analyzed + while SQL-level rules are in force (fail closed). + """ + if not eff.has_sql_rules(): + return sql + + try: + import sqlglot + from sqlglot import exp + except ImportError: + raise PolicyViolationError( + "A data policy requires SQL analysis (sqlglot), which is not installed." + ) from None + + try: + tree = sqlglot.parse_one(sql, dialect=dialect) + except Exception: + raise PolicyViolationError("Query could not be analyzed for policy enforcement.") from None + if tree is None: + raise PolicyViolationError("Query could not be analyzed for policy enforcement.") + + # Referenced tables + alias map. + referenced: list[tuple[str | None, str]] = [] + alias_map: dict[str, tuple[str | None, str]] = {} + for t in tree.find_all(exp.Table): + name = _norm(t.name) if t.name else "" + if not name: + continue + schema = _norm(t.db) if t.db else None + referenced.append((schema, name)) + alias_map[name] = (schema, name) + if t.alias: + alias_map[_norm(t.alias)] = (schema, name) + + # allow-list: every referenced table must be allowed. + if eff.allowed_tables is not None: + for schema, name in referenced: + if not _table_matches(schema, name, eff.allowed_tables): + raise PolicyViolationError( + f"Table '{_fqtn(schema, name)}' is not in the allowed set for your role." + ) + + # block-list. + for schema, name in referenced: + if _table_matches(schema, name, eff.blocked_tables): + raise PolicyViolationError( + f"Access to table '{_fqtn(schema, name)}' is blocked by policy." + ) + + # blocked columns (explicit references only). + if eff.blocked_columns: + distinct = set(alias_map.values()) + single = next(iter(distinct)) if len(distinct) == 1 else None + for col in tree.find_all(exp.Column): + cname = _norm(col.name) if col.name else "" + if not cname or cname == "*": + continue + qual = _norm(col.table) if col.table else "" + tbl = alias_map.get(qual) if qual else single + if _column_matches(tbl, cname, eff.blocked_columns): + raise PolicyViolationError(f"Column '{cname}' is blocked by policy.") + + # row filters: replace each matching table with a filtered subquery. + if eff.row_filters: + for t in list(tree.find_all(exp.Table)): + name = _norm(t.name) if t.name else "" + schema = _norm(t.db) if t.db else None + filt = eff.row_filters.get(_fqtn(schema, name)) or eff.row_filters.get(name) + if not filt: + continue + alias = t.alias or t.name + try: + inner_table = t.copy() + inner_table.set("alias", None) + inner = exp.select("*").from_(inner_table).where(filt) + subq = exp.Subquery( + this=inner, alias=exp.TableAlias(this=exp.to_identifier(alias)) + ) + t.replace(subq) + except Exception: + raise PolicyViolationError( + f"Row-level filter for table '{name}' could not be applied." + ) from None + + return tree.sql(dialect=dialect) + + +def effective_limits( + eff: EffectivePolicy | None, conn_max_rows: int, conn_timeout: int +) -> tuple[int, int]: + """Tighten the connection's row/timeout caps with the policy's (min wins).""" + max_rows = conn_max_rows + timeout = conn_timeout + if eff is not None: + if eff.max_rows is not None: + max_rows = min(max_rows, eff.max_rows) + if eff.max_runtime_seconds is not None: + timeout = min(timeout, eff.max_runtime_seconds) + return max_rows, timeout + + +def mask_result( + eff: EffectivePolicy | None, columns: list[str], rows: list[list] +) -> tuple[list[list], list[str]]: + """Redact masked columns in ``rows`` by output-column name (star-safe). + + Returns ``(rows, masked_names)``. ``rows`` is unchanged when nothing matches. + """ + if eff is None or not eff.masked_columns: + return rows, [] + bare = {entry.split(".")[-1] for entry in eff.masked_columns} + idxs = [i for i, c in enumerate(columns) if _norm(c) in bare] + if not idxs: + return rows, [] + masked_rows = [] + for row in rows: + new = list(row) + for i in idxs: + new[i] = MASK_TOKEN + masked_rows.append(new) + return masked_rows, [columns[i] for i in idxs] + + +# --------------------------------------------------------------------------- # +# CRUD (connection-scoped; authorization is via require_connection_* deps) +# --------------------------------------------------------------------------- # +async def list_policies(db: AsyncSession, connection_id: uuid.UUID) -> list[DataPolicy]: + result = await db.execute( + select(DataPolicy) + .where(DataPolicy.connection_id == connection_id) + .order_by(DataPolicy.priority, DataPolicy.created_at) + ) + return list(result.scalars().all()) + + +async def get_policy( + db: AsyncSession, connection_id: uuid.UUID, policy_id: uuid.UUID +) -> DataPolicy: + policy = await db.get(DataPolicy, policy_id) + if policy is None or policy.connection_id != connection_id: + raise NotFoundError("DataPolicy", str(policy_id)) + return policy + + +async def create_policy( + db: AsyncSession, connection_id: uuid.UUID, ctx: AuthContext, **data: Any +) -> DataPolicy: + ctx.require_role(ROLE_EDITOR) + conn = await db.get(DatabaseConnection, connection_id) + if conn is None: + raise NotFoundError("DatabaseConnection", str(connection_id)) + policy = DataPolicy( + organization_id=conn.organization_id, connection_id=connection_id, **data + ) + db.add(policy) + await db.flush() + await db.refresh(policy) + await audit_service.record( + db, + organization_id=conn.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.POLICY_CREATED, + payload={ + "policy_id": str(policy.id), + "name": policy.name, + "connection_id": str(connection_id), + }, + ) + return policy + + +async def update_policy( + db: AsyncSession, + connection_id: uuid.UUID, + policy_id: uuid.UUID, + ctx: AuthContext, + updates: dict[str, Any], +) -> DataPolicy: + ctx.require_role(ROLE_EDITOR) + policy = await get_policy(db, connection_id, policy_id) + for key, value in updates.items(): + setattr(policy, key, value) + await db.flush() + await db.refresh(policy) + await audit_service.record( + db, + organization_id=policy.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.POLICY_UPDATED, + payload={"policy_id": str(policy.id), "name": policy.name}, + ) + return policy + + +async def delete_policy( + db: AsyncSession, connection_id: uuid.UUID, policy_id: uuid.UUID, ctx: AuthContext +) -> None: + ctx.require_role(ROLE_EDITOR) + policy = await get_policy(db, connection_id, policy_id) + name = policy.name + org_id = policy.organization_id + await db.delete(policy) + await db.flush() + await audit_service.record( + db, + organization_id=org_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.POLICY_DELETED, + payload={"policy_id": str(policy_id), "name": name}, + ) diff --git a/backend/app/services/query_service.py b/backend/app/services/query_service.py index ebdeb6f..b76555e 100644 --- a/backend/app/services/query_service.py +++ b/backend/app/services/query_service.py @@ -15,10 +15,51 @@ from app.llm.agents.sql_validator import SQLValidatorAgent, ValidationStatus from app.llm.router import route from app.semantic.context_builder import build_context +from app.services import audit_service, cost_service, policy_service from app.services.connection_service import get_connection, get_decrypted_connection_string +from app.services.lineage_service import dialect_for +from app.services.policy_service import PolicyViolationError from app.utils.sql_sanitizer import check_sql_safety +async def _enforce_policy_sql( + db: AsyncSession, + ctx: AuthContext, + connection_id: uuid.UUID, + eff: "policy_service.EffectivePolicy | None", + sql: str, + dialect: str | None, + *, + question: str | None = None, +) -> str: + """Apply a connection's data policy to ``sql`` before execution. + + Returns the (possibly row-filtered) SQL, or — on a policy block — records a + ``query.blocked`` audit event and raises a 403 with the reason. Returns + ``sql`` unchanged when no policy applies. + """ + if eff is None: + return sql + try: + return policy_service.enforce_sql(eff, sql, dialect) + except PolicyViolationError as pv: + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.QUERY_BLOCKED, + payload={ + "connection_id": str(connection_id), + "question": question, + "sql": sql, + "reason": pv.reason, + "policy": True, + }, + ) + raise AppError(f"Blocked by data policy: {pv.reason}", status_code=403) from pv + + async def execute_nl_query( db: AsyncSession, connection_id: uuid.UUID, @@ -99,22 +140,45 @@ async def execute_nl_query( validation = await validator.validate(final_sql, schema_tables) if validation.status == ValidationStatus.UNSAFE: + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.QUERY_BLOCKED, + payload={ + "connection_id": str(connection_id), + "question": question, + "sql": final_sql, + "reason": "; ".join(validation.issues), + }, + ) raise AppError( f"SQL safety violation: {'; '.join(validation.issues)}", status_code=403, ) - # Step 5: Execute query + # Step 5: Execute query (enforcing the connection's data policy first) connector = await get_or_create_connector( str(connection_id), conn.connector_type, connection_string ) + eff_policy = await policy_service.resolve_effective(db, connection_id, ctx.role) + dialect = dialect_for(conn.connector_type) + pol_max_rows, pol_timeout = policy_service.effective_limits( + eff_policy, conn.max_rows, conn.max_query_timeout_seconds + ) + # A policy block here is a hard stop (raises 403) — it must happen outside + # the error-handler retry loop so it is never treated as a fixable error. + run_sql = await _enforce_policy_sql( + db, ctx, connection_id, eff_policy, final_sql, dialect, question=question + ) try: with start_span("execute_query", **{"db.dialect": conn.connector_type}): result = await connector.execute_query( - final_sql, - timeout_seconds=conn.max_query_timeout_seconds, - max_rows=conn.max_rows, + run_sql, + timeout_seconds=pol_timeout, + max_rows=pol_max_rows, ) except Exception as e: # Try error handler on execution errors @@ -143,11 +207,16 @@ async def execute_nl_query( if validation.status != ValidationStatus.VALID: continue + # Re-apply the policy to each corrected SQL so row filters / blocks + # can't be bypassed by an LLM rewrite. A block (403) propagates out. + run_sql = await _enforce_policy_sql( + db, ctx, connection_id, eff_policy, final_sql, dialect, question=question + ) try: result = await connector.execute_query( - final_sql, - timeout_seconds=conn.max_query_timeout_seconds, - max_rows=conn.max_rows, + run_sql, + timeout_seconds=pol_timeout, + max_rows=pol_max_rows, ) break except Exception as retry_error: @@ -172,6 +241,10 @@ async def execute_nl_query( await db.flush() raise AppError(f"Query execution failed after {retry_count} retries: {e}") + # Apply policy column masking in place so redacted PII never reaches the + # interpreter LLM, the response, or persisted history. + result.rows, _ = policy_service.mask_result(eff_policy, result.columns, result.rows) + # Step 6: Interpret results summary = None highlights = [] @@ -210,6 +283,30 @@ async def execute_nl_query( db.add(execution) await db.flush() + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.QUERY_EXECUTED, + payload={ + "connection_id": str(connection_id), + "query_execution_id": str(execution.id), + "question": question, + "sql": final_sql, + "row_count": result.row_count, + }, + ) + + await cost_service.record_execution_cost( + db, + execution=execution, + ctx=ctx, + connector_type=conn.connector_type, + stats=result.stats, + final_sql=final_sql, + ) + return { "id": execution.id, "question": question, @@ -272,21 +369,41 @@ async def execute_raw_sql( # Step 1: Safety check safety_issues = check_sql_safety(sql) if safety_issues: + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.QUERY_BLOCKED, + payload={ + "connection_id": str(connection_id), + "sql": sql, + "reason": "; ".join(safety_issues), + }, + ) raise SQLSafetyError("; ".join(safety_issues)) conn = await get_connection(db, connection_id, ctx) connection_string = get_decrypted_connection_string(conn) - # Step 2: Execute query + # Step 2: Enforce the data policy, then execute. connector = await get_or_create_connector( str(connection_id), conn.connector_type, connection_string ) + eff_policy = await policy_service.resolve_effective(db, connection_id, ctx.role) + dialect = dialect_for(conn.connector_type) + pol_max_rows, pol_timeout = policy_service.effective_limits( + eff_policy, conn.max_rows, conn.max_query_timeout_seconds + ) + run_sql = await _enforce_policy_sql( + db, ctx, connection_id, eff_policy, sql, dialect, question=original_question + ) try: result = await connector.execute_query( - sql, - timeout_seconds=conn.max_query_timeout_seconds, - max_rows=conn.max_rows, + run_sql, + timeout_seconds=pol_timeout, + max_rows=pol_max_rows, ) except Exception as e: # Save failed execution to history @@ -305,6 +422,9 @@ async def execute_raw_sql( await db.flush() raise AppError(f"Query execution failed: {e}") from e + # Apply policy column masking in place before interpretation / persistence. + result.rows, _ = policy_service.mask_result(eff_policy, result.columns, result.rows) + # Step 3: Interpret results (LLM summary + follow-ups) summary = None highlights = [] @@ -352,6 +472,15 @@ async def execute_raw_sql( db.add(execution) await db.flush() + await cost_service.record_execution_cost( + db, + execution=execution, + ctx=ctx, + connector_type=conn.connector_type, + stats=result.stats, + final_sql=sql, + ) + return { "id": execution.id, "question": question_text, diff --git a/backend/app/services/schedule_service.py b/backend/app/services/schedule_service.py new file mode 100644 index 0000000..8e0afad --- /dev/null +++ b/backend/app/services/schedule_service.py @@ -0,0 +1,400 @@ +"""Scheduled reports — CRUD, cron math, and the run/deliver pipeline. + +A schedule runs a saved query (or dashboard) on a cron cadence and delivers the +result over a notification channel, optionally only when a threshold is met. + +Workspace-scoped like dashboards. Runs are executed by the scheduler under a +system-built :class:`AuthContext` for the schedule's own workspace (so the +saved query's connection auth + result cache still apply), not the request user. +""" + +from __future__ import annotations + +import logging +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.core.auth import AuthContext +from app.core.exceptions import NotFoundError, ValidationError +from app.db.models.dashboard import Dashboard +from app.db.models.membership import ROLE_ADMIN, ROLE_EDITOR +from app.db.models.saved_query import SavedQuery +from app.db.models.schedule import ( + STATUS_ERROR, + STATUS_SKIPPED, + STATUS_SUCCESS, + TARGET_SAVED_QUERY, + TARGET_TYPES, + Schedule, +) +from app.db.models.user import User +from app.notifications import get_notifier +from app.notifications.base import NotificationMessage +from app.services import audit_service, saved_query_service + +logger = logging.getLogger("querywise") + +_REPORT_ROW_LIMIT = 100 # rows rendered into a delivered report + + +# --------------------------------------------------------------------------- # +# Cron +# --------------------------------------------------------------------------- # +def compute_next_run(cron: str, after: datetime | None = None) -> datetime | None: + """Next fire time (UTC) for ``cron`` strictly after ``after`` (or now). + + Uses ``croniter`` (the optional ``[scheduling]`` extra); returns None and + logs if it isn't installed, so the rest of the feature degrades gracefully. + """ + try: + from croniter import croniter + except ImportError: + logger.warning( + "croniter not installed — install the [scheduling] extra for cron " + "scheduling; schedule '%s' will not auto-run", + cron, + ) + return None + base = after or datetime.now(UTC) + if base.tzinfo is None: + base = base.replace(tzinfo=UTC) + return croniter(cron, base).get_next(datetime) + + +def validate_cron(cron: str) -> None: + """Raise ValidationError if ``cron`` is malformed (no-op without croniter).""" + try: + from croniter import croniter + except ImportError: + return + if not croniter.is_valid(cron): + raise ValidationError(f"Invalid cron expression: {cron!r}") + + +# --------------------------------------------------------------------------- # +# Threshold +# --------------------------------------------------------------------------- # +_OPS = { + ">": lambda a, b: a > b, + ">=": lambda a, b: a >= b, + "<": lambda a, b: a < b, + "<=": lambda a, b: a <= b, + "==": lambda a, b: a == b, + "!=": lambda a, b: a != b, +} + + +def evaluate_threshold(threshold: dict | None, result: dict) -> bool | None: + """Return whether ``threshold`` is met by ``result``; None if not applicable. + + ``threshold`` = ``{"metric": "row_count"|"", "op": ">", "value": N}``. + For a column metric the first row's value in that column is compared. + """ + if not threshold: + return None + op = _OPS.get(threshold.get("op", ">")) + if op is None: + return None + target = threshold.get("value") + metric = threshold.get("metric", "row_count") + + if metric == "row_count": + actual: Any = result.get("row_count", 0) + else: + columns = result.get("columns") or [] + rows = result.get("rows") or [] + if metric not in columns or not rows: + return None + actual = rows[0][columns.index(metric)] + + try: + return op(actual, target) + except TypeError: + return None + + +# --------------------------------------------------------------------------- # +# Access + CRUD (mirrors dashboard_service) +# --------------------------------------------------------------------------- # +def _assert_access(schedule: Schedule, ctx: AuthContext, *, write: bool = False) -> None: + if ( + schedule.organization_id != ctx.organization_id + or schedule.workspace_id != ctx.workspace_id + ): + raise NotFoundError("Schedule", str(schedule.id)) + if write: + ctx.require_role(ROLE_EDITOR) + + +async def _load( + db: AsyncSession, schedule_id: uuid.UUID, ctx: AuthContext, *, write: bool = False +) -> Schedule: + schedule = await db.get(Schedule, schedule_id) + if schedule is None: + raise NotFoundError("Schedule", str(schedule_id)) + _assert_access(schedule, ctx, write=write) + return schedule + + +async def list_schedules(db: AsyncSession, ctx: AuthContext) -> list[Schedule]: + result = await db.execute( + select(Schedule) + .where( + Schedule.organization_id == ctx.organization_id, + Schedule.workspace_id == ctx.workspace_id, + ) + .order_by(Schedule.created_at.desc()) + ) + return list(result.scalars().all()) + + +async def get_schedule(db: AsyncSession, schedule_id: uuid.UUID, ctx: AuthContext) -> Schedule: + return await _load(db, schedule_id, ctx) + + +async def create_schedule(db: AsyncSession, ctx: AuthContext, **data: Any) -> Schedule: + ctx.require_role(ROLE_EDITOR) + target_type = data.get("target_type") + if target_type not in TARGET_TYPES: + raise ValidationError(f"target_type must be one of {TARGET_TYPES}") + validate_cron(data["cron"]) + await _assert_target_exists(db, ctx, target_type, data["target_id"]) + + schedule = Schedule( + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + owner_id=ctx.user_id, + **data, + ) + if schedule.enabled: + schedule.next_run_at = compute_next_run(schedule.cron) + db.add(schedule) + await db.flush() + await db.refresh(schedule) + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.SCHEDULE_CREATED, + payload={"schedule_id": str(schedule.id), "name": schedule.name, "cron": schedule.cron}, + ) + return schedule + + +async def update_schedule( + db: AsyncSession, schedule_id: uuid.UUID, ctx: AuthContext, updates: dict[str, Any] +) -> Schedule: + schedule = await _load(db, schedule_id, ctx, write=True) + if "cron" in updates and updates["cron"]: + validate_cron(updates["cron"]) + for key, value in updates.items(): + setattr(schedule, key, value) + # Recompute the next fire time when cadence or enabled state changes. + schedule.next_run_at = compute_next_run(schedule.cron) if schedule.enabled else None + await db.flush() + await db.refresh(schedule) + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.SCHEDULE_UPDATED, + payload={"schedule_id": str(schedule.id), "name": schedule.name}, + ) + return schedule + + +async def delete_schedule(db: AsyncSession, schedule_id: uuid.UUID, ctx: AuthContext) -> None: + schedule = await _load(db, schedule_id, ctx, write=True) + name = schedule.name + await db.delete(schedule) + await db.flush() + await audit_service.record( + db, + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + actor_id=ctx.user_id, + event_type=audit_service.SCHEDULE_DELETED, + payload={"schedule_id": str(schedule_id), "name": name}, + ) + + +async def _assert_target_exists( + db: AsyncSession, ctx: AuthContext, target_type: str, target_id: uuid.UUID +) -> None: + model = SavedQuery if target_type == TARGET_SAVED_QUERY else Dashboard + obj = await db.get(model, target_id) + if obj is None or obj.organization_id != ctx.organization_id: + raise NotFoundError(target_type, str(target_id)) + + +# --------------------------------------------------------------------------- # +# Run + deliver +# --------------------------------------------------------------------------- # +async def context_for_schedule(db: AsyncSession, schedule: Schedule) -> AuthContext: + """Build a system AuthContext bound to the schedule's own workspace. + + The owner (if still present) is the actor so the run's snapshots + query + audit attribute correctly; falls back to the bootstrapped admin. + """ + actor: User | None = None + if schedule.owner_id: + actor = await db.get(User, schedule.owner_id) + if actor is None: + from app.services import identity_service + + _, _, actor = await identity_service.bootstrap_default_identity(db) + return AuthContext( + user=actor, + organization_id=schedule.organization_id, + workspace_id=schedule.workspace_id, + role=ROLE_ADMIN, + ) + + +async def _run_target(db: AsyncSession, schedule: Schedule, ctx: AuthContext) -> dict: + """Execute the schedule's target and return a normalized result dict.""" + if schedule.target_type == TARGET_SAVED_QUERY: + saved = await db.get(SavedQuery, schedule.target_id) + if saved is None or saved.organization_id != ctx.organization_id: + raise NotFoundError("SavedQuery", str(schedule.target_id)) + result = await saved_query_service.run_saved_query( + db, saved, ctx, supplied_params=schedule.params, refresh=True + ) + return {"title": saved.name, **result, "sections": [{"name": saved.name, **result}]} + + # Dashboard: run every tile's saved query, collect a section per tile. + dashboard = await db.execute( + select(Dashboard) + .where(Dashboard.id == schedule.target_id) + .options(selectinload(Dashboard.tiles)) + ) + dash = dashboard.scalar_one_or_none() + if dash is None or dash.organization_id != ctx.organization_id: + raise NotFoundError("Dashboard", str(schedule.target_id)) + + sections = [] + total_rows = 0 + for tile in dash.tiles: + saved = await db.get(SavedQuery, tile.saved_query_id) + if saved is None: + continue + res = await saved_query_service.run_saved_query( + db, saved, ctx, supplied_params=schedule.params, refresh=True + ) + total_rows += res.get("row_count", 0) + sections.append({"name": tile.title or saved.name, **res}) + return {"title": dash.name, "row_count": total_rows, "sections": sections} + + +def _render_report(schedule: Schedule, result: dict, threshold_met: bool | None) -> tuple[str, str]: + """Return (text_body, html_body) for the delivered report.""" + lines = [f"Report: {schedule.name}", f"Target: {result.get('title', '')}", ""] + html_parts = [f"

{schedule.name}

"] + if threshold_met is not None: + flag = "MET" if threshold_met else "not met" + lines.append(f"Threshold {flag}: {schedule.threshold}") + html_parts.append(f"

Threshold {flag}: {schedule.threshold}

") + + for section in result.get("sections", []): + columns = section.get("columns") or [] + rows = (section.get("rows") or [])[:_REPORT_ROW_LIMIT] + lines.append(f"\n## {section['name']} — {section.get('row_count', 0)} rows") + lines.append("\t".join(str(c) for c in columns)) + for row in rows: + lines.append("\t".join("" if v is None else str(v) for v in row)) + + html_parts.append(f"

{section['name']} — {section.get('row_count', 0)} rows

") + head = "".join(f"{c}" for c in columns) + body = "".join( + "" + "".join(f"{'' if v is None else v}" for v in row) + "" + for row in rows + ) + html_parts.append( + f"" + f"{head}{body}
" + ) + return "\n".join(lines), "\n".join(html_parts) + + +async def run_one(db: AsyncSession, schedule: Schedule, *, reschedule: bool = True) -> dict: + """Execute one schedule end-to-end: run target → threshold → deliver. + + Updates ``last_*`` and (when ``reschedule``) ``next_run_at``. Returns a + summary dict. Raises only on a target-execution error after recording it. + """ + ctx = await context_for_schedule(db, schedule) + now = datetime.now(UTC) + delivered = False + threshold_met: bool | None = None + status = STATUS_SUCCESS + error: str | None = None + + try: + result = await _run_target(db, schedule, ctx) + threshold_met = evaluate_threshold(schedule.threshold, result) + + suppress = schedule.only_on_threshold and threshold_met is False + if suppress: + status = STATUS_SKIPPED + else: + text_body, html_body = _render_report(schedule, result, threshold_met) + subject = f"[QueryWise] {schedule.name}" + if threshold_met: + subject = f"[QueryWise] ⚠ {schedule.name} — threshold met" + await get_notifier(schedule.channel).send( + NotificationMessage( + subject=subject, + text_body=text_body, + html_body=html_body, + recipients=list(schedule.recipients or []), + ) + ) + delivered = True + except Exception as e: # noqa: BLE001 — record the failure on the schedule + status = STATUS_ERROR + error = str(e) + logger.exception("Schedule '%s' run failed", schedule.id) + + schedule.last_run_at = now + schedule.last_status = status + schedule.last_error = error + if reschedule and schedule.enabled: + schedule.next_run_at = compute_next_run(schedule.cron, after=now) + await db.flush() + + await audit_service.record( + db, + organization_id=schedule.organization_id, + workspace_id=schedule.workspace_id, + actor_id=schedule.owner_id, + event_type=audit_service.REPORT_DELIVERED if delivered else audit_service.SCHEDULE_RUN, + payload={ + "schedule_id": str(schedule.id), + "name": schedule.name, + "channel": schedule.channel, + "status": status, + "delivered": delivered, + "threshold_met": threshold_met, + "error": error, + }, + ) + return { + "schedule_id": schedule.id, + "status": status, + "delivered": delivered, + "threshold_met": threshold_met, + "error": error, + } + + +async def run_now(db: AsyncSession, schedule_id: uuid.UUID, ctx: AuthContext) -> dict: + """Manually trigger a schedule (does not change its cron cadence).""" + schedule = await _load(db, schedule_id, ctx, write=True) + return await run_one(db, schedule, reschedule=False) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 6f6dc9b..f76305d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -47,6 +47,9 @@ jobs = [ "arq>=0.26", "redis>=5.0", ] +scheduling = [ + "croniter>=2.0", +] dev = [ "pytest>=8.0", "pytest-asyncio>=0.24", diff --git a/backend/tests/test_audit_service.py b/backend/tests/test_audit_service.py new file mode 100644 index 0000000..d45feab --- /dev/null +++ b/backend/tests/test_audit_service.py @@ -0,0 +1,82 @@ +"""Unit tests for audit_service: payload shaping + the fire-and-forget guarantee. + +No live DB — a FakeSession stands in for AsyncSession (matching the other +service unit tests). The load-bearing property under test is that ``record`` +NEVER raises, even when the underlying write fails: auditing must not be able to +break the action being audited. +""" + +import uuid + +from app.services import audit_service + + +class _NestedCtx: + """Stand-in for the awaitable context manager returned by begin_nested().""" + + def __init__(self, session: "FakeSession", fail: bool) -> None: + self._session = session + self._fail = fail + + async def __aenter__(self): + if self._fail: + raise RuntimeError("savepoint boom") + return self + + async def __aexit__(self, *exc): + return False + + +class FakeSession: + def __init__(self, fail: bool = False) -> None: + self.added: list = [] + self._fail = fail + + def begin_nested(self): + return _NestedCtx(self, self._fail) + + def add(self, obj) -> None: + self.added.append(obj) + + +async def test_record_writes_event_with_request_id_folded_in(): + db = FakeSession() + org = uuid.uuid4() + actor = uuid.uuid4() + + await audit_service.record( + db, + organization_id=org, + actor_id=actor, + event_type=audit_service.CONNECTION_CREATED, + payload={"name": "warehouse"}, + ) + + assert len(db.added) == 1 + event = db.added[0] + assert event.organization_id == org + assert event.actor_id == actor + assert event.event_type == "connection.created" + assert event.payload["name"] == "warehouse" + + +async def test_record_never_raises_when_write_fails(): + db = FakeSession(fail=True) + + # Must return normally despite the savepoint blowing up — fire-and-forget. + await audit_service.record( + db, + organization_id=uuid.uuid4(), + event_type=audit_service.QUERY_BLOCKED, + payload={"reason": "DDL not allowed"}, + ) + + assert db.added == [] + + +def test_event_type_constants_are_unique_and_listed(): + assert len(audit_service.EVENT_TYPES) == len(set(audit_service.EVENT_TYPES)) + # Every module-level dotted constant is advertised in EVENT_TYPES. + for name, value in vars(audit_service).items(): + if name.isupper() and isinstance(value, str) and "." in value: + assert value in audit_service.EVENT_TYPES diff --git a/backend/tests/test_cost_service.py b/backend/tests/test_cost_service.py new file mode 100644 index 0000000..9080750 --- /dev/null +++ b/backend/tests/test_cost_service.py @@ -0,0 +1,55 @@ +"""Unit tests for cost_service pure helpers: compute_cost + table extraction. + +Aggregations are DB-backed and exercised via integration; here we cover the +pricing math and the SQL table extraction (skips without sqlglot). +""" + +import pytest + +from app.config import settings +from app.services import cost_service as svc + + +def test_compute_cost_scanned_bytes(): + # 1 TiB scanned at the default $6.25/TiB. + assert svc.compute_cost({"scanned_bytes": 1024**4}, 100) == 6.25 + + +def test_compute_cost_prefers_billed_over_scanned(): + cost = svc.compute_cost({"scanned_bytes": 2 * 1024**4, "billed_bytes": 1024**4}, 0) + assert cost == 6.25 + + +def test_compute_cost_zero_without_stats_or_time_price(): + assert svc.compute_cost({}, 5000) == 0.0 + assert svc.compute_cost(None, None) == 0.0 + + +def test_compute_cost_time_fallback(monkeypatch): + monkeypatch.setattr(settings, "cost_per_second_usd", 0.01) + # 2000 ms -> 2 s * $0.01 = $0.02, only when no warehouse stats present. + assert svc.compute_cost({}, 2000) == 0.02 + # Warehouse stats present -> time fallback is NOT added. + assert svc.compute_cost({"scanned_bytes": 1024**4}, 2000) == 6.25 + + +def test_compute_cost_slot_and_dbu(monkeypatch): + monkeypatch.setattr(settings, "cost_per_slot_ms_usd", 0.001) + monkeypatch.setattr(settings, "cost_per_dbu_usd", 0.5) + assert svc.compute_cost({"slot_ms": 1000}, 0) == 1.0 + assert svc.compute_cost({"dbu": 4}, 0) == 2.0 + + +# --- table extraction (needs sqlglot) -------------------------------------- +pytest.importorskip("sqlglot") + + +def test_referenced_tables_extracts_tables(): + sql = "SELECT id FROM public.orders o JOIN users u ON u.id = o.uid" + tables = svc._referenced_tables(sql, "postgresql") + assert "public.orders" in tables + assert "users" in tables + + +def test_referenced_tables_empty_on_none(): + assert svc._referenced_tables(None, "postgresql") == [] diff --git a/backend/tests/test_notifications.py b/backend/tests/test_notifications.py new file mode 100644 index 0000000..a7a5558 --- /dev/null +++ b/backend/tests/test_notifications.py @@ -0,0 +1,48 @@ +"""Unit tests for the notification adapters: factory fallback + deliver().""" + +import app.notifications as notifications +from app.config import settings +from app.notifications import deliver, get_notifier +from app.notifications.email import EmailNotifier +from app.notifications.log import LogNotifier +from app.notifications.slack import SlackNotifier + + +def test_email_channel_falls_back_to_log_when_unconfigured(monkeypatch): + monkeypatch.setattr(settings, "smtp_host", None) + assert isinstance(get_notifier("email"), LogNotifier) + + +def test_email_channel_uses_email_when_configured(monkeypatch): + monkeypatch.setattr(settings, "smtp_host", "smtp.example.com") + assert isinstance(get_notifier("email"), EmailNotifier) + + +def test_slack_channel_falls_back_to_log_when_unconfigured(monkeypatch): + monkeypatch.setattr(settings, "slack_webhook_url", None) + assert isinstance(get_notifier("slack"), LogNotifier) + + +def test_slack_channel_uses_slack_when_configured(monkeypatch): + monkeypatch.setattr(settings, "slack_webhook_url", "https://hooks.slack.com/x") + assert isinstance(get_notifier("slack"), SlackNotifier) + + +def test_unknown_channel_is_log(): + assert isinstance(get_notifier("carrier-pigeon"), LogNotifier) + + +async def test_deliver_returns_true_on_success(monkeypatch): + monkeypatch.setattr(settings, "smtp_host", None) # -> LogNotifier, always succeeds + ok = await deliver("email", subject="hi", text_body="body", recipients=["a@b.c"]) + assert ok is True + + +async def test_deliver_swallows_failure_and_returns_false(monkeypatch): + class BoomNotifier: + async def send(self, message): + raise RuntimeError("smtp down") + + monkeypatch.setattr(notifications, "get_notifier", lambda channel: BoomNotifier()) + ok = await deliver("email", subject="hi", text_body="body", recipients=["a@b.c"]) + assert ok is False diff --git a/backend/tests/test_policy_service.py b/backend/tests/test_policy_service.py new file mode 100644 index 0000000..d813b97 --- /dev/null +++ b/backend/tests/test_policy_service.py @@ -0,0 +1,160 @@ +"""Unit tests for policy_service: merge, enforcement, masking. + +Pure functions only — DataPolicy is stubbed with SimpleNamespace (the merge +reads attributes, not the ORM). enforce_sql tests skip cleanly without sqlglot. +""" + +from types import SimpleNamespace + +import pytest + +from app.services import policy_service as svc +from app.services.policy_service import PolicyViolationError + + +def _policy(**kw): + base = dict( + enabled=True, + priority=100, + applies_to_roles=[], + max_rows=None, + max_runtime_seconds=None, + allowed_tables=[], + blocked_tables=[], + blocked_columns=[], + masked_columns=[], + row_filters={}, + name="p", + ) + base.update(kw) + return SimpleNamespace(**base) + + +# --- merge ------------------------------------------------------------------ +def test_merge_none_when_no_policies(): + assert svc.merge_policies([], "viewer") is None + + +def test_merge_skips_inapplicable_roles(): + p = _policy(applies_to_roles=["admin"], max_rows=10) + assert svc.merge_policies([p], "viewer") is None + eff = svc.merge_policies([p], "admin") + assert eff is not None and eff.max_rows == 10 + + +def test_merge_takes_minimum_limits(): + eff = svc.merge_policies( + [_policy(max_rows=100, max_runtime_seconds=30), _policy(max_rows=50)], "viewer" + ) + assert eff.max_rows == 50 + assert eff.max_runtime_seconds == 30 + + +def test_merge_allowed_tables_intersection(): + eff = svc.merge_policies( + [_policy(allowed_tables=["a", "b"]), _policy(allowed_tables=["b", "c"])], "viewer" + ) + assert eff.allowed_tables == {"b"} + + +def test_merge_unions_blocked_and_masked(): + eff = svc.merge_policies( + [_policy(blocked_columns=["ssn"], masked_columns=["email"]), + _policy(blocked_columns=["dob"])], + "viewer", + ) + assert eff.blocked_columns == {"ssn", "dob"} + assert eff.masked_columns == {"email"} + + +def test_merge_row_filters_anded(): + eff = svc.merge_policies( + [_policy(row_filters={"orders": "region = 'EU'"}), + _policy(row_filters={"orders": "amount > 0"})], + "viewer", + ) + assert "AND" in eff.row_filters["orders"] + + +# --- effective_limits ------------------------------------------------------- +def test_effective_limits_tightens_only(): + eff = svc.merge_policies([_policy(max_rows=10)], "viewer") + assert svc.effective_limits(eff, 1000, 30) == (10, 30) + # Policy looser than connection → connection wins. + eff2 = svc.merge_policies([_policy(max_rows=5000)], "viewer") + assert svc.effective_limits(eff2, 1000, 30) == (1000, 30) + + +def test_effective_limits_none_policy(): + assert svc.effective_limits(None, 1000, 30) == (1000, 30) + + +# --- masking ---------------------------------------------------------------- +def test_mask_result_redacts_by_output_name(): + eff = svc.merge_policies([_policy(masked_columns=["users.email"])], "viewer") + cols = ["id", "email"] + rows = [[1, "a@b.c"], [2, "d@e.f"]] + masked, names = svc.mask_result(eff, cols, rows) + assert names == ["email"] + assert masked == [[1, svc.MASK_TOKEN], [2, svc.MASK_TOKEN]] + # Original rows untouched (new list returned). + assert rows[0][1] == "a@b.c" + + +def test_mask_result_noop_when_no_match(): + eff = svc.merge_policies([_policy(masked_columns=["ssn"])], "viewer") + rows = [[1, "x"]] + masked, names = svc.mask_result(eff, ["id", "name"], rows) + assert names == [] and masked is rows + + +def test_mask_result_none_policy(): + rows = [[1]] + assert svc.mask_result(None, ["id"], rows) == (rows, []) + + +# --- enforcement (needs sqlglot) ------------------------------------------- +sqlglot = pytest.importorskip("sqlglot") + + +def test_enforce_no_rules_returns_sql_unchanged(): + eff = svc.merge_policies([_policy(max_rows=10)], "viewer") # limits only + assert svc.enforce_sql(eff, "SELECT 1", "postgres") == "SELECT 1" + + +def test_enforce_blocked_table(): + eff = svc.merge_policies([_policy(blocked_tables=["secrets"])], "viewer") + with pytest.raises(PolicyViolationError, match="secrets"): + svc.enforce_sql(eff, "SELECT * FROM secrets", "postgres") + + +def test_enforce_allowed_table_violation(): + eff = svc.merge_policies([_policy(allowed_tables=["orders"])], "viewer") + with pytest.raises(PolicyViolationError, match="users"): + svc.enforce_sql(eff, "SELECT * FROM users", "postgres") + # Allowed table passes through. + assert "orders" in svc.enforce_sql(eff, "SELECT id FROM orders", "postgres") + + +def test_enforce_blocked_column(): + eff = svc.merge_policies([_policy(blocked_columns=["ssn"])], "viewer") + with pytest.raises(PolicyViolationError, match="ssn"): + svc.enforce_sql(eff, "SELECT ssn FROM people", "postgres") + + +def test_enforce_injects_row_filter(): + eff = svc.merge_policies([_policy(row_filters={"orders": "region = 'EU'"})], "viewer") + out = svc.enforce_sql(eff, "SELECT id FROM orders", "postgres").lower() + assert "region" in out and "eu" in out + + +def test_enforce_fails_closed_when_sql_cannot_be_parsed(monkeypatch): + # When SQL-level rules are in force but the query can't be analyzed, the + # engine must block (fail closed), never pass the query through unfiltered. + def _boom(*a, **k): + raise ValueError("parse error") + + monkeypatch.setattr(sqlglot, "parse_one", _boom) + eff = svc.merge_policies([_policy(blocked_tables=["x"])], "viewer") + with pytest.raises(PolicyViolationError): + svc.enforce_sql(eff, "SELECT * FROM whatever", "postgres") diff --git a/backend/tests/test_schedule_service.py b/backend/tests/test_schedule_service.py new file mode 100644 index 0000000..7280f19 --- /dev/null +++ b/backend/tests/test_schedule_service.py @@ -0,0 +1,64 @@ +"""Unit tests for schedule_service: cron math + threshold evaluation. + +Pure functions only — no DB. Cron tests skip cleanly when croniter (the +[scheduling] extra) isn't installed, mirroring the lineage/sqlglot pattern. +""" + +from datetime import UTC, datetime + +import pytest + +from app.core.exceptions import ValidationError +from app.services import schedule_service as svc + +croniter = pytest.importorskip("croniter") + + +# --- cron ------------------------------------------------------------------- +def test_compute_next_run_advances_to_next_slot(): + base = datetime(2026, 6, 8, 10, 0, tzinfo=UTC) + nxt = svc.compute_next_run("0 9 * * *", after=base) # daily 09:00 + assert nxt == datetime(2026, 6, 9, 9, 0, tzinfo=UTC) + + +def test_compute_next_run_naive_base_treated_as_utc(): + nxt = svc.compute_next_run("*/15 * * * *", after=datetime(2026, 6, 8, 10, 7)) + assert nxt == datetime(2026, 6, 8, 10, 15, tzinfo=UTC) + + +def test_validate_cron_rejects_garbage(): + with pytest.raises(ValidationError): + svc.validate_cron("not a cron") + + +# --- threshold -------------------------------------------------------------- +def test_threshold_none_returns_none(): + assert svc.evaluate_threshold(None, {"row_count": 5}) is None + + +def test_threshold_row_count_met(): + th = {"metric": "row_count", "op": ">", "value": 10} + assert svc.evaluate_threshold(th, {"row_count": 15}) is True + assert svc.evaluate_threshold(th, {"row_count": 5}) is False + + +def test_threshold_column_value_first_row(): + th = {"metric": "amount", "op": ">=", "value": 100} + result = {"columns": ["id", "amount"], "rows": [[1, 250], [2, 50]]} + assert svc.evaluate_threshold(th, result) is True + + +def test_threshold_missing_column_returns_none(): + th = {"metric": "ghost", "op": ">", "value": 1} + assert svc.evaluate_threshold(th, {"columns": ["a"], "rows": [[1]]}) is None + + +def test_threshold_non_numeric_compare_returns_none(): + th = {"metric": "name", "op": ">", "value": 5} + result = {"columns": ["name"], "rows": [["alice"]]} + assert svc.evaluate_threshold(th, result) is None + + +def test_threshold_unknown_op_returns_none(): + th = {"metric": "row_count", "op": "≈", "value": 5} + assert svc.evaluate_threshold(th, {"row_count": 5}) is None diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index dbff764..4989f4e 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -14,6 +14,10 @@ import { DictionaryPage } from './pages/DictionaryPage'; import { KnowledgePage } from './pages/KnowledgePage'; import { CatalogPage } from './pages/CatalogPage'; import { HistoryPage } from './pages/HistoryPage'; +import { AuditPage } from './pages/AuditPage'; +import { SchedulesPage } from './pages/SchedulesPage'; +import { PoliciesPage } from './pages/PoliciesPage'; +import { AnalyticsPage } from './pages/AnalyticsPage'; export default function App() { return ( @@ -37,6 +41,10 @@ export default function App() { } /> } /> } /> + } /> + } /> + } /> + } /> diff --git a/frontend/src/api/analyticsApi.ts b/frontend/src/api/analyticsApi.ts new file mode 100644 index 0000000..3f9c898 --- /dev/null +++ b/frontend/src/api/analyticsApi.ts @@ -0,0 +1,15 @@ +import { api } from './client'; +import type { CostByEntry, SlowestQuery, TableUsage, UsageSummary } from '../types/api'; + +export const analyticsApi = { + usage: (days: number) => + api.get('/analytics/usage', { params: { days } }).then((r) => r.data), + cost: (by: string, days: number) => + api.get('/analytics/cost', { params: { by, days } }).then((r) => r.data), + slowest: (days: number, limit = 10) => + api + .get('/analytics/slowest', { params: { days, limit } }) + .then((r) => r.data), + tables: (days: number, limit = 10) => + api.get('/analytics/tables', { params: { days, limit } }).then((r) => r.data), +}; diff --git a/frontend/src/api/auditApi.ts b/frontend/src/api/auditApi.ts new file mode 100644 index 0000000..925e2fd --- /dev/null +++ b/frontend/src/api/auditApi.ts @@ -0,0 +1,45 @@ +import { api } from './client'; +import type { AuditEvent } from '../types/api'; + +export interface AuditListParams { + event_type?: string; + actor_id?: string; + limit?: number; + offset?: number; +} + +export const auditApi = { + list: (params: AuditListParams) => + api + .get('/audit-events', { + params: { + event_type: params.event_type || undefined, + actor_id: params.actor_id || undefined, + limit: params.limit ?? 100, + offset: params.offset ?? 0, + }, + }) + .then((r) => r.data), + + eventTypes: () => api.get('/audit-events/event-types').then((r) => r.data), + + // Fetch the CSV through axios (carries the session cookie) and trigger a + // client-side download, rather than navigating to the raw URL. + exportCsv: async (params: Pick) => { + const res = await api.get('/audit-events/export', { + params: { + event_type: params.event_type || undefined, + actor_id: params.actor_id || undefined, + }, + responseType: 'blob', + }); + const url = URL.createObjectURL(res.data as Blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'audit_events.csv'; + document.body.appendChild(a); + a.click(); + a.remove(); + URL.revokeObjectURL(url); + }, +}; diff --git a/frontend/src/api/policiesApi.ts b/frontend/src/api/policiesApi.ts new file mode 100644 index 0000000..76514b2 --- /dev/null +++ b/frontend/src/api/policiesApi.ts @@ -0,0 +1,15 @@ +import { api } from './client'; +import type { DataPolicy } from '../types/api'; + +const base = (connectionId: string) => `/connections/${connectionId}/policies`; + +export const policiesApi = { + list: (connectionId: string) => + api.get(base(connectionId)).then((r) => r.data), + create: (connectionId: string, data: Partial) => + api.post(base(connectionId), data).then((r) => r.data), + update: (connectionId: string, id: string, data: Partial) => + api.put(`${base(connectionId)}/${id}`, data).then((r) => r.data), + remove: (connectionId: string, id: string) => + api.delete(`${base(connectionId)}/${id}`).then((r) => r.data), +}; diff --git a/frontend/src/api/schedulesApi.ts b/frontend/src/api/schedulesApi.ts new file mode 100644 index 0000000..fcab18c --- /dev/null +++ b/frontend/src/api/schedulesApi.ts @@ -0,0 +1,13 @@ +import { api } from './client'; +import type { Schedule, ScheduleRunResult } from '../types/api'; + +export const schedulesApi = { + list: () => api.get('/schedules').then((r) => r.data), + create: (data: Partial) => + api.post('/schedules', data).then((r) => r.data), + update: (id: string, data: Partial) => + api.put(`/schedules/${id}`, data).then((r) => r.data), + remove: (id: string) => api.delete(`/schedules/${id}`).then((r) => r.data), + run: (id: string) => + api.post(`/schedules/${id}/run`).then((r) => r.data), +}; diff --git a/frontend/src/components/layout/AppLayout.tsx b/frontend/src/components/layout/AppLayout.tsx index a95df3d..82d21d3 100644 --- a/frontend/src/components/layout/AppLayout.tsx +++ b/frontend/src/components/layout/AppLayout.tsx @@ -22,6 +22,10 @@ import { IconBookmark, IconLayoutDashboard, IconBook2, + IconShieldLock, + IconClockHour4, + IconLockCog, + IconChartHistogram, } from '@tabler/icons-react'; import { Outlet, useLocation, useNavigate } from 'react-router-dom'; import { EmbeddingStatusBanner } from '../common/EmbeddingStatusBanner'; @@ -38,7 +42,11 @@ const NAV_ITEMS = [ { label: 'Dictionary', path: '/dictionary', icon: IconVocabulary }, { label: 'Knowledge', path: '/knowledge', icon: IconFileText }, { label: 'Catalog', path: '/catalog', icon: IconBook2 }, + { label: 'Schedules', path: '/schedules', icon: IconClockHour4 }, { label: 'History', path: '/history', icon: IconHistory }, + { label: 'Usage & Cost', path: '/analytics', icon: IconChartHistogram, adminOnly: true }, + { label: 'Policies', path: '/policies', icon: IconLockCog, adminOnly: true }, + { label: 'Audit Log', path: '/audit', icon: IconShieldLock, adminOnly: true }, ]; const ROLE_COLOR: Record = { @@ -113,7 +121,7 @@ export function AppLayout() { - {NAV_ITEMS.map((item) => ( + {NAV_ITEMS.filter((item) => !item.adminOnly || role === 'admin').map((item) => ( analyticsApi.usage(days), + enabled, + }); +} + +export function useCostBy(by: string, days: number, enabled = true) { + return useQuery({ + queryKey: ['analytics', 'cost', by, days], + queryFn: () => analyticsApi.cost(by, days), + enabled, + }); +} + +export function useSlowestQueries(days: number, enabled = true) { + return useQuery({ + queryKey: ['analytics', 'slowest', days], + queryFn: () => analyticsApi.slowest(days), + enabled, + }); +} + +export function useTableUsage(days: number, enabled = true) { + return useQuery({ + queryKey: ['analytics', 'tables', days], + queryFn: () => analyticsApi.tables(days), + enabled, + }); +} diff --git a/frontend/src/hooks/useAudit.ts b/frontend/src/hooks/useAudit.ts new file mode 100644 index 0000000..1f7662e --- /dev/null +++ b/frontend/src/hooks/useAudit.ts @@ -0,0 +1,19 @@ +import { useQuery } from '@tanstack/react-query'; +import { auditApi, type AuditListParams } from '../api/auditApi'; + +export function useAuditEvents(params: AuditListParams, enabled = true) { + return useQuery({ + queryKey: ['audit', 'events', params], + queryFn: () => auditApi.list(params), + enabled, + }); +} + +export function useAuditEventTypes(enabled = true) { + return useQuery({ + queryKey: ['audit', 'event-types'], + queryFn: () => auditApi.eventTypes(), + staleTime: Infinity, + enabled, + }); +} diff --git a/frontend/src/hooks/usePolicies.ts b/frontend/src/hooks/usePolicies.ts new file mode 100644 index 0000000..5322921 --- /dev/null +++ b/frontend/src/hooks/usePolicies.ts @@ -0,0 +1,36 @@ +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; +import { policiesApi } from '../api/policiesApi'; +import type { DataPolicy } from '../types/api'; + +export function usePolicies(connectionId: string | undefined) { + return useQuery({ + queryKey: ['policies', connectionId], + queryFn: () => policiesApi.list(connectionId!), + enabled: !!connectionId, + }); +} + +export function useCreatePolicy(connectionId: string) { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (data: Partial) => policiesApi.create(connectionId, data), + onSuccess: () => qc.invalidateQueries({ queryKey: ['policies', connectionId] }), + }); +} + +export function useUpdatePolicy(connectionId: string) { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ id, data }: { id: string; data: Partial }) => + policiesApi.update(connectionId, id, data), + onSuccess: () => qc.invalidateQueries({ queryKey: ['policies', connectionId] }), + }); +} + +export function useDeletePolicy(connectionId: string) { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (id: string) => policiesApi.remove(connectionId, id), + onSuccess: () => qc.invalidateQueries({ queryKey: ['policies', connectionId] }), + }); +} diff --git a/frontend/src/hooks/useSchedules.ts b/frontend/src/hooks/useSchedules.ts new file mode 100644 index 0000000..cf1c324 --- /dev/null +++ b/frontend/src/hooks/useSchedules.ts @@ -0,0 +1,42 @@ +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; +import { schedulesApi } from '../api/schedulesApi'; +import type { Schedule } from '../types/api'; + +const KEY = ['schedules']; + +export function useSchedules() { + return useQuery({ queryKey: KEY, queryFn: schedulesApi.list }); +} + +export function useCreateSchedule() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (data: Partial) => schedulesApi.create(data), + onSuccess: () => qc.invalidateQueries({ queryKey: KEY }), + }); +} + +export function useUpdateSchedule() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ id, data }: { id: string; data: Partial }) => + schedulesApi.update(id, data), + onSuccess: () => qc.invalidateQueries({ queryKey: KEY }), + }); +} + +export function useDeleteSchedule() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (id: string) => schedulesApi.remove(id), + onSuccess: () => qc.invalidateQueries({ queryKey: KEY }), + }); +} + +export function useRunSchedule() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (id: string) => schedulesApi.run(id), + onSuccess: () => qc.invalidateQueries({ queryKey: KEY }), + }); +} diff --git a/frontend/src/pages/AnalyticsPage.tsx b/frontend/src/pages/AnalyticsPage.tsx new file mode 100644 index 0000000..490b4c3 --- /dev/null +++ b/frontend/src/pages/AnalyticsPage.tsx @@ -0,0 +1,208 @@ +import { useState } from 'react'; +import { + Stack, + Title, + Group, + Text, + Paper, + SimpleGrid, + Select, + SegmentedControl, + Table, + Alert, + Loader, +} from '@mantine/core'; +import { useAuth } from '../context/auth'; +import { + useCostBy, + useSlowestQueries, + useTableUsage, + useUsageSummary, +} from '../hooks/useAnalytics'; + +function formatBytes(n: number): string { + if (!n) return '0 B'; + const units = ['B', 'KB', 'MB', 'GB', 'TB']; + const i = Math.min(units.length - 1, Math.floor(Math.log(n) / Math.log(1024))); + return `${(n / 1024 ** i).toFixed(1)} ${units[i]}`; +} + +function StatCard({ label, value }: { label: string; value: string }) { + return ( + + + {label} + + + {value} + + + ); +} + +export function AnalyticsPage() { + const { role } = useAuth(); + const isAdmin = role === 'admin'; + const [days, setDays] = useState('30'); + const [by, setBy] = useState('workspace'); + const d = Number(days); + + const usage = useUsageSummary(d, isAdmin); + const cost = useCostBy(by, d, isAdmin); + const slowest = useSlowestQueries(d, isAdmin); + const tables = useTableUsage(d, isAdmin); + + if (!isAdmin) { + return ( + + Usage & Cost + + Usage analytics are restricted to workspace administrators. + + + ); + } + + const s = usage.data; + + return ( + + + Usage & Cost + { + setEventType(v); + setPage(0); + }} + /> + + + {isLoading ? ( + + + + ) : !events || events.length === 0 ? ( + No audit events for this filter. + ) : ( + <> + + + + + Time + Event + Actor + Details + + + + {events.map((e) => ( + + ))} + +
+ + + + Page {page + 1} + + + + + + + + + + )} +
+ ); +} diff --git a/frontend/src/pages/PoliciesPage.tsx b/frontend/src/pages/PoliciesPage.tsx new file mode 100644 index 0000000..1fe5570 --- /dev/null +++ b/frontend/src/pages/PoliciesPage.tsx @@ -0,0 +1,314 @@ +import { useState } from 'react'; +import { + Stack, + Title, + Group, + Text, + Badge, + Button, + Modal, + TextInput, + Textarea, + MultiSelect, + NumberInput, + Switch, + Table, + ActionIcon, + Select, + Alert, + Loader, + Code, +} from '@mantine/core'; +import { useForm } from '@mantine/form'; +import { IconPencil, IconTrash, IconPlus } from '@tabler/icons-react'; +import { useConnections } from '../hooks/useConnections'; +import { + usePolicies, + useCreatePolicy, + useUpdatePolicy, + useDeletePolicy, +} from '../hooks/usePolicies'; +import type { DataPolicy } from '../types/api'; + +// "table: condition" per line <-> { table: condition } +function filtersToText(f: Record): string { + return Object.entries(f) + .map(([k, v]) => `${k}: ${v}`) + .join('\n'); +} +function textToFilters(text: string): Record { + const out: Record = {}; + for (const line of text.split('\n')) { + const idx = line.indexOf(':'); + if (idx > 0) { + const k = line.slice(0, idx).trim(); + const v = line.slice(idx + 1).trim(); + if (k && v) out[k] = v; + } + } + return out; +} +const csv = (s: string) => s.split(',').map((x) => x.trim()).filter(Boolean); + +export function PoliciesPage() { + const { data: connections } = useConnections(); + const [connectionId, setConnectionId] = useState(null); + const connOptions = (connections ?? []).map((c) => ({ value: c.id, label: c.name })); + if (!connectionId && connOptions.length > 0) setConnectionId(connOptions[0].value); + + const { data: policies, isLoading } = usePolicies(connectionId ?? undefined); + const del = useDeletePolicy(connectionId ?? ''); + const [editing, setEditing] = useState(null); + const [open, setOpen] = useState(false); + + return ( + + + Data Policies + + + + +