diff --git a/CLAUDE.md b/CLAUDE.md index 3f4e0ed..82b9903 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -31,6 +31,8 @@ python backend/scripts/seed_ifrs9_metadata.py The sample-db contains an **IFRS 9 banking schema** with 6 tables: `counterparties`, `facilities`, `exposures`, `ecl_provisions`, `collateral`, `staging_history`. Connection string (from Docker): `postgresql://sample:sample_dev@sample-db:5432/sampledb`. +The same container hosts a second database, **`opsdb`** — a deliberately hostile operational-style schema (no FKs, `tenant_id` scoping, soft deletes, int status codes, lookup tables, business-logic views, a dead `customers_bak` table) used to exercise the semantic layer compiler. Connection string: `postgresql://sample:sample_dev@sample-db:5432/opsdb`. The container runs with `pg_stat_statements` preloaded; populate query logs with `python backend/scripts/run_ops_workload.py`. Fixtures: `backend/tests/fixtures/ops_seed.sql` (+ `ops_extensions.sql`). Init scripts only apply on a fresh volume (`docker compose down -v`). + **Auto-setup** (`AUTO_SETUP_SAMPLE_DB=true`, default): On first `docker compose up`, the backend automatically creates the connection, introspects the schema, seeds all metadata (10 glossary terms, 8 metrics, 43 dictionary entries across 12 columns, 1 knowledge document), and launches background embedding generation. Logic in `app/services/setup_service.py`, called from `main.py` lifespan hook. Idempotent — safe to restart. **Startup sequence** (in `main.py` lifespan): @@ -270,6 +272,17 @@ Makes the semantic layer discoverable and trustworthy. Two milestones; migration - **Endpoints:** `/connections/{id}/catalog/{search,facets,lineage}`, plus `/status`, `/versions`, `/versions/{v}`, and `/lineage` sub-resources on the metric/glossary/sample-query/saved-query routers. - **Heads-up:** existing rows migrate to `status='draft'`, `version=1`. The saved-query PUT routes any `status` change through the governed lifecycle (no raw status writes). sqlglot is a new optional dep — install the `[lineage]` extra (or rebuild the backend image) for lineage to populate. +## Semantic layer compiler (Slice 1) + +Attacks the cold-start problem: point QueryWise at an operational DB with an empty semantic layer and get reviewable draft objects. Migration `013`. + +- **Engine** (`app/semantic_compiler/`): self-contained package (dataclasses + pure functions, no FastAPI/ORM imports — standalone-CLI extractable). Collectors gather evidence (catalog via the connector, `pg_stats`/CHECK/enums/unique indexes, `pg_get_viewdef`, `pg_stat_statements`); `sqlmeta.py` (sqlglot, graceful degradation) extracts join pairs/aggregates/GROUP BY/WHERE; inference modules emit `Finding`s with evidence + confidence: **join inference without FKs** (naming + value-overlap probe + log co-occurrence; failed probe kills the candidate), dictionaries (enum/CHECK/lookup-table labels/most_common_vals — note pg_stats `n_distinct` is negative when it scales with rows), view→metric extraction, recurring log aggregates, dead tables, tenant scoping (call-weighted log confirmation required), PII (name + sampled value shape), fan-out warnings (1:N parent-measure double-count). The LLM pass (`app/llm/agents/semantic_annotator.py`) only names/describes — output merges onto naming fields, never structure; runs fine without a provider. Output is hard-capped per kind (`Thresholds`) — review fatigue kills draft tools. +- **Staging, not drafts** (`CompilationRun`/`CompilationFinding`, `app/services/compilation_service.py`): findings never touch semantic tables until accepted (draft metrics/glossary feed the context builder today). Accept dispatches per kind through existing creation paths (embed + lineage), landing as `status='draft'`; policies (`PII masking`, `dead tables`, row filters) are created **disabled**; fan-out guidance becomes a knowledge doc (so the prompt assembler picks it up via RAG). Runs as a background job (`semantic_compilation`) with progress (`compilation_progress.py`). +- **Rematerialization:** `introspect_and_cache` wipes cached tables (cascading to inferred relationships + dictionary entries), so accepted findings are **name-keyed** and `rematerialize_accepted` re-creates them after every introspect. `cached_relationships` gained `origin` (`fk|inferred`), `confidence`, `cardinality`, `evidence`. +- **Endpoints:** `/connections/{id}/compilation/runs` (+ `/runs/{rid}`), `/compilation/findings` (+ `/{fid}/accept`, `/{fid}/dismiss`, `/bulk`). Frontend: `pages/CompilerPage.tsx` (run button, progress, findings grouped by kind with evidence + confidence, bulk accept/dismiss). +- **Eval:** `python backend/scripts/eval_compiler_ifrs9.py` scores recovery of the IFRS 9 seed metadata with FKs hidden (`ignore_declared_fks`). Baseline: relationships 5/5 @ 100% precision, dictionary 79%/89%, glossary table-coverage 10/10; metrics need views/logs (sampledb has neither — expected 0). +- **Heads-up:** `pg_stats` is empty until ANALYZE; `pg_stat_statements` needs the extension + read rights (`pg_read_all_stats`). Every collector degrades to empty and the run records `sources_available` so the UI explains reduced confidence. Collectors are Postgres-only for now — other connectors compile catalog-only. + ## Packaging & deployability (parallel track) Production deployment artifacts under `deploy/` (+ root prod compose), separate from the dev `docker-compose.yml` / `Dockerfile`s (which stay untouched for local work). The whole **Packaging & deployability** parallel track from `planfull.md` is complete: hardened images, prod compose, Helm chart, Terraform for AWS + GCP + Azure, CI/CD (build/push/deploy), and ops (backup/restore, DR runbook, config reference). The only deferred item is the **SaaS control plane** (provisioning/billing/fleet upgrades), which is additive and build-on-demand. Overview: `deploy/README.md`. diff --git a/backend/alembic/versions/013_semantic_compiler.py b/backend/alembic/versions/013_semantic_compiler.py new file mode 100644 index 0000000..52f3410 --- /dev/null +++ b/backend/alembic/versions/013_semantic_compiler.py @@ -0,0 +1,121 @@ +"""Semantic layer compiler (Slice 1) + +Revision ID: 013 +Revises: 012 +Create Date: 2026-06-10 + +Adds the compiler staging tables and inferred-relationship support: + +* ``cached_relationships`` gains ``origin`` ('fk' | 'inferred'), ``confidence``, + ``cardinality`` and ``evidence`` so join edges inferred by the compiler can + coexist with FK-derived ones. +* ``compilation_runs`` — one row per compiler execution against a connection. +* ``compilation_findings`` — proposed semantic objects with name-keyed payloads, + evidence, and confidence. Findings become real semantic objects only on + explicit accept; accepted findings are the durable source for rematerializing + inferred relationships and dictionary entries after re-introspection (which + wipes the schema cache). +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "013" +down_revision: str = "012" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "cached_relationships", + sa.Column("origin", sa.String(20), nullable=False, server_default="fk"), + ) + op.add_column("cached_relationships", sa.Column("confidence", sa.Float, nullable=True)) + op.add_column("cached_relationships", sa.Column("cardinality", sa.String(10), nullable=True)) + op.add_column("cached_relationships", sa.Column("evidence", JSONB, nullable=True)) + + op.create_table( + "compilation_runs", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "connection_id", + UUID(as_uuid=True), + sa.ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("status", sa.String(20), nullable=False, server_default="queued"), + sa.Column("options", JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")), + sa.Column("stats", JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")), + sa.Column("error", sa.Text, nullable=True), + sa.Column( + "triggered_by_id", + UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False + ), + ) + op.create_index("ix_compilation_runs_connection_id", "compilation_runs", ["connection_id"]) + + op.create_table( + "compilation_findings", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "run_id", + UUID(as_uuid=True), + sa.ForeignKey("compilation_runs.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "connection_id", + UUID(as_uuid=True), + sa.ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("kind", sa.String(40), nullable=False), + sa.Column("title", sa.String(255), nullable=False), + sa.Column("payload", JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")), + sa.Column("evidence", JSONB, nullable=False, server_default=sa.text("'[]'::jsonb")), + sa.Column("confidence", sa.Float, nullable=False, server_default=sa.text("0")), + sa.Column("status", sa.String(20), nullable=False, server_default="proposed"), + sa.Column("created_entity_type", sa.String(40), nullable=True), + sa.Column("created_entity_id", UUID(as_uuid=True), nullable=True), + sa.Column( + "reviewed_by_id", + UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False + ), + ) + op.create_index( + "ix_compilation_findings_conn_status_kind", + "compilation_findings", + ["connection_id", "status", "kind"], + ) + op.create_index("ix_compilation_findings_run_id", "compilation_findings", ["run_id"]) + + +def downgrade() -> None: + op.drop_index("ix_compilation_findings_run_id", table_name="compilation_findings") + op.drop_index("ix_compilation_findings_conn_status_kind", table_name="compilation_findings") + op.drop_table("compilation_findings") + op.drop_index("ix_compilation_runs_connection_id", table_name="compilation_runs") + op.drop_table("compilation_runs") + op.drop_column("cached_relationships", "evidence") + op.drop_column("cached_relationships", "cardinality") + op.drop_column("cached_relationships", "confidence") + op.drop_column("cached_relationships", "origin") diff --git a/backend/app/api/v1/endpoints/compilation.py b/backend/app/api/v1/endpoints/compilation.py new file mode 100644 index 0000000..3fd41cc --- /dev/null +++ b/backend/app/api/v1/endpoints/compilation.py @@ -0,0 +1,140 @@ +"""Semantic layer compiler endpoints: runs + findings review.""" + +import uuid + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.deps import require_connection_read, require_connection_write +from app.api.v1.schemas.compilation import ( + BulkReviewRequest, + BulkReviewResponse, + CompilationFindingResponse, + CompilationProgressResponse, + CompilationRunCreate, + CompilationRunResponse, +) +from app.core.auth import AuthContext +from app.db.session import get_db +from app.services import compilation_progress, compilation_service + +router = APIRouter(tags=["compilation"]) + + +def _with_progress(run) -> CompilationRunResponse: + response = CompilationRunResponse.model_validate(run) + p = compilation_progress.get_progress(str(run.connection_id)) + if p is not None and run.status in ("queued", "running"): + response.progress = CompilationProgressResponse( + total=p.total, + completed=p.completed, + stage=p.stage, + status=p.status, + error=p.error, + ) + return response + + +@router.post( + "/connections/{connection_id}/compilation/runs", + response_model=CompilationRunResponse, + status_code=202, +) +async def start_compilation( + connection_id: uuid.UUID, + body: CompilationRunCreate, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + try: + run = await compilation_service.start_run(db, connection_id, ctx, options=body.model_dump()) + except ValueError as exc: + raise HTTPException(status_code=409, detail=str(exc)) from exc + return _with_progress(run) + + +@router.get( + "/connections/{connection_id}/compilation/runs", + response_model=list[CompilationRunResponse], +) +async def list_compilation_runs( + connection_id: uuid.UUID, + ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + runs = await compilation_service.list_runs(db, connection_id, ctx) + return [_with_progress(run) for run in runs] + + +@router.get( + "/connections/{connection_id}/compilation/runs/{run_id}", + response_model=CompilationRunResponse, +) +async def get_compilation_run( + connection_id: uuid.UUID, + run_id: uuid.UUID, + ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + run = await compilation_service.get_run(db, run_id, ctx) + return _with_progress(run) + + +@router.get( + "/connections/{connection_id}/compilation/findings", + response_model=list[CompilationFindingResponse], +) +async def list_compilation_findings( + connection_id: uuid.UUID, + status: str | None = None, + kind: str | None = None, + ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + return await compilation_service.list_findings(db, connection_id, ctx, status, kind) + + +@router.post( + "/connections/{connection_id}/compilation/findings/{finding_id}/accept", + response_model=CompilationFindingResponse, +) +async def accept_finding( + connection_id: uuid.UUID, + finding_id: uuid.UUID, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + try: + return await compilation_service.accept_finding(db, finding_id, ctx) + except ValueError as exc: + raise HTTPException(status_code=409, detail=str(exc)) from exc + + +@router.post( + "/connections/{connection_id}/compilation/findings/{finding_id}/dismiss", + response_model=CompilationFindingResponse, +) +async def dismiss_finding( + connection_id: uuid.UUID, + finding_id: uuid.UUID, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + try: + return await compilation_service.dismiss_finding(db, finding_id, ctx) + except ValueError as exc: + raise HTTPException(status_code=409, detail=str(exc)) from exc + + +@router.post( + "/connections/{connection_id}/compilation/findings/bulk", + response_model=BulkReviewResponse, +) +async def bulk_review_findings( + connection_id: uuid.UUID, + body: BulkReviewRequest, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + result = await compilation_service.bulk_review(db, body.finding_ids, body.action, ctx) + return BulkReviewResponse(**result) diff --git a/backend/app/api/v1/router.py b/backend/app/api/v1/router.py index 7b8cf2d..28cf80c 100644 --- a/backend/app/api/v1/router.py +++ b/backend/app/api/v1/router.py @@ -7,6 +7,7 @@ audit, auth, catalog, + compilation, connections, dashboards, dictionary, @@ -43,6 +44,7 @@ api_router.include_router(query_history.router) api_router.include_router(knowledge.router) api_router.include_router(catalog.router) +api_router.include_router(compilation.router) api_router.include_router(audit.router) api_router.include_router(schedules.router) api_router.include_router(policies.router) diff --git a/backend/app/api/v1/schemas/compilation.py b/backend/app/api/v1/schemas/compilation.py new file mode 100644 index 0000000..c36fffe --- /dev/null +++ b/backend/app/api/v1/schemas/compilation.py @@ -0,0 +1,62 @@ +import uuid +from datetime import datetime + +from pydantic import BaseModel, ConfigDict, Field + + +class CompilationRunCreate(BaseModel): + llm_enabled: bool = True + min_confidence: float = Field(default=0.5, ge=0.0, le=1.0) + # Eval mode: pretend declared FKs don't exist so join inference is exercised. + ignore_declared_fks: bool = False + + +class CompilationProgressResponse(BaseModel): + total: int + completed: int + stage: str + status: str + error: str | None = None + + +class CompilationRunResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + connection_id: uuid.UUID + status: str + options: dict + stats: dict + error: str | None + started_at: datetime | None + finished_at: datetime | None + created_at: datetime + progress: CompilationProgressResponse | None = None + + +class CompilationFindingResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + run_id: uuid.UUID + connection_id: uuid.UUID + kind: str + title: str + payload: dict + evidence: list + confidence: float + status: str + created_entity_type: str | None + created_entity_id: uuid.UUID | None + reviewed_at: datetime | None + created_at: datetime + + +class BulkReviewRequest(BaseModel): + finding_ids: list[uuid.UUID] = Field(min_length=1, max_length=500) + action: str = Field(pattern="^(accept|dismiss)$") + + +class BulkReviewResponse(BaseModel): + succeeded: int + failed: int diff --git a/backend/app/db/models/__init__.py b/backend/app/db/models/__init__.py index 01c5d1e..1cf0706 100644 --- a/backend/app/db/models/__init__.py +++ b/backend/app/db/models/__init__.py @@ -2,6 +2,7 @@ from app.db.models.artifact_dependency import ArtifactDependency from app.db.models.audit_event import AuditEvent from app.db.models.chart import Chart +from app.db.models.compilation import CompilationFinding, CompilationRun from app.db.models.connection import DatabaseConnection from app.db.models.cost_attribution import CostAttribution from app.db.models.dashboard import Dashboard @@ -51,4 +52,6 @@ "Schedule", "DataPolicy", "CostAttribution", + "CompilationRun", + "CompilationFinding", ] diff --git a/backend/app/db/models/compilation.py b/backend/app/db/models/compilation.py new file mode 100644 index 0000000..dba41c2 --- /dev/null +++ b/backend/app/db/models/compilation.py @@ -0,0 +1,89 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, Float, ForeignKey, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.db.base import Base + + +class CompilationRun(Base): + """One semantic-layer-compiler execution against a connection.""" + + __tablename__ = "compilation_runs" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + connection_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + status: Mapped[str] = mapped_column( + String(20), nullable=False, default="queued" + ) # queued | running | completed | failed + # Run options: llm_enabled, min_confidence, ignore_declared_fks, ... + options: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict) + # Per-kind finding counts + which evidence sources were available + # (pg_stats / views / query_logs), so the UI can explain reduced confidence. + stats: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict) + error: Mapped[str | None] = mapped_column(Text) + triggered_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) + started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + finished_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + + findings: Mapped[list["CompilationFinding"]] = relationship( + back_populates="run", cascade="all, delete-orphan" + ) + + +class CompilationFinding(Base): + """A proposed semantic object awaiting human review. + + ``payload`` is **name-keyed** (schema/table/column names, never cache ids): + the schema cache is wiped on every re-introspect, so accepted findings are + the durable source from which inferred relationships and dictionary entries + are rematerialized. A finding becomes a real semantic object only when + accepted — keeping unreviewed output away from the query-pipeline context + builder (which retrieves draft metrics/glossary today). + """ + + __tablename__ = "compilation_findings" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + run_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("compilation_runs.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + connection_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + ) + # relationship | metric | dictionary | glossary | data_policy_row_filter | + # data_policy_masking | dead_table | fanout_warning + kind: Mapped[str] = mapped_column(String(40), nullable=False) + title: Mapped[str] = mapped_column(String(255), nullable=False) + payload: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict) + # List of {source, detail} facts, e.g. + # {"source": "value_overlap", "detail": "98% of 500 sampled orders.customer_id ..."} + evidence: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) + status: Mapped[str] = mapped_column( + String(20), nullable=False, default="proposed" + ) # proposed | accepted | dismissed + created_entity_type: Mapped[str | None] = mapped_column(String(40)) + created_entity_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True)) + reviewed_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) + reviewed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + + run: Mapped["CompilationRun"] = relationship(back_populates="findings") diff --git a/backend/app/db/models/schema_cache.py b/backend/app/db/models/schema_cache.py index 0247790..79e11db 100644 --- a/backend/app/db/models/schema_cache.py +++ b/backend/app/db/models/schema_cache.py @@ -2,8 +2,8 @@ from datetime import datetime from pgvector.sqlalchemy import Vector -from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text, func -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column, relationship from app.config import settings @@ -15,7 +15,9 @@ class CachedTable(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) connection_id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="CASCADE"), nullable=False + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, ) schema_name: Mapped[str] = mapped_column(String(255), nullable=False) table_name: Mapped[str] = mapped_column(String(255), nullable=False) @@ -23,9 +25,7 @@ class CachedTable(Base): comment: Mapped[str | None] = mapped_column(Text) row_count_estimate: Mapped[int | None] = mapped_column(Integer) description_embedding = mapped_column(Vector(settings.embedding_dimension), nullable=True) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now() - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) @@ -66,9 +66,7 @@ class CachedColumn(Base): comment: Mapped[str | None] = mapped_column(Text) ordinal_position: Mapped[int] = mapped_column(Integer, nullable=False) description_embedding = mapped_column(Vector(settings.embedding_dimension), nullable=True) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now() - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) # Relationships table: Mapped["CachedTable"] = relationship(back_populates="columns") @@ -82,9 +80,17 @@ class CachedRelationship(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) connection_id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="CASCADE"), nullable=False + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, ) constraint_name: Mapped[str | None] = mapped_column(String(255)) + # 'fk' = declared foreign key (re-derived on every introspect); + # 'inferred' = compiler-accepted edge (rematerialized from accepted findings). + origin: Mapped[str] = mapped_column(String(20), nullable=False, default="fk") + confidence: Mapped[float | None] = mapped_column(Float) + cardinality: Mapped[str | None] = mapped_column(String(10)) # '1:1'|'N:1'|'1:N'|'N:N' + evidence = mapped_column(JSONB, nullable=True) source_table_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("cached_tables.id", ondelete="CASCADE"), nullable=False ) @@ -93,9 +99,7 @@ class CachedRelationship(Base): UUID(as_uuid=True), ForeignKey("cached_tables.id", ondelete="CASCADE"), nullable=False ) target_column: Mapped[str] = mapped_column(String(255), nullable=False) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now() - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) # Relationships source_table: Mapped["CachedTable"] = relationship( diff --git a/backend/app/jobs/tasks.py b/backend/app/jobs/tasks.py index ca08a55..e824c5f 100644 --- a/backend/app/jobs/tasks.py +++ b/backend/app/jobs/tasks.py @@ -6,8 +6,11 @@ module makes that registration explicit and order-independent. """ -# Registers "generate_embeddings". -import app.services.setup_service # noqa: F401 - # Registers "run_schedule". import app.jobs.scheduler # noqa: F401 + +# Registers "semantic_compilation". +import app.services.compilation_service # noqa: F401 + +# Registers "generate_embeddings". +import app.services.setup_service # noqa: F401 diff --git a/backend/app/llm/agents/semantic_annotator.py b/backend/app/llm/agents/semantic_annotator.py new file mode 100644 index 0000000..5ac9bf0 --- /dev/null +++ b/backend/app/llm/agents/semantic_annotator.py @@ -0,0 +1,105 @@ +"""Agent: Semantic Annotator — names/describes compiler findings. + +The semantic layer compiler produces findings deterministically; this agent's +only job is to replace machine-generated names and descriptions with +human-quality ones. Its output is merged onto the findings' *naming fields +only* — it cannot add or alter facts. Annotation is optional: any failure +leaves the deterministic fallback names in place. +""" + +import json +import logging +import re + +from app.llm.base_provider import BaseLLMProvider, LLMConfig, LLMMessage +from app.llm.prompts.annotator_prompts import KIND_FIELDS, SYSTEM_PROMPT, USER_PROMPT_TEMPLATE +from app.llm.utils import repair_json + +logger = logging.getLogger(__name__) + +_BATCH_SIZE = 20 +_MAX_FIELD_LEN = 2000 +_IDENTIFIER_RE = re.compile(r"[^a-z0-9_]+") + + +def _sanitize_identifier(value: str) -> str: + cleaned = _IDENTIFIER_RE.sub("_", value.strip().lower()).strip("_") + return cleaned[:255] or "unnamed" + + +class SemanticAnnotatorAgent: + def __init__(self, provider: BaseLLMProvider, config: LLMConfig): + self.provider = provider + self.config = config + + async def annotate(self, kind: str, findings: list[dict]) -> dict[int, dict[str, str]]: + """Return {finding_index: {field: value}} for one kind of finding. + + ``findings`` are dicts with at least ``title``, ``payload``, ``evidence``. + Per-batch failures are swallowed — callers always get a (possibly + partial or empty) mapping. + """ + allowed_fields = KIND_FIELDS.get(kind) + if not allowed_fields or not findings: + return {} + + annotations: dict[int, dict[str, str]] = {} + for start in range(0, len(findings), _BATCH_SIZE): + batch = findings[start : start + _BATCH_SIZE] + try: + annotations.update( + await self._annotate_batch(kind, allowed_fields, batch, offset=start) + ) + except Exception as exc: + logger.warning("annotation batch failed for kind=%s: %s", kind, exc) + return annotations + + async def _annotate_batch( + self, kind: str, allowed_fields: list[str], batch: list[dict], offset: int + ) -> dict[int, dict[str, str]]: + findings_json = json.dumps( + [ + { + "index": offset + i, + "title": f.get("title"), + "payload": f.get("payload"), + "evidence": f.get("evidence"), + } + for i, f in enumerate(batch) + ], + default=str, + indent=2, + ) + fields_doc = ", ".join(f'"{name}": "..."' for name in allowed_fields) + messages = [ + LLMMessage(role="system", content=SYSTEM_PROMPT), + LLMMessage( + role="user", + content=USER_PROMPT_TEMPLATE.format( + kind=kind, fields_doc=fields_doc, findings_json=findings_json + ), + ), + ] + response = await self.provider.complete(messages, self.config) + parsed = json.loads(repair_json(response.content)) + + valid_indices = {offset + i for i in range(len(batch))} + result: dict[int, dict[str, str]] = {} + for item in parsed.get("annotations", []): + if not isinstance(item, dict): + continue + index = item.get("index") + if index not in valid_indices: + continue # the model referenced a finding we didn't send + fields: dict[str, str] = {} + for field in allowed_fields: + value = item.get(field) + if not isinstance(value, str) or not value.strip(): + continue + value = value.strip()[:_MAX_FIELD_LEN] + if field == "metric_name": + value = _sanitize_identifier(value) + fields[field] = value + if fields: + result[index] = fields + return result diff --git a/backend/app/llm/prompts/annotator_prompts.py b/backend/app/llm/prompts/annotator_prompts.py new file mode 100644 index 0000000..f7e20b6 --- /dev/null +++ b/backend/app/llm/prompts/annotator_prompts.py @@ -0,0 +1,34 @@ +"""Prompts for the semantic annotator agent (semantic layer compiler).""" + +SYSTEM_PROMPT = """You are a data analyst writing display names and descriptions for \ +semantic-layer objects that were derived from verified database evidence. + +STRICT RULES: +- You only NAME and DESCRIBE. Every fact (tables, columns, joins, values, SQL \ +expressions) was verified deterministically — you may not add, remove, or alter any of it. +- Do not invent tables, columns, joins, values, filters, or SQL. +- Descriptions must be grounded in the provided evidence only. If the evidence is \ +thin, write a short, plain description rather than speculating. +- Respond with JSON only, no markdown fences, matching the schema in the user message.""" + +USER_PROMPT_TEMPLATE = """Below are {kind} findings inferred from a database. For each, propose \ +better human-facing naming fields. Return JSON of the form: + +{{"annotations": [{{"index": , {fields_doc}}}]}} + +Only the listed fields are allowed. Only reference indices that appear below. + +Findings: +{findings_json}""" + +# Per-kind: which naming fields the LLM may produce. +KIND_FIELDS: dict[str, list[str]] = { + "metric": ["metric_name", "display_name", "description"], + "glossary": ["term", "definition"], + "relationship": ["description"], + "dictionary": ["description"], + "data_policy_row_filter": ["description"], + "data_policy_masking": ["description"], + "dead_table": ["description"], + "fanout_warning": ["description"], +} diff --git a/backend/app/semantic_compiler/__init__.py b/backend/app/semantic_compiler/__init__.py new file mode 100644 index 0000000..b68ce34 --- /dev/null +++ b/backend/app/semantic_compiler/__init__.py @@ -0,0 +1,25 @@ +"""Semantic layer compiler. + +Introspects an operational database (schema catalog, column statistics, view +definitions, query logs) and produces draft semantic-layer findings — inferred +join paths, dictionary entries, metric candidates, glossary entities, and +refusal boundaries (PII masking, tenant row filters, dead tables, fan-out +warnings) — each carrying evidence and a confidence score for human review. + +Design rules: +* Deterministic inference produces all facts; the LLM pass only names and + describes them (see ``app/llm/agents/semantic_annotator.py``). +* This package is self-contained: dataclasses + pure functions, no FastAPI or + ORM imports. Database access goes through the narrow ``Prober`` protocol so + a standalone CLI can be extracted later. +""" + +from app.semantic_compiler.engine import run_compiler +from app.semantic_compiler.types import ( + CompilerInput, + Finding, + Prober, + Thresholds, +) + +__all__ = ["CompilerInput", "Finding", "Prober", "Thresholds", "run_compiler"] diff --git a/backend/app/semantic_compiler/collectors/__init__.py b/backend/app/semantic_compiler/collectors/__init__.py new file mode 100644 index 0000000..3a3df17 --- /dev/null +++ b/backend/app/semantic_compiler/collectors/__init__.py @@ -0,0 +1,11 @@ +from app.semantic_compiler.collectors.catalog import build_table_profiles +from app.semantic_compiler.collectors.pg_stats import collect_pg_stats +from app.semantic_compiler.collectors.query_logs import collect_query_logs +from app.semantic_compiler.collectors.views import collect_view_definitions + +__all__ = [ + "build_table_profiles", + "collect_pg_stats", + "collect_query_logs", + "collect_view_definitions", +] diff --git a/backend/app/semantic_compiler/collectors/catalog.py b/backend/app/semantic_compiler/collectors/catalog.py new file mode 100644 index 0000000..6a760d3 --- /dev/null +++ b/backend/app/semantic_compiler/collectors/catalog.py @@ -0,0 +1,47 @@ +"""Catalog collector: maps connector introspection output to TableProfiles. + +Takes duck-typed objects shaped like ``app.connectors.base_connector.TableInfo`` +(schema_name, table_name, table_type, comment, row_count_estimate, columns[], +foreign_keys[]) so this package stays free of app imports. +""" + +from typing import Any + +from app.semantic_compiler.types import ColumnProfile, DeclaredFK, TableProfile + + +def build_table_profiles(table_infos: list[Any]) -> list[TableProfile]: + profiles: list[TableProfile] = [] + for info in table_infos: + columns = [ + ColumnProfile( + name=col.name, + data_type=col.data_type, + is_nullable=col.is_nullable, + is_primary_key=col.is_primary_key, + comment=col.comment, + ordinal_position=col.ordinal_position, + ) + for col in info.columns + ] + fks = [ + DeclaredFK( + source_column=fk.column_name, + target_schema=fk.referred_schema, + target_table=fk.referred_table, + target_column=fk.referred_column, + ) + for fk in info.foreign_keys + ] + profiles.append( + TableProfile( + schema_name=info.schema_name, + table_name=info.table_name, + table_type=info.table_type, + comment=info.comment, + row_count_estimate=info.row_count_estimate, + columns=columns, + declared_fks=fks, + ) + ) + return profiles diff --git a/backend/app/semantic_compiler/collectors/pg_stats.py b/backend/app/semantic_compiler/collectors/pg_stats.py new file mode 100644 index 0000000..f58cf82 --- /dev/null +++ b/backend/app/semantic_compiler/collectors/pg_stats.py @@ -0,0 +1,127 @@ +"""Statistics collector: enriches ColumnProfiles from Postgres catalogs. + +Reads ``pg_stats`` (null_frac, n_distinct, most_common_vals), CHECK ``IN``-list +constraints, enum types, and single-column unique indexes. Every sub-query is +best-effort: on permission errors the profiles simply stay un-enriched and the +collector reports the source as unavailable. + +NOTE: ``pg_stats`` is empty until the target DB has been ANALYZEd. +""" + +import logging +import re + +from app.semantic_compiler.types import Prober, TableProfile + +logger = logging.getLogger(__name__) + +_PG_STATS_SQL = """ +SELECT schemaname, tablename, attname, null_frac, n_distinct, + most_common_vals::text::text[] AS mcv, + most_common_freqs::float[] AS mcf +FROM pg_stats +WHERE schemaname NOT IN ('pg_catalog', 'information_schema') +""" + +_CHECK_SQL = """ +SELECT n.nspname AS schema, c.relname AS table, pg_get_constraintdef(con.oid) AS def +FROM pg_constraint con +JOIN pg_class c ON con.conrelid = c.oid +JOIN pg_namespace n ON c.relnamespace = n.oid +WHERE con.contype = 'c' AND n.nspname NOT IN ('pg_catalog', 'information_schema') +""" + +_ENUM_SQL = """ +SELECT n.nspname AS schema, c.relname AS table, a.attname AS column, + e.enumlabel AS label, e.enumsortorder AS sort +FROM pg_attribute a +JOIN pg_class c ON a.attrelid = c.oid +JOIN pg_namespace n ON c.relnamespace = n.oid +JOIN pg_type t ON a.atttypid = t.oid +JOIN pg_enum e ON e.enumtypid = t.oid +WHERE c.relkind IN ('r', 'p') AND n.nspname NOT IN ('pg_catalog', 'information_schema') +ORDER BY n.nspname, c.relname, a.attname, e.enumsortorder +""" + +_UNIQUE_SQL = """ +SELECT n.nspname AS schema, c.relname AS table, a.attname AS column +FROM pg_index i +JOIN pg_class c ON c.oid = i.indrelid +JOIN pg_namespace n ON c.relnamespace = n.oid +JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = i.indkey[0] +WHERE i.indisunique AND i.indnkeyatts = 1 + AND n.nspname NOT IN ('pg_catalog', 'information_schema') +""" + +# pg_get_constraintdef renders IN-lists as: +# CHECK ((stage = ANY (ARRAY[1, 2, 3]))) -- int +# CHECK (((segment)::text = ANY ((ARRAY['retail'::charactervarying, ...])::text[]))) +_CHECK_IN_RE = re.compile( + r"\(?\"?(\w+)\"?\)?(?:::[\w\s]+)?\s*=\s*ANY\s*\(+\s*ARRAY\[(.*?)\]", re.IGNORECASE +) + + +def _clean_literal(raw: str) -> str: + """'active'::text -> active ; 3 -> 3""" + value = raw.strip() + value = re.sub(r"::[\w\s\"]+$", "", value).strip() + if value.startswith("'") and value.endswith("'"): + value = value[1:-1] + return value + + +async def collect_pg_stats(prober: Prober, tables: list[TableProfile]) -> bool: + """Enrich `tables` in place. Returns True if pg_stats was readable.""" + by_name: dict[tuple[str, str], TableProfile] = { + (t.schema_name, t.table_name): t for t in tables + } + + available = False + try: + rows = await prober.query(_PG_STATS_SQL, max_rows=50000) + available = True + for row in rows: + table = by_name.get((row["schemaname"], row["tablename"])) + col = table.column(row["attname"]) if table else None + if col is None: + continue + col.null_frac = row["null_frac"] + col.n_distinct = row["n_distinct"] + col.most_common_vals = list(row["mcv"]) if row["mcv"] else None + col.most_common_freqs = list(row["mcf"]) if row["mcf"] else None + except Exception as exc: + logger.debug("pg_stats unavailable: %s", exc) + + try: + for row in await prober.query(_CHECK_SQL, max_rows=10000): + table = by_name.get((row["schema"], row["table"])) + if table is None: + continue + match = _CHECK_IN_RE.search(row["def"] or "") + if not match: + continue + col = table.column(match.group(1)) + if col is not None: + col.check_in_values = [_clean_literal(v) for v in match.group(2).split(",")] + except Exception as exc: + logger.debug("pg_constraint unavailable: %s", exc) + + try: + for row in await prober.query(_ENUM_SQL, max_rows=10000): + table = by_name.get((row["schema"], row["table"])) + col = table.column(row["column"]) if table else None + if col is not None: + col.enum_values = (col.enum_values or []) + [row["label"]] + except Exception as exc: + logger.debug("pg_enum unavailable: %s", exc) + + try: + for row in await prober.query(_UNIQUE_SQL, max_rows=10000): + table = by_name.get((row["schema"], row["table"])) + col = table.column(row["column"]) if table else None + if col is not None: + col.is_unique = True + except Exception as exc: + logger.debug("pg_index unavailable: %s", exc) + + return available diff --git a/backend/app/semantic_compiler/collectors/query_logs.py b/backend/app/semantic_compiler/collectors/query_logs.py new file mode 100644 index 0000000..be99bbd --- /dev/null +++ b/backend/app/semantic_compiler/collectors/query_logs.py @@ -0,0 +1,41 @@ +"""Query-log collector: pg_stat_statements, the revealed-preference semantic layer. + +Requires the extension to be installed in the target DB and the connecting role +to be allowed to read other sessions' statements (``pg_read_all_stats`` or +superuser); degrades to an empty list otherwise. +""" + +import logging + +from app.semantic_compiler.types import LoggedQuery, Prober + +logger = logging.getLogger(__name__) + +_LOGS_SQL = """ +SELECT query, calls, total_exec_time +FROM pg_stat_statements +WHERE dbid = (SELECT oid FROM pg_database WHERE datname = current_database()) + AND query NOT ILIKE '%pg_catalog%' + AND query NOT ILIKE '%pg_stat_statements%' + AND query NOT ILIKE '%information_schema%' +ORDER BY calls DESC +LIMIT 500 +""" + + +async def collect_query_logs(prober: Prober) -> tuple[list[LoggedQuery], bool]: + try: + rows = await prober.query(_LOGS_SQL, max_rows=500) + except Exception as exc: + logger.debug("pg_stat_statements unavailable: %s", exc) + return [], False + queries = [ + LoggedQuery( + sql=r["query"], + calls=int(r["calls"] or 1), + total_time_ms=float(r["total_exec_time"] or 0.0), + ) + for r in rows + if r.get("query") + ] + return queries, True diff --git a/backend/app/semantic_compiler/collectors/views.py b/backend/app/semantic_compiler/collectors/views.py new file mode 100644 index 0000000..c1c3938 --- /dev/null +++ b/backend/app/semantic_compiler/collectors/views.py @@ -0,0 +1,30 @@ +"""View-definition collector. Handwritten views are crystallized business logic.""" + +import logging + +from app.semantic_compiler.types import Prober, ViewDef + +logger = logging.getLogger(__name__) + +_VIEWS_SQL = """ +SELECT schemaname, viewname, definition +FROM pg_views +WHERE schemaname NOT IN ('pg_catalog', 'information_schema') + AND viewname NOT LIKE 'pg\\_%' +""" + + +async def collect_view_definitions(prober: Prober) -> tuple[list[ViewDef], bool]: + try: + rows = await prober.query(_VIEWS_SQL, max_rows=2000) + except Exception as exc: + logger.debug("pg_views unavailable: %s", exc) + return [], False + return ( + [ + ViewDef(schema_name=r["schemaname"], view_name=r["viewname"], sql=r["definition"]) + for r in rows + if r.get("definition") + ], + True, + ) diff --git a/backend/app/semantic_compiler/engine.py b/backend/app/semantic_compiler/engine.py new file mode 100644 index 0000000..0239671 --- /dev/null +++ b/backend/app/semantic_compiler/engine.py @@ -0,0 +1,102 @@ +"""Compiler engine: deterministic inference over collected evidence. + +The service layer builds a ``CompilerInput`` via the collectors, then calls +``run_compiler``. The LLM annotation pass happens *after* this returns — the +engine itself never calls a model. +""" + +import logging + +from app.semantic_compiler.inference import ( + infer_dead_tables, + infer_dictionaries, + infer_fanout_warnings, + infer_glossary_entities, + infer_joins, + infer_log_metrics, + infer_pii, + infer_tenant_scope, + infer_view_metrics, +) +from app.semantic_compiler.sqlmeta import SqlAnalysis, analyze +from app.semantic_compiler.types import ( + CompilerInput, + Finding, + Prober, + Thresholds, + ViewDef, +) + +logger = logging.getLogger(__name__) + + +def _apply_thresholds(findings: list[Finding], thresholds: Thresholds) -> list[Finding]: + kept: list[Finding] = [] + by_kind: dict[str, list[Finding]] = {} + for finding in findings: + if finding.confidence >= thresholds.min_confidence: + by_kind.setdefault(finding.kind, []).append(finding) + for kind, group in by_kind.items(): + group.sort(key=lambda f: f.confidence, reverse=True) + cap = thresholds.max_per_kind.get(kind) + kept.extend(group[:cap] if cap else group) + return kept + + +async def run_compiler( + inp: CompilerInput, prober: Prober, thresholds: Thresholds | None = None +) -> list[Finding]: + thresholds = thresholds or Thresholds() + ignore_declared_fks = bool(inp.options.get("ignore_declared_fks")) + + # Parse views and logged queries once; weight = call count for logs. + view_analyses: list[tuple[ViewDef, SqlAnalysis]] = [] + for view in inp.views: + analysis = analyze(view.sql, dialect=inp.dialect) + if analysis is not None: + view_analyses.append((view, analysis)) + + log_analyses: list[tuple[SqlAnalysis, int]] = [] + for logged in inp.logged_queries: + analysis = analyze(logged.sql, dialect=inp.dialect) + if analysis is not None: + log_analyses.append((analysis, logged.calls)) + + # Combined join evidence: views count once, logs by call weight. + combined = [(a, 1, f"view {v.view_name}") for v, a in view_analyses] + [ + (a, calls, "query log") for a, calls in log_analyses + ] + + relationships = await infer_joins( + inp.tables, combined, prober, thresholds, ignore_declared_fks=ignore_declared_fks + ) + + dictionaries = await infer_dictionaries(inp.tables, relationships, prober) + + view_names_used = { + t for a, _ in log_analyses for t in a.tables + } # views queried in the workload + metrics = infer_view_metrics(view_analyses, used_views=view_names_used) + metrics += infer_log_metrics(log_analyses, metrics) + + logs_available = inp.sources_available.get("query_logs", False) + referenced: set[str] = set() + for analysis, _ in log_analyses: + referenced.update(analysis.tables) + for _, analysis in view_analyses: + referenced.update(analysis.tables) + dead = infer_dead_tables(inp.tables, referenced, logs_available) + dead_names = {f.payload["table"].lower() for f in dead} + + tenant = infer_tenant_scope(inp.tables, log_analyses) + pii = await infer_pii(inp.tables, prober) + fanout = infer_fanout_warnings(inp.tables, relationships) + glossary = infer_glossary_entities(inp.tables, relationships, dead_names) + + all_findings = relationships + dictionaries + metrics + dead + tenant + pii + fanout + glossary + kept = _apply_thresholds(all_findings, thresholds) + logger.info("compiler produced %d findings (%d above threshold)", len(all_findings), len(kept)) + return kept + + +__all__ = ["run_compiler"] diff --git a/backend/app/semantic_compiler/inference/__init__.py b/backend/app/semantic_compiler/inference/__init__.py new file mode 100644 index 0000000..8eb23ad --- /dev/null +++ b/backend/app/semantic_compiler/inference/__init__.py @@ -0,0 +1,21 @@ +from app.semantic_compiler.inference.dead_tables import infer_dead_tables +from app.semantic_compiler.inference.dictionaries import infer_dictionaries +from app.semantic_compiler.inference.fanout import infer_fanout_warnings +from app.semantic_compiler.inference.glossary import infer_glossary_entities +from app.semantic_compiler.inference.joins import infer_joins +from app.semantic_compiler.inference.log_metrics import infer_log_metrics +from app.semantic_compiler.inference.pii import infer_pii +from app.semantic_compiler.inference.tenant_scope import infer_tenant_scope +from app.semantic_compiler.inference.view_metrics import infer_view_metrics + +__all__ = [ + "infer_dead_tables", + "infer_dictionaries", + "infer_fanout_warnings", + "infer_glossary_entities", + "infer_joins", + "infer_log_metrics", + "infer_pii", + "infer_tenant_scope", + "infer_view_metrics", +] diff --git a/backend/app/semantic_compiler/inference/dead_tables.py b/backend/app/semantic_compiler/inference/dead_tables.py new file mode 100644 index 0000000..7b98167 --- /dev/null +++ b/backend/app/semantic_compiler/inference/dead_tables.py @@ -0,0 +1,43 @@ +"""Dead-table detection: candidates for blocking / retrieval de-boosting.""" + +import re + +from app.semantic_compiler.types import KIND_DEAD_TABLE, Evidence, Finding, TableProfile + +_DEAD_SUFFIX_RE = re.compile(r"_(bak|backup|old|tmp|temp|archive|deprecated)\d*$", re.IGNORECASE) + + +def infer_dead_tables( + tables: list[TableProfile], + referenced_tables: set[str], + logs_available: bool, +) -> list[Finding]: + """`referenced_tables` = tables seen in logged queries or view definitions.""" + findings: list[Finding] = [] + for table in tables: + if table.table_type != "table": + continue + score = 0.0 + evidence: list[Evidence] = [] + if _DEAD_SUFFIX_RE.search(table.table_name): + score += 0.6 + evidence.append(Evidence("naming", "name suffix suggests a backup/old copy")) + if (table.row_count_estimate or 0) <= 0: + score += 0.35 + evidence.append(Evidence("pg_stats", "row count estimate is zero")) + if logs_available and table.table_name.lower() not in referenced_tables: + score += 0.2 + evidence.append(Evidence("query_logs", "never referenced in logged queries or views")) + + if score < 0.5 or not evidence: + continue + findings.append( + Finding( + kind=KIND_DEAD_TABLE, + title=f"Likely dead table: {table.table_name}", + payload={"schema": table.schema_name, "table": table.table_name}, + evidence=evidence, + confidence=min(score, 0.95), + ) + ) + return findings diff --git a/backend/app/semantic_compiler/inference/dictionaries.py b/backend/app/semantic_compiler/inference/dictionaries.py new file mode 100644 index 0000000..cc0643e --- /dev/null +++ b/backend/app/semantic_compiler/inference/dictionaries.py @@ -0,0 +1,178 @@ +"""Dictionary inference: enumerable column values with display labels. + +Evidence, strongest first: enum types, CHECK IN-lists, lookup tables reached +through an inferred/declared join (labels probed from the lookup), and +``pg_stats.most_common_vals`` on low-cardinality columns. +""" + +import logging + +from app.semantic_compiler.types import ( + KIND_DICTIONARY, + KIND_RELATIONSHIP, + ColumnProfile, + Evidence, + Finding, + Prober, + TableProfile, +) + +logger = logging.getLogger(__name__) + +_TEXTY_TYPES = ("text", "character varying", "varchar", "character", "char") +_INTY_TYPES = ("integer", "bigint", "smallint", "int") +_MAX_CARDINALITY = 25 +_MAX_VALUE_LEN = 30 +_LOOKUP_MAX_ROWS = 100 +_LABEL_COLUMN_NAMES = ("label", "name", "description", "title", "display_name") + + +def _quote(identifier: str) -> str: + return '"' + identifier.replace('"', '""') + '"' + + +def _entry_list(values: list[str]) -> list[dict]: + return [ + {"raw_value": str(v), "display_value": str(v), "description": None, "sort_order": i + 1} + for i, v in enumerate(values) + ] + + +def _effective_n_distinct(col: ColumnProfile, row_count: int | None) -> float | None: + """pg_stats stores n_distinct as a NEGATIVE fraction of the row count when + the planner thinks distinct values scale with table size.""" + if col.n_distinct is None: + return None + if col.n_distinct >= 0: + return col.n_distinct + return -col.n_distinct * (row_count or 0) + + +def _is_enumerable_text(col: ColumnProfile, row_count: int | None) -> bool: + if not any(col.data_type.lower().startswith(t) for t in _TEXTY_TYPES): + return False + n_distinct = _effective_n_distinct(col, row_count) + if n_distinct is None or not (2 <= n_distinct <= _MAX_CARDINALITY): + return False + if not col.most_common_vals: + return False + return all(len(str(v)) <= _MAX_VALUE_LEN for v in col.most_common_vals) + + +def _is_coded_int(col: ColumnProfile, row_count: int | None) -> bool: + if not any(col.data_type.lower().startswith(t) for t in _INTY_TYPES): + return False + n_distinct = _effective_n_distinct(col, row_count) + return n_distinct is not None and 2 <= n_distinct <= _MAX_CARDINALITY + + +async def _lookup_entries( + prober: Prober, lookup: TableProfile, key_column: str +) -> list[dict] | None: + """Probe a small id/code/label table for raw->display mappings.""" + label_col = next((c for c in lookup.columns if c.name.lower() in _LABEL_COLUMN_NAMES), None) + code_col = next((c for c in lookup.columns if c.name.lower() == "code"), None) + if label_col is None and code_col is None: + return None + display = label_col or code_col + sql = ( + f"SELECT {_quote(key_column)} AS raw, {_quote(display.name)} AS display " + f"FROM {_quote(lookup.schema_name)}.{_quote(lookup.table_name)} " + f"ORDER BY {_quote(key_column)} LIMIT {_LOOKUP_MAX_ROWS}" + ) + rows = await prober.query(sql, max_rows=_LOOKUP_MAX_ROWS) + if not rows: + return None + return [ + { + "raw_value": str(r["raw"]), + "display_value": str(r["display"]), + "description": None, + "sort_order": i + 1, + } + for i, r in enumerate(rows) + ] + + +async def infer_dictionaries( + tables: list[TableProfile], + relationship_findings: list[Finding], + prober: Prober, +) -> list[Finding]: + # (source_table, source_column) -> (target table name, confidence) + lookup_edges: dict[tuple[str, str], tuple[str, float]] = {} + for f in relationship_findings: + if f.kind != KIND_RELATIONSHIP: + continue + key = (f.payload["source_table"], f.payload["source_column"]) + lookup_edges[key] = (f.payload["target_table"], f.confidence) + for table in tables: + for fk in table.declared_fks: + lookup_edges[(table.table_name, fk.source_column)] = (fk.target_table, 1.0) + + by_name = {t.table_name.lower(): t for t in tables} + findings: list[Finding] = [] + + for table in tables: + if table.table_type != "table": + continue + for col in table.columns: + if col.is_primary_key or col.name.lower().endswith("_id"): + continue + + entries: list[dict] | None = None + evidence: Evidence | None = None + confidence = 0.0 + + if col.enum_values: + entries = _entry_list(col.enum_values) + evidence = Evidence("constraint", f"enum type with {len(entries)} labels") + confidence = 0.9 + elif col.check_in_values: + entries = _entry_list(col.check_in_values) + evidence = Evidence("constraint", f"CHECK constraint allows {len(entries)} values") + confidence = 0.85 + elif ( + _is_coded_int(col, table.row_count_estimate) + and (table.table_name, col.name) in lookup_edges + ): + target_name, rel_conf = lookup_edges[(table.table_name, col.name)] + lookup = by_name.get(target_name.lower()) + if lookup is not None and (lookup.row_count_estimate or 0) <= _LOOKUP_MAX_ROWS: + try: + entries = await _lookup_entries(prober, lookup, "id") + except Exception as exc: + logger.debug("lookup probe failed for %s: %s", target_name, exc) + entries = None + if entries: + evidence = Evidence( + "value_overlap", + f"labels resolved from lookup table {target_name}", + ) + confidence = 0.85 if rel_conf >= 0.7 else 0.65 + elif _is_enumerable_text(col, table.row_count_estimate): + entries = _entry_list(col.most_common_vals or []) + evidence = Evidence( + "pg_stats", + f"n_distinct={_effective_n_distinct(col, table.row_count_estimate):.0f}, " + "values from most_common_vals", + ) + confidence = 0.6 + + if not entries or evidence is None: + continue + findings.append( + Finding( + kind=KIND_DICTIONARY, + title=f"Value dictionary: {table.table_name}.{col.name}", + payload={ + "schema": table.schema_name, + "table": table.table_name, + "column": col.name, + "entries": entries, + }, + evidence=[evidence], + confidence=confidence, + ) + ) + return findings diff --git a/backend/app/semantic_compiler/inference/fanout.py b/backend/app/semantic_compiler/inference/fanout.py new file mode 100644 index 0000000..1528512 --- /dev/null +++ b/backend/app/semantic_compiler/inference/fanout.py @@ -0,0 +1,102 @@ +"""Fan-out trap detection. + +For every parent 1:N child join, summing a parent measure across the join +double-counts — the single most common class of silently-wrong SQL answers. +Emits a warning per risky edge listing the parent's numeric measure columns. +""" + +from app.semantic_compiler.types import ( + KIND_FANOUT, + KIND_RELATIONSHIP, + Evidence, + Finding, + TableProfile, +) + +_MEASURE_TYPES = ("numeric", "decimal", "double", "real", "money", "integer", "bigint") +_KEY_HINTS = ("id", "_id", "code", "status", "method", "stage", "type", "year", "month") + + +def _measure_columns(table: TableProfile) -> list[str]: + measures = [] + for col in table.columns: + if col.is_primary_key: + continue + name = col.name.lower() + if any(name == hint or name.endswith(hint) for hint in _KEY_HINTS): + continue + if any(col.data_type.lower().startswith(t) for t in _MEASURE_TYPES): + measures.append(col.name) + return measures + + +def infer_fanout_warnings( + tables: list[TableProfile], + relationship_findings: list[Finding], +) -> list[Finding]: + by_name = {t.table_name.lower(): t for t in tables} + + # (child, child_col, parent, parent_col, confidence) for every N:1 edge + edges: list[tuple[str, str, str, str, float]] = [] + for f in relationship_findings: + if f.kind != KIND_RELATIONSHIP or f.payload.get("cardinality") != "N:1": + continue + p = f.payload + edges.append( + ( + p["source_table"], + p["source_column"], + p["target_table"], + p["target_column"], + f.confidence, + ) + ) + for table in tables: + for fk in table.declared_fks: + target = by_name.get(fk.target_table.lower()) + target_col = target.column(fk.target_column) if target else None + if target_col is not None and (target_col.is_primary_key or target_col.is_unique): + edges.append( + (table.table_name, fk.source_column, fk.target_table, fk.target_column, 1.0) + ) + + findings: list[Finding] = [] + seen: set[tuple[str, str]] = set() + for child, child_col, parent, parent_col, rel_conf in edges: + if (child, parent) in seen: + continue + seen.add((child, parent)) + parent_table = by_name.get(parent.lower()) + if parent_table is None: + continue + measures = _measure_columns(parent_table) + if not measures: + continue + example = f"{parent}.{measures[0]}" + findings.append( + Finding( + kind=KIND_FANOUT, + title=f"Fan-out risk: {parent} ⋈ {child}", + payload={ + "parent_table": parent, + "child_table": child, + "join": {"child_column": child_col, "parent_column": parent_col}, + "risky_columns": measures, + "guidance": ( + f"Joining {parent} to {child} repeats each {parent} row once per " + f"matching {child} row. Aggregating {parent} measures (e.g. " + f"SUM({example})) across this join double-counts; aggregate " + f"before joining or use DISTINCT on {parent} keys." + ), + }, + evidence=[ + Evidence( + "heuristic", + f"{child}.{child_col} → {parent}.{parent_col} is N:1, so the " + f"reverse join direction fans out {parent} rows", + ) + ], + confidence=min(0.9, rel_conf), + ) + ) + return findings diff --git a/backend/app/semantic_compiler/inference/glossary.py b/backend/app/semantic_compiler/inference/glossary.py new file mode 100644 index 0000000..dd6bc9b --- /dev/null +++ b/backend/app/semantic_compiler/inference/glossary.py @@ -0,0 +1,91 @@ +"""Glossary entity candidates: hub tables of the (inferred + declared) join graph. + +Deterministic evidence + fallback definitions; the LLM annotation pass writes +the human-quality names and descriptions. +""" + +from collections import defaultdict + +from app.semantic_compiler.inference.naming import singularize +from app.semantic_compiler.types import ( + KIND_GLOSSARY, + KIND_RELATIONSHIP, + Evidence, + Finding, + TableProfile, +) + +_MAX_ENTITIES = 8 +_LOOKUP_COLUMNS = {"id", "code", "label", "name", "description"} + + +def _is_lookup_table(table: TableProfile) -> bool: + names = {c.name.lower() for c in table.columns} + return names <= _LOOKUP_COLUMNS and (table.row_count_estimate or 0) <= 100 + + +def infer_glossary_entities( + tables: list[TableProfile], + relationship_findings: list[Finding], + dead_table_names: set[str], +) -> list[Finding]: + # target table -> list of "source.column" references pointing at it + inbound: dict[str, list[str]] = defaultdict(list) + for f in relationship_findings: + if f.kind != KIND_RELATIONSHIP: + continue + p = f.payload + inbound[p["target_table"].lower()].append(f"{p['source_table']}.{p['source_column']}") + for table in tables: + for fk in table.declared_fks: + inbound[fk.target_table.lower()].append(f"{table.table_name}.{fk.source_column}") + + candidates: list[tuple[float, TableProfile, list[str]]] = [] + for table in tables: + if table.table_type != "table": + continue + if table.table_name.lower() in dead_table_names or _is_lookup_table(table): + continue + refs = sorted(set(inbound.get(table.table_name.lower(), []))) + if not refs and (table.row_count_estimate or 0) < 1: + continue + score = 0.55 + 0.05 * min(len(refs), 4) + candidates.append((score, table, refs)) + + candidates.sort(key=lambda c: (c[0], c[1].row_count_estimate or 0), reverse=True) + + findings: list[Finding] = [] + for score, table, refs in candidates[:_MAX_ENTITIES]: + term = singularize(table.table_name).replace("_", " ").title() + referenced_by = f" Referenced by: {', '.join(refs)}." if refs else "" + definition = ( + f"Core entity stored in {table.qualified_name} " + f"(~{table.row_count_estimate or 0} rows).{referenced_by}" + ) + if table.comment: + definition = f"{table.comment} {definition}" + evidence = [ + Evidence( + "heuristic", + f"hub table with {len(refs)} inbound join reference(s)" + if refs + else "populated base table", + ) + ] + findings.append( + Finding( + kind=KIND_GLOSSARY, + title=f"Entity: {term}", + payload={ + "term": term, + "definition": definition, # LLM annotation improves this + "sql_expression": table.table_name, + "related_tables": [table.table_name] + sorted({r.split(".")[0] for r in refs}), + "related_columns": refs, + "examples": [], + }, + evidence=evidence, + confidence=min(score, 0.8), + ) + ) + return findings diff --git a/backend/app/semantic_compiler/inference/joins.py b/backend/app/semantic_compiler/inference/joins.py new file mode 100644 index 0000000..b142972 --- /dev/null +++ b/backend/app/semantic_compiler/inference/joins.py @@ -0,0 +1,278 @@ +"""Join-path inference — the compiler's core trick. + +Operational DBs routinely drop FK constraints for write performance, so join +paths must be inferred. Three independent evidence sources combine into one +confidence score: + +1. naming convention — ``orders.customer_id`` -> ``customers.id`` (~0.45) +2. value-overlap probe — sampled LEFT JOIN, >=95% of source values + resolve in the target (+0.35) +3. log co-occurrence — the join appears in actual logged queries (+0.15) + +A failed probe (<50% overlap) kills the candidate outright: a name match with +disjoint values is a coincidence, not a join path. +""" + +import logging +from dataclasses import dataclass, field + +from app.semantic_compiler.inference.naming import plural_candidates, singularize +from app.semantic_compiler.sqlmeta import SqlAnalysis +from app.semantic_compiler.types import ( + KIND_RELATIONSHIP, + Evidence, + Finding, + Prober, + TableProfile, + Thresholds, +) + +logger = logging.getLogger(__name__) + +_OVERLAP_SQL = """ +SELECT count(*) FILTER (WHERE t.{tc} IS NOT NULL)::float / NULLIF(count(*), 0) AS overlap +FROM (SELECT {sc} AS v FROM {ss}.{st} WHERE {sc} IS NOT NULL LIMIT {limit}) s +LEFT JOIN {ts}.{tt} t ON t.{tc} = s.v +""" + + +def _quote(identifier: str) -> str: + return '"' + identifier.replace('"', '""') + '"' + + +@dataclass +class _Candidate: + source: TableProfile + source_column: str + target: TableProfile + target_column: str + score: float = 0.0 + evidence: list[Evidence] = field(default_factory=list) + dropped: bool = False + + @property + def key(self) -> tuple[str, str, str, str]: + return ( + self.source.table_name, + self.source_column, + self.target.table_name, + self.target_column, + ) + + +def _naming_candidates( + tables: list[TableProfile], by_name: dict[str, TableProfile] +) -> list[_Candidate]: + candidates: list[_Candidate] = [] + for table in tables: + if table.table_type != "table": + continue + singular = singularize(table.table_name.lower()) + for col in table.columns: + name = col.name.lower() + if col.is_primary_key or name == "id": + continue + + target_names: list[str] = [] + if name.endswith("_id") and len(name) > 3: + target_names = plural_candidates(name[:-3]) + else: + # lookup-table pattern: orders.status -> order_statuses / statuses + target_names = [f"{singular}_{p}" for p in plural_candidates(name)] + target_names += plural_candidates(name) + + for target_name in target_names: + target = by_name.get(target_name) + if target is None or target.table_name == table.table_name: + continue + target_col = target.column("id") or next( + (c for c in target.columns if c.is_primary_key), None + ) + if target_col is None: + continue + candidates.append( + _Candidate( + source=table, + source_column=col.name, + target=target, + target_column=target_col.name, + score=0.45, + evidence=[ + Evidence( + "naming", + f"{table.table_name}.{col.name} matches " + f"{target.table_name}.{target_col.name} by convention", + ) + ], + ) + ) + break # best naming match only + return candidates + + +def _merge_log_evidence( + candidates: list[_Candidate], + analyses: list[tuple[SqlAnalysis, int, str]], + by_name: dict[str, TableProfile], +) -> list[_Candidate]: + """Add co-occurrence evidence to naming candidates; create log-only candidates.""" + pair_weight: dict[tuple, tuple[int, str]] = {} + for analysis, calls, origin in analyses: + for pair in analysis.join_pairs: + count, _ = pair_weight.get(pair.key(), (0, origin)) + pair_weight[pair.key()] = (count + max(calls, 1), origin) + + by_key = {c.key: c for c in candidates} + for pair_key, (weight, origin) in pair_weight.items(): + (t1, c1), (t2, c2) = pair_key + matched = None + for st, sc, tt, tc in ( + (t1, c1, t2, c2), + (t2, c2, t1, c1), + ): + matched = by_key.get((st, sc, tt, tc)) + if matched: + break + if matched is not None: + matched.score += 0.15 + matched.evidence.append( + Evidence("query_logs", f"join observed in workload ({origin}, weight {weight})") + ) + continue + + # Log-only candidate: pick the direction whose right side is a key column. + left, right = by_name.get(t1), by_name.get(t2) + if left is None or right is None: + continue + for src, sc, tgt, tc in ((left, c1, right, c2), (right, c2, left, c1)): + tgt_col = tgt.column(tc) + if tgt_col is not None and (tgt_col.is_primary_key or tgt_col.is_unique): + cand = _Candidate( + source=src, + source_column=sc, + target=tgt, + target_column=tc, + score=0.35, + evidence=[ + Evidence( + "query_logs", + f"join observed in workload ({origin}, weight {weight}) " + "with no naming-convention match", + ) + ], + ) + by_key[cand.key] = cand + candidates.append(cand) + break + return candidates + + +async def _probe_overlap(prober: Prober, cand: _Candidate, sample_rows: int) -> float | None: + sql = _OVERLAP_SQL.format( + sc=_quote(cand.source_column), + ss=_quote(cand.source.schema_name), + st=_quote(cand.source.table_name), + ts=_quote(cand.target.schema_name), + tt=_quote(cand.target.table_name), + tc=_quote(cand.target_column), + limit=int(sample_rows), + ) + rows = await prober.query(sql, max_rows=1) + if not rows: + return None + value = rows[0].get("overlap") + return float(value) if value is not None else None + + +def _cardinality(cand: _Candidate) -> str | None: + target_col = cand.target.column(cand.target_column) + source_col = cand.source.column(cand.source_column) + target_unique = bool(target_col and (target_col.is_primary_key or target_col.is_unique)) + source_unique = bool(source_col and (source_col.is_primary_key or source_col.is_unique)) + if target_unique and source_unique: + return "1:1" + if target_unique: + return "N:1" + if source_unique: + return "1:N" + return None + + +async def infer_joins( + tables: list[TableProfile], + analyses: list[tuple[SqlAnalysis, int, str]], + prober: Prober, + thresholds: Thresholds, + ignore_declared_fks: bool = False, +) -> list[Finding]: + """`analyses` = (parsed SQL, call weight, origin label) from views + logs.""" + by_name = {t.table_name.lower(): t for t in tables} + + declared: set[tuple[str, str, str, str]] = set() + if not ignore_declared_fks: + for table in tables: + for fk in table.declared_fks: + declared.add( + (table.table_name, fk.source_column, fk.target_table, fk.target_column) + ) + + candidates = _naming_candidates(tables, by_name) + candidates = _merge_log_evidence(candidates, analyses, by_name) + candidates = [c for c in candidates if c.key not in declared] + candidates.sort(key=lambda c: c.score, reverse=True) + + probes_left = thresholds.probe_budget + for cand in candidates: + if probes_left <= 0: + break + if cand.source.table_type != "table" or cand.target.table_type != "table": + continue + probes_left -= 1 + try: + overlap = await _probe_overlap(prober, cand, thresholds.probe_sample_rows) + except Exception as exc: + logger.debug("overlap probe failed for %s: %s", cand.key, exc) + continue + if overlap is None: + continue + detail = ( + f"{overlap:.0%} of {thresholds.probe_sample_rows} sampled " + f"{cand.source.table_name}.{cand.source_column} values resolve in " + f"{cand.target.table_name}.{cand.target_column}" + ) + if overlap >= 0.95: + cand.score += 0.35 + cand.evidence.append(Evidence("value_overlap", detail)) + elif overlap >= 0.70: + cand.score += 0.15 + cand.evidence.append(Evidence("value_overlap", detail)) + elif overlap < 0.50: + cand.dropped = True + cand.evidence.append(Evidence("value_overlap", detail + " — candidate rejected")) + + findings: list[Finding] = [] + for cand in candidates: + if cand.dropped: + continue + cardinality = _cardinality(cand) + findings.append( + Finding( + kind=KIND_RELATIONSHIP, + title=( + f"{cand.source.table_name}.{cand.source_column} → " + f"{cand.target.table_name}.{cand.target_column}" + ), + payload={ + "source_schema": cand.source.schema_name, + "source_table": cand.source.table_name, + "source_column": cand.source_column, + "target_schema": cand.target.schema_name, + "target_table": cand.target.table_name, + "target_column": cand.target_column, + "cardinality": cardinality, + }, + evidence=cand.evidence, + confidence=min(cand.score, 0.98), + ) + ) + return findings diff --git a/backend/app/semantic_compiler/inference/log_metrics.py b/backend/app/semantic_compiler/inference/log_metrics.py new file mode 100644 index 0000000..56e8349 --- /dev/null +++ b/backend/app/semantic_compiler/inference/log_metrics.py @@ -0,0 +1,76 @@ +"""Metric candidates from recurring aggregate shapes in query logs.""" + +from collections import defaultdict + +from app.semantic_compiler.sqlmeta import SqlAnalysis +from app.semantic_compiler.types import KIND_METRIC, Evidence, Finding + +_MIN_CALLS = 5 + + +def _bare(column: str) -> str: + return column.split(".")[-1] + + +def infer_log_metrics( + log_analyses: list[tuple[SqlAnalysis, int]], + existing_metric_findings: list[Finding], +) -> list[Finding]: + """Aggregate+GROUP BY shapes recurring across logged queries, weighted by calls. + + Skips expressions already proposed from views (views are stronger evidence). + """ + already_proposed = {f.payload["sql_expression"].lower() for f in existing_metric_findings} + + # (function, column) -> accumulated calls / dimensions / tables / example sql + shapes: dict[tuple[str, str | None], dict] = defaultdict( + lambda: {"calls": 0, "dimensions": set(), "tables": set(), "sql": None} + ) + for analysis, calls in log_analyses: + if not analysis.aggregates: + continue + for agg in analysis.aggregates: + if agg.function == "count" and agg.column is None and not analysis.group_by: + continue # bare COUNT(*) with no dims is noise + shape = shapes[(agg.function, agg.column)] + shape["calls"] += max(calls, 1) + shape["dimensions"].update(_bare(g) for g in analysis.group_by if "(" not in g) + shape["tables"].update(analysis.tables) + shape["sql"] = shape["sql"] or agg.sql + + findings: list[Finding] = [] + for (function, column), shape in shapes.items(): + if shape["calls"] < _MIN_CALLS or shape["sql"] is None: + continue + if shape["sql"].lower() in already_proposed: + continue + confidence = 0.5 + if shape["calls"] >= 20: + confidence += 0.1 + if shape["calls"] >= 100: + confidence += 0.1 + name_base = _bare(column) if column else "rows" + findings.append( + Finding( + kind=KIND_METRIC, + title=f"Recurring aggregate: {shape['sql']}", + payload={ + "metric_name": f"{function}_{name_base}".lower(), + "display_name": f"{function} {name_base}".replace("_", " ").title(), + "description": None, + "sql_expression": shape["sql"], + "aggregation_type": function, + "related_tables": sorted(shape["tables"]), + "dimensions": sorted(shape["dimensions"]), + "filters": {}, + }, + evidence=[ + Evidence( + "query_logs", + f"aggregate ran {shape['calls']} times across logged queries", + ) + ], + confidence=confidence, + ) + ) + return findings diff --git a/backend/app/semantic_compiler/inference/naming.py b/backend/app/semantic_compiler/inference/naming.py new file mode 100644 index 0000000..9053149 --- /dev/null +++ b/backend/app/semantic_compiler/inference/naming.py @@ -0,0 +1,22 @@ +"""Tiny dependency-free English pluralizer for table-name matching.""" + + +def plural_candidates(word: str) -> list[str]: + """Plausible table names for a singular entity word, best-first.""" + candidates = [word + "s"] + if word.endswith("y") and len(word) > 1 and word[-2] not in "aeiou": + candidates.append(word[:-1] + "ies") + if word.endswith(("s", "x", "z", "ch", "sh")): + candidates.append(word + "es") + candidates.append(word) # already-plural or uncountable table names + return candidates + + +def singularize(word: str) -> str: + if word.endswith("ies") and len(word) > 3: + return word[:-3] + "y" + if word.endswith("ses") or word.endswith("xes") or word.endswith("zes"): + return word[:-2] + if word.endswith("s") and not word.endswith("ss"): + return word[:-1] + return word diff --git a/backend/app/semantic_compiler/inference/pii.py b/backend/app/semantic_compiler/inference/pii.py new file mode 100644 index 0000000..bbcd698 --- /dev/null +++ b/backend/app/semantic_compiler/inference/pii.py @@ -0,0 +1,110 @@ +"""PII detection -> draft masking policy. + +Two independent signals: column-name patterns and value-shape regexes over +sampled rows. Both together -> high confidence; name alone -> medium. +""" + +import logging +import re +from re import Pattern + +from app.semantic_compiler.types import ( + KIND_MASKING, + Evidence, + Finding, + Prober, + TableProfile, +) + +logger = logging.getLogger(__name__) + +# category -> column-name pattern +_NAME_PATTERNS: dict[str, Pattern[str]] = { + "email": re.compile(r"e?mail", re.IGNORECASE), + "phone": re.compile(r"phone|mobile|fax", re.IGNORECASE), + "national_id": re.compile(r"ssn|national_id|tax_id|passport|nino", re.IGNORECASE), + "date_of_birth": re.compile(r"birth|dob", re.IGNORECASE), + "address": re.compile(r"address|street|postcode|zip_?code", re.IGNORECASE), + "bank_account": re.compile(r"iban|account_number|routing", re.IGNORECASE), + "person_name": re.compile(r"^(full|first|last|middle|family|given)_?name$", re.IGNORECASE), +} + +# category -> value-shape validator (None = name+type is the only signal) +_VALUE_PATTERNS: dict[str, Pattern[str] | None] = { + "email": re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$"), + "phone": re.compile(r"^\+?[\d][\d\s().-]{6,}$"), + "national_id": re.compile(r"^(\d{7,10}|\d{3}-\d{2}-\d{4})$"), + "date_of_birth": None, + "address": None, + "bank_account": re.compile(r"^[A-Z]{2}\d{2}[A-Z0-9]{10,30}$|^\d{8,17}$"), + "person_name": None, +} + +_SAMPLE_BUDGET = 30 # max columns to sample per run + + +def _categorize(column_name: str) -> str | None: + for category, pattern in _NAME_PATTERNS.items(): + if pattern.search(column_name): + return category + return None + + +async def infer_pii(tables: list[TableProfile], prober: Prober) -> list[Finding]: + findings: list[Finding] = [] + samples_left = _SAMPLE_BUDGET + + for table in tables: + if table.table_type != "table": + continue + for col in table.columns: + category = _categorize(col.name) + if category is None: + continue + if category == "date_of_birth" and "date" not in col.data_type.lower(): + continue + + confidence = 0.55 + evidence = [Evidence("naming", f"column name matches the {category} pattern")] + + validator = _VALUE_PATTERNS.get(category) + if validator is not None and samples_left > 0: + samples_left -= 1 + try: + values = await prober.sample_values( + table.schema_name, table.table_name, col.name, limit=20 + ) + except Exception as exc: + logger.debug( + "PII sampling failed for %s.%s: %s", table.table_name, col.name, exc + ) + values = [] + non_null = [str(v) for v in values if v is not None] + if non_null: + matched = sum(1 for v in non_null if validator.match(v.strip())) + if matched / len(non_null) >= 0.6: + confidence = 0.85 + evidence.append( + Evidence( + "value_overlap", + f"{matched} of {len(non_null)} sampled values " + f"match the {category} shape", + ) + ) + + findings.append( + Finding( + kind=KIND_MASKING, + title=f"PII candidate: {table.table_name}.{col.name} ({category})", + payload={ + "schema": table.schema_name, + "table": table.table_name, + "column": col.name, + "category": category, + "masked_column": f"{table.table_name}.{col.name}", + }, + evidence=evidence, + confidence=confidence, + ) + ) + return findings diff --git a/backend/app/semantic_compiler/inference/tenant_scope.py b/backend/app/semantic_compiler/inference/tenant_scope.py new file mode 100644 index 0000000..6526f6a --- /dev/null +++ b/backend/app/semantic_compiler/inference/tenant_scope.py @@ -0,0 +1,121 @@ +"""Tenant/scope-column detection -> draft row-filter data policy. + +A column that appears on several entity tables AND is filtered in most logged +queries is a multi-tenancy (or org/account scoping) key. Real schemas don't +carry the tenant column on every table (lookup tables never do; child tables +scope through their parent), so presence alone is a weak signal — the query +logs are the confirming evidence. Without log confirmation the score stays +below the default threshold, so single-tenant DBs aren't spammed. + +The draft policy is a template — an admin must substitute the real tenant +value before enabling it. +""" + +import re + +from app.semantic_compiler.sqlmeta import SqlAnalysis +from app.semantic_compiler.types import KIND_ROW_FILTER, Evidence, Finding, TableProfile + +_KNOWN_SCOPE_NAMES = { + "tenant_id", + "org_id", + "organization_id", + "company_id", + "account_id", + "workspace_id", + "client_id", +} +_MIN_TABLE_FRACTION = 0.3 +_MIN_LOG_FRACTION = 0.5 +_DEAD_SUFFIX_RE = re.compile(r"_(bak|backup|old|tmp|temp|archive|deprecated)\d*$", re.IGNORECASE) +_LOOKUP_COLUMNS = {"id", "code", "label", "name", "description"} + + +def _is_entity_table(table: TableProfile) -> bool: + """Base tables minus backups and id/code/label lookup tables.""" + if table.table_type != "table": + return False + if _DEAD_SUFFIX_RE.search(table.table_name): + return False + names = {c.name.lower() for c in table.columns} + if names <= _LOOKUP_COLUMNS and (table.row_count_estimate or 0) <= 100: + return False + return True + + +def infer_tenant_scope( + tables: list[TableProfile], + log_analyses: list[tuple[SqlAnalysis, int]], +) -> list[Finding]: + entity_tables = [t for t in tables if _is_entity_table(t)] + if len(entity_tables) < 3: + return [] + + # column name -> entity tables that carry it + carriers: dict[str, list[str]] = {} + for table in entity_tables: + for col in table.columns: + name = col.name.lower() + if name.endswith("_id") and not col.is_primary_key: + carriers.setdefault(name, []).append(table.table_name) + + findings: list[Finding] = [] + for column, carrying_tables in carriers.items(): + table_fraction = len(carrying_tables) / len(entity_tables) + if table_fraction < _MIN_TABLE_FRACTION or len(carrying_tables) < 2: + continue + + score = 0.35 + evidence = [ + Evidence( + "heuristic", + f"column {column} present on {len(carrying_tables)} of " + f"{len(entity_tables)} entity tables ({table_fraction:.0%})", + ) + ] + if column in _KNOWN_SCOPE_NAMES: + score += 0.1 + evidence.append(Evidence("naming", f"{column} is a conventional scoping column")) + + if log_analyses: + # Call-weighted, and only over queries touching carrier tables: + # a query against a reference table can't be expected to filter by + # tenant, and one-off statements (e.g. the compiler's own probes) + # shouldn't dilute a hot production query that ran 10k times. + carrier_set = {t.lower() for t in carrying_tables} + relevant_weight = 0 + filtered_weight = 0 + for analysis, calls in log_analyses: + if not (set(analysis.tables) & carrier_set): + continue + weight = max(calls, 1) + relevant_weight += weight + if any(ref.split(".")[-1] == column for ref in analysis.where_columns): + filtered_weight += weight + log_fraction = filtered_weight / relevant_weight if relevant_weight else 0.0 + if log_fraction >= _MIN_LOG_FRACTION: + score += 0.3 + evidence.append( + Evidence( + "query_logs", + f"filtered in {log_fraction:.0%} of logged query volume " + "against the tables that carry it", + ) + ) + + findings.append( + Finding( + kind=KIND_ROW_FILTER, + title=f"Scoping column detected: {column}", + payload={ + "column": column, + "tables": sorted(carrying_tables), + "row_filters": { + t: f"{t}.{column} = :tenant_id" for t in sorted(carrying_tables) + }, + }, + evidence=evidence, + confidence=min(score, 0.95), + ) + ) + return findings diff --git a/backend/app/semantic_compiler/inference/view_metrics.py b/backend/app/semantic_compiler/inference/view_metrics.py new file mode 100644 index 0000000..bf82e1f --- /dev/null +++ b/backend/app/semantic_compiler/inference/view_metrics.py @@ -0,0 +1,80 @@ +"""Metric extraction from view definitions. + +A handwritten view is business logic someone already wrote and tested: its +aggregates are metric definitions, its GROUP BY columns are dimensions, and +its WHERE clause is the canonical filter. +""" + +from app.semantic_compiler.sqlmeta import SqlAnalysis +from app.semantic_compiler.types import KIND_METRIC, Evidence, Finding, ViewDef + + +def _bare(column: str) -> str: + return column.split(".")[-1] + + +def _metric_name(view_name: str, alias: str | None, function: str, column: str | None) -> str: + if alias: + base = alias + elif column: + base = f"{function}_{_bare(column)}" + else: + base = function + prefix = view_name.removeprefix("v_").removeprefix("vw_") + return f"{prefix}_{base}".lower() + + +def infer_view_metrics( + view_analyses: list[tuple[ViewDef, SqlAnalysis]], + used_views: set[str] | None = None, +) -> list[Finding]: + """`used_views` = view names seen in query logs (small confidence boost).""" + used_views = used_views or set() + findings: list[Finding] = [] + seen_expressions: set[str] = set() + + for view, analysis in view_analyses: + if not analysis.aggregates: + continue + base_tables = [t for t in analysis.tables if t != view.view_name.lower()] + dimensions = sorted({_bare(g) for g in analysis.group_by if "(" not in g}) + for agg in analysis.aggregates: + expression_key = agg.sql.lower() + if expression_key in seen_expressions: + continue + seen_expressions.add(expression_key) + + confidence = 0.75 + evidence = [ + Evidence( + "view", + f"aggregate {agg.sql} defined in view {view.view_name}" + + (f" (grouped by {', '.join(dimensions)})" if dimensions else ""), + ) + ] + if view.view_name.lower() in used_views: + confidence += 0.05 + evidence.append(Evidence("query_logs", f"view {view.view_name} is queried")) + + findings.append( + Finding( + kind=KIND_METRIC, + title=f"Metric from {view.view_name}: {agg.sql}", + payload={ + "metric_name": _metric_name( + view.view_name, agg.alias, agg.function, agg.column + ), + "display_name": (agg.alias or agg.function).replace("_", " ").title(), + "description": None, # LLM annotation fills this + "sql_expression": agg.sql, + "aggregation_type": agg.function, + "related_tables": base_tables, + "dimensions": dimensions, + "filters": {"where": analysis.where_sql} if analysis.where_sql else {}, + "source_view": view.view_name, + }, + evidence=evidence, + confidence=confidence, + ) + ) + return findings diff --git a/backend/app/semantic_compiler/sqlmeta.py b/backend/app/semantic_compiler/sqlmeta.py new file mode 100644 index 0000000..c9cf60a --- /dev/null +++ b/backend/app/semantic_compiler/sqlmeta.py @@ -0,0 +1,152 @@ +"""sqlglot-based SQL analysis for the compiler. + +Richer than ``lineage_service.extract_refs`` (which only yields table/column +refs): extracts equi-join pairs, WHERE-clause columns, aggregate select items, +and GROUP BY dimensions. Like the lineage service, degrades gracefully — +returns ``None`` when sqlglot is unavailable or the statement doesn't parse. +""" + +import logging +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class JoinPair: + left_table: str + left_column: str + right_table: str + right_column: str + + def key(self) -> tuple[tuple[str, str], tuple[str, str]]: + """Direction-insensitive identity for co-occurrence matching.""" + sides = sorted([(self.left_table, self.left_column), (self.right_table, self.right_column)]) + return (sides[0], sides[1]) + + +@dataclass +class AggregateItem: + sql: str # rendered aggregate expression, e.g. "SUM(orders.total_amount)" + function: str # sum | count | avg | min | max | ... + column: str | None # "table.column" when resolvable + alias: str | None = None + + +@dataclass +class SqlAnalysis: + tables: list[str] = field(default_factory=list) # real table names, lowercase + join_pairs: list[JoinPair] = field(default_factory=list) + where_columns: list[str] = field(default_factory=list) # "table.column" or bare "column" + where_sql: str | None = None + aggregates: list[AggregateItem] = field(default_factory=list) + group_by: list[str] = field(default_factory=list) + + +def analyze(sql: str, dialect: str | None = None) -> SqlAnalysis | None: + """Parse one statement and extract compiler-relevant structure. + + Returns None if sqlglot is missing or parsing fails — callers treat the + statement as opaque rather than erroring. + """ + if not sql or not sql.strip(): + return None + try: + import sqlglot + from sqlglot import exp + except ImportError: + logger.debug("sqlglot not installed; skipping SQL analysis") + return None + + try: + tree = sqlglot.parse_one(sql, dialect=dialect) + except Exception as exc: + logger.debug("sql analysis parse failed: %s", exc) + return None + + # alias (or bare name) -> real table name, all lowercased + alias_to_table: dict[str, str] = {} + for table_node in tree.find_all(exp.Table): + name = table_node.name.lower() + alias_to_table[name] = name + alias = table_node.alias + if alias: + alias_to_table[alias.lower()] = name + tables = sorted(set(alias_to_table.values())) + only_table = tables[0] if len(tables) == 1 else None + + def resolve(col: exp.Column) -> tuple[str | None, str]: + qualifier = col.table.lower() if col.table else None + if qualifier: + return alias_to_table.get(qualifier, qualifier), col.name.lower() + return only_table, col.name.lower() + + def dotted(col: exp.Column) -> str: + table, name = resolve(col) + return f"{table}.{name}" if table else name + + # --- equi-join pairs: any column = column across two different tables --- + join_pairs: list[JoinPair] = [] + seen_pairs: set[tuple] = set() + for eq in tree.find_all(exp.EQ): + left, right = eq.this, eq.expression + if not (isinstance(left, exp.Column) and isinstance(right, exp.Column)): + continue + lt, lc = resolve(left) + rt, rc = resolve(right) + if not lt or not rt or lt == rt: + continue + pair = JoinPair(lt, lc, rt, rc) + if pair.key() not in seen_pairs: + seen_pairs.add(pair.key()) + join_pairs.append(pair) + + # --- WHERE columns + rendered WHERE text (outermost only) --- + where_columns: list[str] = [] + where_sql: str | None = None + select = tree if isinstance(tree, exp.Select) else tree.find(exp.Select) + where = select.args.get("where") if select is not None else None + if where is not None: + where_sql = where.this.sql(dialect=dialect) + seen_cols: set[str] = set() + for col in where.find_all(exp.Column): + ref = dotted(col) + if ref not in seen_cols: + seen_cols.add(ref) + where_columns.append(ref) + + # --- aggregates in the outermost projection --- + aggregates: list[AggregateItem] = [] + if select is not None: + for projection in select.expressions: + alias = projection.alias if isinstance(projection, exp.Alias) else None + for agg in projection.find_all(exp.AggFunc): + inner_col = agg.find(exp.Column) + aggregates.append( + AggregateItem( + sql=agg.sql(dialect=dialect), + function=agg.sql_name().lower(), + column=dotted(inner_col) if inner_col is not None else None, + alias=alias, + ) + ) + + # --- GROUP BY dimensions --- + group_by: list[str] = [] + if select is not None: + group = select.args.get("group") + if group is not None: + for g in group.expressions: + if isinstance(g, exp.Column): + group_by.append(dotted(g)) + else: + group_by.append(g.sql(dialect=dialect)) + + return SqlAnalysis( + tables=tables, + join_pairs=join_pairs, + where_columns=where_columns, + where_sql=where_sql, + aggregates=aggregates, + group_by=group_by, + ) diff --git a/backend/app/semantic_compiler/types.py b/backend/app/semantic_compiler/types.py new file mode 100644 index 0000000..d5f92e1 --- /dev/null +++ b/backend/app/semantic_compiler/types.py @@ -0,0 +1,157 @@ +"""Dataclasses and protocols shared across the compiler. No app imports.""" + +from dataclasses import dataclass, field +from typing import Any, Protocol + + +class Prober(Protocol): + """Narrow read-only database access used by collectors and probes. + + The in-app implementation adapts ``BaseConnector``; a standalone CLI can + implement it directly over asyncpg. + """ + + async def query(self, sql: str, max_rows: int = 1000) -> list[dict[str, Any]]: ... + + async def sample_values( + self, schema: str, table: str, column: str, limit: int = 20 + ) -> list[Any]: ... + + +@dataclass +class DeclaredFK: + source_column: str + target_schema: str + target_table: str + target_column: str + + +@dataclass +class ColumnProfile: + name: str + data_type: str + is_nullable: bool = True + is_primary_key: bool = False + comment: str | None = None + ordinal_position: int = 0 + # pg_stats enrichment (None = stats unavailable) + null_frac: float | None = None + n_distinct: float | None = None + most_common_vals: list[str] | None = None + most_common_freqs: list[float] | None = None + # constraint/index enrichment + is_unique: bool = False + check_in_values: list[str] | None = None + enum_values: list[str] | None = None + # sampled values (PII detection) + sample_values: list[Any] = field(default_factory=list) + + +@dataclass +class TableProfile: + schema_name: str + table_name: str + table_type: str = "table" # "table" | "view" + comment: str | None = None + row_count_estimate: int | None = None + columns: list[ColumnProfile] = field(default_factory=list) + declared_fks: list[DeclaredFK] = field(default_factory=list) + + @property + def qualified_name(self) -> str: + return f"{self.schema_name}.{self.table_name}" + + def column(self, name: str) -> ColumnProfile | None: + for col in self.columns: + if col.name == name: + return col + return None + + +@dataclass +class ViewDef: + schema_name: str + view_name: str + sql: str + + +@dataclass +class LoggedQuery: + sql: str + calls: int = 1 + total_time_ms: float = 0.0 + + +@dataclass +class Evidence: + source: str # naming | value_overlap | query_logs | pg_stats | view | constraint | heuristic + detail: str + + def as_dict(self) -> dict[str, str]: + return {"source": self.source, "detail": self.detail} + + +# Finding kinds +KIND_RELATIONSHIP = "relationship" +KIND_METRIC = "metric" +KIND_DICTIONARY = "dictionary" +KIND_GLOSSARY = "glossary" +KIND_ROW_FILTER = "data_policy_row_filter" +KIND_MASKING = "data_policy_masking" +KIND_DEAD_TABLE = "dead_table" +KIND_FANOUT = "fanout_warning" + +ALL_KINDS = ( + KIND_RELATIONSHIP, + KIND_METRIC, + KIND_DICTIONARY, + KIND_GLOSSARY, + KIND_ROW_FILTER, + KIND_MASKING, + KIND_DEAD_TABLE, + KIND_FANOUT, +) + + +@dataclass +class Finding: + kind: str + title: str + payload: dict[str, Any] + evidence: list[Evidence] = field(default_factory=list) + confidence: float = 0.0 + + +def _default_caps() -> dict[str, int]: + # Review fatigue kills draft-generation tools: cap output hard. + return { + KIND_RELATIONSHIP: 40, + KIND_METRIC: 30, + KIND_DICTIONARY: 60, + KIND_GLOSSARY: 15, + KIND_ROW_FILTER: 3, + KIND_MASKING: 25, + KIND_DEAD_TABLE: 20, + KIND_FANOUT: 20, + } + + +@dataclass +class Thresholds: + min_confidence: float = 0.5 + max_per_kind: dict[str, int] = field(default_factory=_default_caps) + # Value-overlap probes are real queries against the target DB — budget them. + probe_budget: int = 60 + probe_sample_rows: int = 500 + + +@dataclass +class CompilerInput: + dialect: str | None = None # sqlglot dialect, e.g. "postgres" + tables: list[TableProfile] = field(default_factory=list) + views: list[ViewDef] = field(default_factory=list) + logged_queries: list[LoggedQuery] = field(default_factory=list) + # Which evidence sources actually answered (for run stats / UI messaging). + sources_available: dict[str, bool] = field(default_factory=dict) + # ignore_declared_fks: treat declared FKs as absent (eval mode) + options: dict[str, Any] = field(default_factory=dict) diff --git a/backend/app/services/compilation_progress.py b/backend/app/services/compilation_progress.py new file mode 100644 index 0000000..5996552 --- /dev/null +++ b/backend/app/services/compilation_progress.py @@ -0,0 +1,68 @@ +"""In-memory progress tracker for semantic-compiler runs (mirrors embedding_progress).""" + +import asyncio +from dataclasses import dataclass +from datetime import UTC, datetime + + +@dataclass +class CompilationProgress: + connection_id: str + total: int = 0 + completed: int = 0 + stage: str = "" # human-readable current stage + status: str = "pending" # pending | running | completed | failed + error: str | None = None + started_at: datetime | None = None + finished_at: datetime | None = None + + +_progress: dict[str, CompilationProgress] = {} +_tasks: dict[str, asyncio.Task] = {} + + +def start_tracking(connection_id: str, total: int) -> CompilationProgress: + p = CompilationProgress( + connection_id=connection_id, + total=total, + status="running", + started_at=datetime.now(UTC), + ) + _progress[connection_id] = p + return p + + +def advance(connection_id: str, stage: str) -> None: + if connection_id in _progress: + p = _progress[connection_id] + p.completed += 1 + p.stage = stage + + +def mark_completed(connection_id: str) -> None: + if connection_id in _progress: + p = _progress[connection_id] + p.status = "completed" + p.completed = p.total + p.finished_at = datetime.now(UTC) + + +def mark_failed(connection_id: str, error: str) -> None: + if connection_id in _progress: + p = _progress[connection_id] + p.status = "failed" + p.error = error + p.finished_at = datetime.now(UTC) + + +def get_progress(connection_id: str) -> CompilationProgress | None: + return _progress.get(connection_id) + + +def is_running(connection_id: str) -> bool: + return connection_id in _progress and _progress[connection_id].status == "running" + + +def register_task(connection_id: str, task: asyncio.Task) -> None: + """Store task reference to prevent garbage collection (in-process queue only).""" + _tasks[connection_id] = task diff --git a/backend/app/services/compilation_service.py b/backend/app/services/compilation_service.py new file mode 100644 index 0000000..4035c16 --- /dev/null +++ b/backend/app/services/compilation_service.py @@ -0,0 +1,694 @@ +"""Semantic layer compiler — service layer. + +Orchestrates the engine in ``app/semantic_compiler/`` as a background job, +persists findings for review, and dispatches accepted findings into the real +semantic objects through the existing creation paths (embedding + lineage). + +Findings never touch the semantic tables until accepted: draft metrics and +glossary terms are retrieved by the query-pipeline context builder, so +unreviewed compiler output must stay out of them. +""" + +import asyncio +import logging +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.connectors.connector_registry import get_or_create_connector +from app.core.auth import AuthContext +from app.core.exceptions import NotFoundError +from app.db.models.compilation import CompilationFinding, CompilationRun +from app.db.models.data_policy import DataPolicy +from app.db.models.dictionary import DictionaryEntry +from app.db.models.glossary import GlossaryTerm +from app.db.models.metric import MetricDefinition +from app.db.models.schema_cache import CachedColumn, CachedRelationship, CachedTable +from app.db.session import async_session_factory +from app.jobs import get_job_queue, register_job +from app.semantic_compiler import CompilerInput, Thresholds, run_compiler +from app.semantic_compiler.collectors import ( + build_table_profiles, + collect_pg_stats, + collect_query_logs, + collect_view_definitions, +) +from app.services import compilation_progress as progress +from app.services.connection_service import get_connection, get_decrypted_connection_string +from app.services.lineage_service import dialect_for + +logger = logging.getLogger(__name__) + +# Merged "refusal boundary" policies created on accept (one per connection). +PII_POLICY_NAME = "Compiler: PII masking" +DEAD_TABLE_POLICY_NAME = "Compiler: dead tables" + +_RUN_STAGES = 6 # introspect, statistics, views, query logs, inference, annotate/persist + + +class ConnectorProber: + """Adapts a BaseConnector to the engine's narrow Prober protocol.""" + + def __init__(self, connector: Any): + self._connector = connector + + async def query(self, sql: str, max_rows: int = 1000) -> list[dict[str, Any]]: + result = await self._connector.execute_query(sql, timeout_seconds=15, max_rows=max_rows) + return [dict(zip(result.columns, row, strict=False)) for row in result.rows] + + async def sample_values( + self, schema: str, table: str, column: str, limit: int = 20 + ) -> list[Any]: + return await self._connector.get_sample_values(schema, table, column, limit) + + +# --------------------------------------------------------------------------- +# Run lifecycle +# --------------------------------------------------------------------------- + + +async def start_run( + db: AsyncSession, + connection_id: uuid.UUID, + ctx: AuthContext, + options: dict[str, Any] | None = None, +) -> CompilationRun: + """Create a queued run and launch the background job.""" + await get_connection(db, connection_id, ctx, write=True) + if progress.is_running(str(connection_id)): + raise ValueError("A compilation is already running for this connection") + + run = CompilationRun( + connection_id=connection_id, + status="queued", + options=options or {}, + triggered_by_id=ctx.user_id, + ) + db.add(run) + await db.flush() + await db.commit() + + queue = get_job_queue() + task = queue.submit("semantic_compilation", run.id, name=f"compile-{connection_id}") + if queue.backend_name == "inprocess" and isinstance(task, asyncio.Task): + progress.register_task(str(connection_id), task) + return run + + +async def _run_compilation_job(run_id: uuid.UUID) -> None: + """Background job: collect evidence, run inference, annotate, persist findings.""" + async with async_session_factory() as db: + run = await db.get(CompilationRun, run_id) + if run is None: + logger.warning("compilation run %s vanished before starting", run_id) + return + cid = str(run.connection_id) + progress.start_tracking(cid, _RUN_STAGES) + run.status = "running" + run.started_at = datetime.now(UTC) + await db.commit() + try: + await _execute_run(db, run) + run.status = "completed" + run.finished_at = datetime.now(UTC) + await db.commit() + progress.mark_completed(cid) + except Exception as exc: + await db.rollback() + run = await db.get(CompilationRun, run_id) + if run is not None: + run.status = "failed" + run.error = str(exc)[:2000] + run.finished_at = datetime.now(UTC) + await db.commit() + progress.mark_failed(cid, str(exc)) + logger.exception("compilation run %s failed", run_id) + + +async def _execute_run(db: AsyncSession, run: CompilationRun) -> None: + from app.services.identity_service import system_context + + ctx = await system_context(db) + conn = await get_connection(db, run.connection_id, ctx) + connector = await get_or_create_connector( + str(run.connection_id), conn.connector_type, get_decrypted_connection_string(conn) + ) + prober = ConnectorProber(connector) + cid = str(run.connection_id) + options = run.options or {} + is_postgres = conn.connector_type.lower() == "postgresql" + + # --- collect --- + progress.advance(cid, "Introspecting schema") + table_infos = [] + for schema in await connector.introspect_schemas(): + table_infos.extend(await connector.introspect_tables(schema)) + tables = build_table_profiles(table_infos) + + progress.advance(cid, "Reading column statistics") + stats_available = await collect_pg_stats(prober, tables) if is_postgres else False + + progress.advance(cid, "Reading view definitions") + views, views_available = await collect_view_definitions(prober) if is_postgres else ([], False) + + progress.advance(cid, "Reading query logs") + logged, logs_available = await collect_query_logs(prober) if is_postgres else ([], False) + + sources = { + "pg_stats": stats_available, + "views": views_available, + "query_logs": logs_available, + } + + # --- infer --- + progress.advance(cid, "Running inference") + thresholds = Thresholds() + if "min_confidence" in options: + thresholds.min_confidence = float(options["min_confidence"]) + inp = CompilerInput( + dialect=dialect_for(conn.connector_type), + tables=tables, + views=views, + logged_queries=logged, + sources_available=sources, + options={"ignore_declared_fks": bool(options.get("ignore_declared_fks"))}, + ) + findings = await run_compiler(inp, prober, thresholds) + + # --- annotate (optional) + persist --- + progress.advance(cid, "Annotating and saving findings") + if options.get("llm_enabled", True): + findings = await _annotate(findings) + + # A new run supersedes prior un-reviewed proposals; accepted/dismissed + # findings are review history (and the rematerialization source) — untouched. + stale = await db.execute( + select(CompilationFinding).where( + CompilationFinding.connection_id == run.connection_id, + CompilationFinding.status == "proposed", + ) + ) + superseded = 0 + for old in stale.scalars(): + old.status = "dismissed" + old.reviewed_at = datetime.now(UTC) + superseded += 1 + + counts: dict[str, int] = {} + for finding in findings: + counts[finding.kind] = counts.get(finding.kind, 0) + 1 + db.add( + CompilationFinding( + run_id=run.id, + connection_id=run.connection_id, + kind=finding.kind, + title=finding.title[:255], + payload=finding.payload, + evidence=[e.as_dict() for e in finding.evidence], + confidence=round(finding.confidence, 3), + ) + ) + run.stats = { + "findings": counts, + "sources_available": sources, + "superseded_proposals": superseded, + "tables_examined": len(tables), + "views_examined": len(views), + "logged_queries_examined": len(logged), + } + + +async def _annotate(findings: list) -> list: + """LLM naming pass — merged onto naming fields only, never structure. + + Best-effort: returns the findings unchanged if the provider is unavailable. + """ + from app.llm.agents.semantic_annotator import SemanticAnnotatorAgent + from app.llm.base_provider import LLMConfig + from app.llm.prompts.annotator_prompts import KIND_FIELDS + from app.llm.provider_registry import get_provider + + try: + provider = get_provider(settings.default_llm_provider) + except Exception as exc: + logger.warning("annotation skipped — provider unavailable: %s", exc) + return findings + model = ( + settings.ollama_model + if settings.default_llm_provider == "ollama" + else settings.default_llm_model + ) + agent = SemanticAnnotatorAgent(provider, LLMConfig(model=model, max_tokens=4096)) + + by_kind: dict[str, list[int]] = {} + for i, finding in enumerate(findings): + if finding.kind in KIND_FIELDS: + by_kind.setdefault(finding.kind, []).append(i) + + for kind, indices in by_kind.items(): + payloads = [ + { + "title": findings[i].title, + "payload": findings[i].payload, + "evidence": [e.as_dict() for e in findings[i].evidence], + } + for i in indices + ] + annotations = await agent.annotate(kind, payloads) + for local_index, fields in annotations.items(): + if 0 <= local_index < len(indices): + findings[indices[local_index]].payload.update(fields) + return findings + + +register_job("semantic_compilation", _run_compilation_job) + + +# --------------------------------------------------------------------------- +# Queries +# --------------------------------------------------------------------------- + + +async def list_runs( + db: AsyncSession, connection_id: uuid.UUID, ctx: AuthContext +) -> list[CompilationRun]: + await get_connection(db, connection_id, ctx) + result = await db.execute( + select(CompilationRun) + .where(CompilationRun.connection_id == connection_id) + .order_by(CompilationRun.created_at.desc()) + .limit(20) + ) + return list(result.scalars().all()) + + +async def get_run(db: AsyncSession, run_id: uuid.UUID, ctx: AuthContext) -> CompilationRun: + run = await db.get(CompilationRun, run_id) + if run is None: + raise NotFoundError("CompilationRun", str(run_id)) + await get_connection(db, run.connection_id, ctx) + return run + + +async def list_findings( + db: AsyncSession, + connection_id: uuid.UUID, + ctx: AuthContext, + status: str | None = None, + kind: str | None = None, +) -> list[CompilationFinding]: + await get_connection(db, connection_id, ctx) + stmt = select(CompilationFinding).where(CompilationFinding.connection_id == connection_id) + if status: + stmt = stmt.where(CompilationFinding.status == status) + if kind: + stmt = stmt.where(CompilationFinding.kind == kind) + stmt = stmt.order_by(CompilationFinding.kind, CompilationFinding.confidence.desc()).limit(500) + result = await db.execute(stmt) + return list(result.scalars().all()) + + +# --------------------------------------------------------------------------- +# Review: accept / dismiss +# --------------------------------------------------------------------------- + + +async def _get_finding( + db: AsyncSession, finding_id: uuid.UUID, ctx: AuthContext +) -> CompilationFinding: + finding = await db.get(CompilationFinding, finding_id) + if finding is None: + raise NotFoundError("CompilationFinding", str(finding_id)) + await get_connection(db, finding.connection_id, ctx, write=True) + return finding + + +async def dismiss_finding( + db: AsyncSession, finding_id: uuid.UUID, ctx: AuthContext +) -> CompilationFinding: + finding = await _get_finding(db, finding_id, ctx) + if finding.status != "proposed": + raise ValueError(f"Finding already {finding.status}") + finding.status = "dismissed" + finding.reviewed_by_id = ctx.user_id + finding.reviewed_at = datetime.now(UTC) + await db.flush() + return finding + + +async def accept_finding( + db: AsyncSession, finding_id: uuid.UUID, ctx: AuthContext +) -> CompilationFinding: + """Materialize a finding as a real semantic object (status stays 'draft' there).""" + finding = await _get_finding(db, finding_id, ctx) + if finding.status != "proposed": + raise ValueError(f"Finding already {finding.status}") + + handler = _ACCEPT_HANDLERS.get(finding.kind) + if handler is None: + raise ValueError(f"Unknown finding kind: {finding.kind}") + entity_type, entity_id = await handler(db, finding, ctx) + + finding.status = "accepted" + finding.created_entity_type = entity_type + finding.created_entity_id = entity_id + finding.reviewed_by_id = ctx.user_id + finding.reviewed_at = datetime.now(UTC) + await db.flush() + return finding + + +async def bulk_review( + db: AsyncSession, finding_ids: list[uuid.UUID], action: str, ctx: AuthContext +) -> dict[str, int]: + succeeded, failed = 0, 0 + for finding_id in finding_ids: + try: + if action == "accept": + await accept_finding(db, finding_id, ctx) + else: + await dismiss_finding(db, finding_id, ctx) + succeeded += 1 + except Exception as exc: + logger.warning("bulk %s failed for finding %s: %s", action, finding_id, exc) + failed += 1 + return {"succeeded": succeeded, "failed": failed} + + +# --- accept dispatch, one handler per kind --------------------------------- + + +async def _resolve_table( + db: AsyncSession, connection_id: uuid.UUID, schema: str | None, table: str +) -> CachedTable | None: + stmt = select(CachedTable).where( + CachedTable.connection_id == connection_id, CachedTable.table_name == table + ) + if schema: + stmt = stmt.where(CachedTable.schema_name == schema) + return (await db.execute(stmt.limit(1))).scalar_one_or_none() + + +async def _resolve_column( + db: AsyncSession, connection_id: uuid.UUID, schema: str | None, table: str, column: str +) -> CachedColumn | None: + cached_table = await _resolve_table(db, connection_id, schema, table) + if cached_table is None: + return None + stmt = select(CachedColumn).where( + CachedColumn.table_id == cached_table.id, CachedColumn.column_name == column + ) + return (await db.execute(stmt.limit(1))).scalar_one_or_none() + + +async def _create_relationship( + db: AsyncSession, + connection_id: uuid.UUID, + payload: dict, + confidence: float, + evidence: list, +) -> CachedRelationship | None: + source = await _resolve_table( + db, connection_id, payload.get("source_schema"), payload["source_table"] + ) + target = await _resolve_table( + db, connection_id, payload.get("target_schema"), payload["target_table"] + ) + if source is None or target is None: + return None + existing = await db.execute( + select(CachedRelationship).where( + CachedRelationship.connection_id == connection_id, + CachedRelationship.source_table_id == source.id, + CachedRelationship.source_column == payload["source_column"], + CachedRelationship.target_table_id == target.id, + CachedRelationship.target_column == payload["target_column"], + ) + ) + found = existing.scalars().first() + if found is not None: + return found + rel = CachedRelationship( + connection_id=connection_id, + constraint_name=None, + origin="inferred", + confidence=confidence, + cardinality=payload.get("cardinality"), + evidence=evidence, + source_table_id=source.id, + source_column=payload["source_column"], + target_table_id=target.id, + target_column=payload["target_column"], + ) + db.add(rel) + await db.flush() + return rel + + +async def _accept_relationship( + db: AsyncSession, finding: CompilationFinding, ctx: AuthContext +) -> tuple[str, uuid.UUID]: + rel = await _create_relationship( + db, finding.connection_id, finding.payload, finding.confidence, finding.evidence + ) + if rel is None: + raise ValueError("Source or target table not found in schema cache — re-introspect first") + return "relationship", rel.id + + +async def _accept_metric( + db: AsyncSession, finding: CompilationFinding, ctx: AuthContext +) -> tuple[str, uuid.UUID]: + from app.services.embedding_service import embed_metric + from app.services.lineage_service import recompute_metric + + p = finding.payload + metric = MetricDefinition( + connection_id=finding.connection_id, + organization_id=ctx.organization_id, + created_by_id=ctx.user_id, + metric_name=p["metric_name"], + display_name=p.get("display_name") or p["metric_name"].replace("_", " ").title(), + description=p.get("description"), + sql_expression=p["sql_expression"], + aggregation_type=p.get("aggregation_type"), + related_tables=p.get("related_tables") or [], + dimensions=p.get("dimensions") or [], + filters=p.get("filters") or {}, + ) + db.add(metric) + await db.flush() + try: + metric.metric_embedding = await embed_metric(metric) + except Exception as exc: + logger.warning("metric embedding deferred (provider unavailable): %s", exc) + try: + await recompute_metric(db, ctx, metric) + except Exception as exc: + logger.warning("metric lineage recompute failed: %s", exc) + return "metric", metric.id + + +async def _accept_glossary( + db: AsyncSession, finding: CompilationFinding, ctx: AuthContext +) -> tuple[str, uuid.UUID]: + from app.services.embedding_service import embed_glossary_term + + p = finding.payload + term = GlossaryTerm( + connection_id=finding.connection_id, + organization_id=ctx.organization_id, + created_by_id=ctx.user_id, + term=p["term"], + definition=p["definition"], + sql_expression=p.get("sql_expression") or p["term"], + related_tables=p.get("related_tables") or [], + related_columns=p.get("related_columns") or [], + examples=p.get("examples") or [], + ) + db.add(term) + await db.flush() + try: + term.term_embedding = await embed_glossary_term(term) + except Exception as exc: + logger.warning("glossary embedding deferred (provider unavailable): %s", exc) + return "glossary", term.id + + +async def _create_dictionary_entries( + db: AsyncSession, connection_id: uuid.UUID, payload: dict +) -> CachedColumn | None: + column = await _resolve_column( + db, connection_id, payload.get("schema"), payload["table"], payload["column"] + ) + if column is None: + return None + existing = await db.execute( + select(DictionaryEntry).where(DictionaryEntry.column_id == column.id).limit(1) + ) + if existing.scalars().first() is not None: + return column # entries already present — don't duplicate + for entry in payload.get("entries", []): + db.add( + DictionaryEntry( + column_id=column.id, + raw_value=str(entry["raw_value"])[:255], + display_value=str(entry.get("display_value") or entry["raw_value"])[:255], + description=entry.get("description"), + sort_order=int(entry.get("sort_order") or 0), + ) + ) + await db.flush() + return column + + +async def _accept_dictionary( + db: AsyncSession, finding: CompilationFinding, ctx: AuthContext +) -> tuple[str, uuid.UUID]: + column = await _create_dictionary_entries(db, finding.connection_id, finding.payload) + if column is None: + raise ValueError("Column not found in schema cache — re-introspect first") + return "dictionary_column", column.id + + +async def _get_or_create_merged_policy( + db: AsyncSession, finding: CompilationFinding, ctx: AuthContext, name: str +) -> DataPolicy: + result = await db.execute( + select(DataPolicy).where( + DataPolicy.connection_id == finding.connection_id, DataPolicy.name == name + ) + ) + policy = result.scalars().first() + if policy is None: + # Created DISABLED: policies have no draft status and enforce live — + # the reviewer flips `enabled` after checking the contents. + policy = DataPolicy( + connection_id=finding.connection_id, + organization_id=ctx.organization_id, + name=name, + enabled=False, + ) + db.add(policy) + await db.flush() + return policy + + +async def _accept_masking( + db: AsyncSession, finding: CompilationFinding, ctx: AuthContext +) -> tuple[str, uuid.UUID]: + policy = await _get_or_create_merged_policy(db, finding, ctx, PII_POLICY_NAME) + masked = finding.payload["masked_column"] + if masked not in (policy.masked_columns or []): + policy.masked_columns = [*(policy.masked_columns or []), masked] + await db.flush() + return "data_policy", policy.id + + +async def _accept_dead_table( + db: AsyncSession, finding: CompilationFinding, ctx: AuthContext +) -> tuple[str, uuid.UUID]: + policy = await _get_or_create_merged_policy(db, finding, ctx, DEAD_TABLE_POLICY_NAME) + table = finding.payload["table"] + if table not in (policy.blocked_tables or []): + policy.blocked_tables = [*(policy.blocked_tables or []), table] + await db.flush() + return "data_policy", policy.id + + +async def _accept_row_filter( + db: AsyncSession, finding: CompilationFinding, ctx: AuthContext +) -> tuple[str, uuid.UUID]: + p = finding.payload + policy = DataPolicy( + connection_id=finding.connection_id, + organization_id=ctx.organization_id, + name=f"Compiler: row filter on {p['column']}", + enabled=False, # the :tenant_id placeholder must be edited first + row_filters=p.get("row_filters") or {}, + ) + db.add(policy) + await db.flush() + return "data_policy", policy.id + + +async def _accept_fanout( + db: AsyncSession, finding: CompilationFinding, ctx: AuthContext +) -> tuple[str, uuid.UUID]: + """Fan-out guidance lands as a knowledge document — the knowledge resolver + already injects relevant chunks into the SQL-generation prompt.""" + from app.services.knowledge_service import import_document + + p = finding.payload + body = p.get("description") or "" + content = ( + f"{p['guidance']}\n\n{body}\n\n" + f"Join: {p['child_table']}.{p['join']['child_column']} = " + f"{p['parent_table']}.{p['join']['parent_column']} (N:1). " + f"Measure columns at risk on {p['parent_table']}: " + f"{', '.join(p.get('risky_columns', []))}." + ).strip() + doc = await import_document( + db, + connection_id=finding.connection_id, + title=f"Join guidance: {p['parent_table']} joined to {p['child_table']}", + content=content, + organization_id=ctx.organization_id, + source_url=None, + ) + return "knowledge", doc.id + + +_ACCEPT_HANDLERS = { + "relationship": _accept_relationship, + "metric": _accept_metric, + "glossary": _accept_glossary, + "dictionary": _accept_dictionary, + "data_policy_masking": _accept_masking, + "data_policy_row_filter": _accept_row_filter, + "dead_table": _accept_dead_table, + "fanout_warning": _accept_fanout, +} + + +# --------------------------------------------------------------------------- +# Rematerialization after re-introspection +# --------------------------------------------------------------------------- + + +async def rematerialize_accepted(db: AsyncSession, connection_id: uuid.UUID) -> dict[str, int]: + """Re-create cache-anchored artifacts from accepted findings. + + ``introspect_and_cache`` wipes all cached tables (cascading to inferred + relationships and dictionary entries). Accepted findings are name-keyed, + so they can be resolved against the fresh cache and re-created. + """ + result = await db.execute( + select(CompilationFinding).where( + CompilationFinding.connection_id == connection_id, + CompilationFinding.status == "accepted", + CompilationFinding.kind.in_(["relationship", "dictionary"]), + ) + ) + relationships = 0 + dictionaries = 0 + for finding in result.scalars(): + try: + if finding.kind == "relationship": + rel = await _create_relationship( + db, connection_id, finding.payload, finding.confidence, finding.evidence + ) + if rel is not None: + relationships += 1 + else: + column = await _create_dictionary_entries(db, connection_id, finding.payload) + if column is not None: + dictionaries += 1 + except Exception as exc: + logger.warning("rematerialization failed for finding %s: %s", finding.id, exc) + await db.flush() + return {"relationships": relationships, "dictionary_columns": dictionaries} diff --git a/backend/app/services/schema_service.py b/backend/app/services/schema_service.py index 789f6a5..93c8e22 100644 --- a/backend/app/services/schema_service.py +++ b/backend/app/services/schema_service.py @@ -104,10 +104,18 @@ async def introspect_and_cache( conn.last_introspected_at = datetime.now(UTC) await db.flush() + # Re-create compiler-accepted artifacts that hang off the (just wiped) + # schema cache: inferred relationships and dictionary entries. Accepted + # compilation findings are name-keyed and survive re-introspection. + from app.services import compilation_service + + rematerialized = await compilation_service.rematerialize_accepted(db, connection_id) + return { "tables_found": total_tables, "columns_found": total_columns, "relationships_found": total_relationships, + "rematerialized": rematerialized, } diff --git a/backend/scripts/eval_compiler_ifrs9.py b/backend/scripts/eval_compiler_ifrs9.py new file mode 100644 index 0000000..d50af91 --- /dev/null +++ b/backend/scripts/eval_compiler_ifrs9.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +"""Evaluate the semantic layer compiler against the IFRS 9 sample DB. + +Ground truth is the hand-written seed metadata (``seed_ifrs9_metadata.py``): +10 glossary terms, 8 metrics, 43 dictionary entries — plus the 5 declared FK +edges (the run uses ``ignore_declared_fks`` so join inference is actually +exercised against a schema where the answer key exists). + +Usage (stack must be up, IFRS 9 connection introspected): + python backend/scripts/eval_compiler_ifrs9.py [--base-url http://localhost:8000] + [--llm] [--min-confidence 0.4] [--skip-analyze] + +By default the LLM naming pass is OFF so the eval is deterministic. +""" + +import argparse +import re +import sys +import time + +import httpx +from seed_ifrs9_metadata import DICTIONARY_ENTRIES, GLOSSARY_TERMS, METRICS + +API_PREFIX = "/api/v1" + +# Declared FK edges in tests/fixtures/sample_seed.sql — the join-inference answer key. +EXPECTED_RELATIONSHIPS = { + ("facilities", "counterparty_id", "counterparties", "id"), + ("exposures", "facility_id", "facilities", "id"), + ("ecl_provisions", "exposure_id", "exposures", "id"), + ("collateral", "facility_id", "facilities", "id"), + ("staging_history", "facility_id", "facilities", "id"), +} + + +def get_ifrs9_connection_id(client: httpx.Client, name: str) -> str: + """The IFRS 9 connection by name (not just the first connection — other + connections, e.g. the opsdb fixture, may exist).""" + response = client.get(f"{API_PREFIX}/connections") + response.raise_for_status() + for conn in response.json(): + if conn["name"] == name: + print(f" Using connection: {conn['name']} ({conn['id']})") + return conn["id"] + print(f"ERROR: no connection named {name!r}. Is AUTO_SETUP_SAMPLE_DB enabled?") + sys.exit(1) + + +def normalize_sql(sql: str) -> str: + """Whitespace/case-insensitive normalization; sqlglot when available.""" + try: + import sqlglot + + return sqlglot.parse_one(sql, dialect="postgres").sql(dialect="postgres").lower() + except Exception: + return re.sub(r"\s+", "", sql.lower()) + + +def aggregate_signature(sql: str) -> tuple[str, str] | None: + """(function, bare column) — fuzzy identity for metric matching.""" + match = re.search(r"(sum|count|avg|min|max)\s*\(\s*(?:\w+\.)?(\w+|\*)", sql.lower()) + return (match.group(1), match.group(2)) if match else None + + +def run_compiler(client: httpx.Client, connection_id: str, args) -> dict: + response = client.post( + f"{API_PREFIX}/connections/{connection_id}/compilation/runs", + json={ + "llm_enabled": args.llm, + "min_confidence": args.min_confidence, + "ignore_declared_fks": True, + }, + ) + response.raise_for_status() + run = response.json() + print(f"Run {run['id']} started; waiting...") + + deadline = time.time() + 600 + while time.time() < deadline: + time.sleep(2) + run = client.get( + f"{API_PREFIX}/connections/{connection_id}/compilation/runs/{run['id']}" + ).json() + if run["status"] in ("completed", "failed"): + break + progress = run.get("progress") or {} + print(f" ... {progress.get('stage', run['status'])}") + if run["status"] != "completed": + print(f"Run did not complete: {run['status']} — {run.get('error')}") + sys.exit(1) + print(f"Run completed. Stats: {run['stats']}") + return run + + +def fetch_findings(client: httpx.Client, connection_id: str) -> list[dict]: + response = client.get( + f"{API_PREFIX}/connections/{connection_id}/compilation/findings", + params={"status": "proposed"}, + ) + response.raise_for_status() + return response.json() + + +def eval_relationships(findings: list[dict]) -> tuple[str, list[bool], list[float]]: + rels = [f for f in findings if f["kind"] == "relationship"] + proposed = { + ( + f["payload"]["source_table"], + f["payload"]["source_column"], + f["payload"]["target_table"], + f["payload"]["target_column"], + ): f["confidence"] + for f in rels + } + matched = EXPECTED_RELATIONSHIPS & set(proposed) + recall = len(matched) / len(EXPECTED_RELATIONSHIPS) + precision = len(matched) / len(proposed) if proposed else 0.0 + correctness = [key in EXPECTED_RELATIONSHIPS for key in proposed] + confidences = list(proposed.values()) + missed = EXPECTED_RELATIONSHIPS - matched + line = f"relationships recall {recall:.0%} ({len(matched)}/5) precision {precision:.0%}" + if missed: + line += f"\n missed: {sorted(missed)}" + return line, correctness, confidences + + +def eval_dictionary(findings: list[dict]) -> tuple[str, list[bool], list[float]]: + truth: set[tuple[str, str, str]] = set() + for (table, column), entries in DICTIONARY_ENTRIES.items(): + for entry in entries: + truth.add((table, column, str(entry["raw_value"]))) + + proposed: set[tuple[str, str, str]] = set() + correctness: list[bool] = [] + confidences: list[float] = [] + for f in findings: + if f["kind"] != "dictionary": + continue + payload = f["payload"] + hit_any = False + for entry in payload.get("entries", []): + key = (payload["table"], payload["column"], str(entry["raw_value"])) + proposed.add(key) + hit_any = hit_any or key in truth + correctness.append(hit_any) + confidences.append(f["confidence"]) + + matched = truth & proposed + recall = len(matched) / len(truth) if truth else 0.0 + precision = len(matched) / len(proposed) if proposed else 0.0 + return ( + f"dictionary recall {recall:.0%} ({len(matched)}/{len(truth)}) " + f"precision {precision:.0%} ({len(proposed)} proposed values)", + correctness, + confidences, + ) + + +def eval_metrics(findings: list[dict]) -> str: + truth_sigs = {aggregate_signature(m["sql_expression"]): m["metric_name"] for m in METRICS} + truth_sigs.pop(None, None) + proposed = [f for f in findings if f["kind"] == "metric"] + proposed_sigs = {aggregate_signature(f["payload"]["sql_expression"]) for f in proposed} - {None} + matched = set(truth_sigs) & proposed_sigs + missed = [truth_sigs[s] for s in set(truth_sigs) - matched] + return ( + f"metrics fuzzy recall {len(matched)}/{len(truth_sigs)} " + f"(by aggregate+column) {len(proposed)} proposed\n" + f" missed: {sorted(missed)}" + ) + + +def eval_glossary(findings: list[dict]) -> str: + proposed = [f for f in findings if f["kind"] == "glossary"] + covered = 0 + for term in GLOSSARY_TERMS: + gt_tables = set(term.get("related_tables") or []) + if any(gt_tables & set(f["payload"].get("related_tables") or []) for f in proposed): + covered += 1 + return ( + f"glossary table-coverage {covered}/{len(GLOSSARY_TERMS)} " + f"(soft metric — entity naming can't recover domain terms like 'EAD' " + f"from schema alone) {len(proposed)} proposed" + ) + + +def calibration(correctness: list[bool], confidences: list[float]) -> str: + right = [c for ok, c in zip(correctness, confidences, strict=False) if ok] + wrong = [c for ok, c in zip(correctness, confidences, strict=False) if not ok] + mean = lambda xs: sum(xs) / len(xs) if xs else float("nan") # noqa: E731 + return ( + f"confidence calibration: correct findings avg {mean(right):.2f} " + f"({len(right)}), incorrect avg {mean(wrong):.2f} ({len(wrong)})" + ) + + +def maybe_analyze(args) -> None: + """pg_stats is empty without ANALYZE; run it directly against sample-db.""" + if args.skip_analyze: + return + try: + import asyncio + + import asyncpg + + async def _go(): + conn = await asyncpg.connect(args.sample_dsn) + try: + await conn.execute("ANALYZE") + finally: + await conn.close() + + asyncio.run(_go()) + print("ANALYZE on sampledb done.") + except Exception as exc: + print(f"WARNING: could not ANALYZE sampledb ({exc}) — pg_stats may be empty.") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--base-url", default="http://localhost:8000") + parser.add_argument( + "--sample-dsn", default="postgresql://sample:sample_dev@localhost:5433/sampledb" + ) + parser.add_argument("--llm", action="store_true", help="enable the LLM naming pass") + parser.add_argument("--min-confidence", type=float, default=0.4) + parser.add_argument("--skip-analyze", action="store_true") + parser.add_argument("--connection-name", default="IFRS 9 Sample DB") + args = parser.parse_args() + + maybe_analyze(args) + + with httpx.Client(base_url=args.base_url, timeout=60) as client: + connection_id = get_ifrs9_connection_id(client, args.connection_name) + run_compiler(client, connection_id, args) + findings = fetch_findings(client, connection_id) + + print(f"\n{len(findings)} proposed findings\n" + "=" * 60) + rel_line, rel_ok, rel_conf = eval_relationships(findings) + dict_line, dict_ok, dict_conf = eval_dictionary(findings) + print(rel_line) + print(dict_line) + print(eval_metrics(findings)) + print(eval_glossary(findings)) + print(calibration(rel_ok + dict_ok, rel_conf + dict_conf)) + + +if __name__ == "__main__": + main() diff --git a/backend/scripts/run_ops_workload.py b/backend/scripts/run_ops_workload.py new file mode 100644 index 0000000..24cf187 --- /dev/null +++ b/backend/scripts/run_ops_workload.py @@ -0,0 +1,109 @@ +"""Run a representative query workload against the opsdb fixture. + +Populates pg_stat_statements so the semantic layer compiler's query-log +collector has evidence to mine: join co-occurrence (orders <-> customers, +order_items <-> orders, status lookups), tenant-scoped WHERE clauses on +nearly every query, and recurring aggregate shapes. + +Usage (sample-db container must be up): + python backend/scripts/run_ops_workload.py [--dsn DSN] [--rounds N] +""" + +import argparse +import asyncio +import random + +import asyncpg + +DEFAULT_DSN = "postgresql://sample:sample_dev@localhost:5433/opsdb" + +# Each shape is run every round with fresh literals. pg_stat_statements +# normalizes constants, so repeated shapes accumulate `calls`. +QUERY_SHAPES: list[str] = [ + # --- tenant-scoped single-table lookups (tenant ubiquity signal) --- + "SELECT * FROM customers WHERE tenant_id = {t} AND deleted_at IS NULL LIMIT 50", + "SELECT * FROM customers WHERE tenant_id = {t} AND status = {cs}", + "SELECT * FROM orders WHERE tenant_id = {t} AND order_date >= DATE '2024-06-01'", + "SELECT * FROM orders WHERE tenant_id = {t} AND status = {os} AND deleted_at IS NULL", + "SELECT * FROM events WHERE tenant_id = {t} AND entity_type = 'order' LIMIT 100", + "SELECT * FROM customers WHERE tenant_id = {t} AND email = 'customer{i}@example.com'", + # --- joins (co-occurrence evidence; no FKs exist, logs are the proof) --- + """SELECT c.full_name, o.id, o.total_amount + FROM orders o JOIN customers c ON o.customer_id = c.id + WHERE o.tenant_id = {t} AND o.deleted_at IS NULL LIMIT 100""", + """SELECT o.id, SUM(oi.quantity * oi.unit_price) AS line_total + FROM order_items oi JOIN orders o ON oi.order_id = o.id + WHERE o.tenant_id = {t} GROUP BY o.id LIMIT 100""", + """SELECT p.name, SUM(oi.quantity) AS units + FROM order_items oi JOIN products p ON oi.product_id = p.id + GROUP BY p.name ORDER BY units DESC LIMIT 20""", + """SELECT o.id, pay.amount, pay.paid_at + FROM payments pay JOIN orders pay_o ON pay.order_id = pay_o.id + JOIN orders o ON o.id = pay_o.id WHERE o.tenant_id = {t} LIMIT 50""", + """SELECT c.full_name, cs.label + FROM customers c JOIN customer_statuses cs ON c.status = cs.id + WHERE c.tenant_id = {t} LIMIT 50""", + """SELECT o.id, os.label + FROM orders o JOIN order_statuses os ON o.status = os.id + WHERE o.tenant_id = {t} AND o.order_date >= DATE '2024-01-01' LIMIT 50""", + """SELECT c.id, c.full_name, COUNT(o.id) AS n + FROM customers c LEFT JOIN orders o ON o.customer_id = c.id + WHERE c.tenant_id = {t} AND c.deleted_at IS NULL + GROUP BY c.id, c.full_name ORDER BY n DESC LIMIT 25""", + # --- recurring aggregates (metric candidates) --- + """SELECT SUM(total_amount) FROM orders + WHERE tenant_id = {t} AND status = 3 AND deleted_at IS NULL""", + """SELECT date_trunc('month', order_date) AS m, SUM(total_amount) + FROM orders WHERE tenant_id = {t} AND deleted_at IS NULL AND status = 3 + GROUP BY m ORDER BY m""", + """SELECT status, COUNT(*) FROM orders + WHERE tenant_id = {t} AND deleted_at IS NULL GROUP BY status""", + """SELECT category, COUNT(*) FROM products GROUP BY category""", + """SELECT AVG(total_amount) FROM orders + WHERE tenant_id = {t} AND status = 3 AND deleted_at IS NULL""", + """SELECT COUNT(*) FROM customers + WHERE tenant_id = {t} AND status = 1 AND deleted_at IS NULL""", + """SELECT method, SUM(amount) FROM payments GROUP BY method""", + # --- view usage --- + "SELECT * FROM v_monthly_revenue WHERE tenant_id = {t} ORDER BY month DESC LIMIT 12", + "SELECT COUNT(*) FROM v_active_customers WHERE tenant_id = {t}", + """SELECT * FROM v_customer_order_counts WHERE tenant_id = {t} + ORDER BY lifetime_value DESC LIMIT 10""", +] + + +async def run(dsn: str, rounds: int) -> None: + conn = await asyncpg.connect(dsn) + rng = random.Random(42) + executed = 0 + try: + for _ in range(rounds): + for shape in QUERY_SHAPES: + sql = shape.format( + t=rng.randint(1, 5), + cs=rng.randint(1, 4), + os=rng.randint(1, 4), + i=rng.randint(1, 200), + ) + await conn.fetch(sql) + executed += 1 + await conn.execute("ANALYZE") + tracked = await conn.fetchval( + "SELECT count(*) FROM pg_stat_statements WHERE dbid = " + "(SELECT oid FROM pg_database WHERE datname = current_database())" + ) + print(f"Executed {executed} queries; pg_stat_statements now tracks {tracked} statements.") + finally: + await conn.close() + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dsn", default=DEFAULT_DSN) + parser.add_argument("--rounds", type=int, default=20) + args = parser.parse_args() + asyncio.run(run(args.dsn, args.rounds)) + + +if __name__ == "__main__": + main() diff --git a/backend/tests/fixtures/ops_extensions.sql b/backend/tests/fixtures/ops_extensions.sql new file mode 100644 index 0000000..7132fc2 --- /dev/null +++ b/backend/tests/fixtures/ops_extensions.sql @@ -0,0 +1,3 @@ +-- Runs first (mounted as 05_extensions.sql). Requires the container to start +-- postgres with -c shared_preload_libraries=pg_stat_statements (see docker-compose.yml). +CREATE EXTENSION IF NOT EXISTS pg_stat_statements; diff --git a/backend/tests/fixtures/ops_seed.sql b/backend/tests/fixtures/ops_seed.sql new file mode 100644 index 0000000..e123d41 --- /dev/null +++ b/backend/tests/fixtures/ops_seed.sql @@ -0,0 +1,246 @@ +-- ============================================================================ +-- "opsdb" — a deliberately hostile operational-style database. +-- +-- Exercises the semantic layer compiler against the pathologies that real +-- operational schemas exhibit and warehouses don't: +-- * NO foreign keys declared anywhere (joins must be inferred) +-- * tenant_id scoping column on most tables (row-filter policy signal) +-- * soft deletes via deleted_at (canonical-filter signal) +-- * int-coded status columns + id/code/label lookup tables (dictionary signal) +-- * PII columns with realistic value shapes (masking-policy signal) +-- * handwritten views encoding business logic (view -> metric signal) +-- * an append-only audit/event table and a dead *_bak table +-- +-- Runs in the sample-db container after the IFRS 9 seed (mounted as +-- 20_ops_seed.sql). Creates a separate database so sampledb stays pristine. +-- ============================================================================ + +CREATE DATABASE opsdb OWNER sample; +\connect opsdb + +CREATE EXTENSION IF NOT EXISTS pg_stat_statements; + +SELECT setseed(0.42); + +-- --------------------------------------------------------------------------- +-- Tables (no foreign keys, on purpose) +-- --------------------------------------------------------------------------- + +CREATE TABLE tenants ( + id BIGINT PRIMARY KEY, + name TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE customer_statuses ( + id INT PRIMARY KEY, + code TEXT NOT NULL, + label TEXT NOT NULL +); + +CREATE TABLE customers ( + id BIGINT PRIMARY KEY, + tenant_id BIGINT NOT NULL, + full_name TEXT NOT NULL, + email TEXT NOT NULL, + phone TEXT, + date_of_birth DATE, + national_id TEXT, + status INT NOT NULL, + deleted_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE products ( + id BIGINT PRIMARY KEY, + sku TEXT NOT NULL, + name TEXT NOT NULL, + category TEXT NOT NULL, + unit_price NUMERIC(10, 2) NOT NULL, + discontinued_at TIMESTAMPTZ +); + +CREATE TABLE order_statuses ( + id INT PRIMARY KEY, + code TEXT NOT NULL, + label TEXT NOT NULL +); + +CREATE TABLE orders ( + id BIGINT PRIMARY KEY, + tenant_id BIGINT NOT NULL, + customer_id BIGINT NOT NULL, + status INT NOT NULL, + order_date DATE NOT NULL, + total_amount NUMERIC(12, 2) NOT NULL, + deleted_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE order_items ( + id BIGINT PRIMARY KEY, + order_id BIGINT NOT NULL, + product_id BIGINT NOT NULL, + quantity INT NOT NULL, + unit_price NUMERIC(10, 2) NOT NULL +); + +CREATE TABLE payments ( + id BIGINT PRIMARY KEY, + order_id BIGINT NOT NULL, + amount NUMERIC(12, 2) NOT NULL, + method INT NOT NULL, + paid_at TIMESTAMPTZ NOT NULL +); + +-- Append-only audit table: high churn, polymorphic refs, never joined in views. +CREATE TABLE events ( + id BIGSERIAL PRIMARY KEY, + tenant_id BIGINT NOT NULL, + entity_type TEXT NOT NULL, + entity_id BIGINT NOT NULL, + event_type TEXT NOT NULL, + payload JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- Dead table: schema copy of customers, zero rows. +CREATE TABLE customers_bak ( + id BIGINT, + tenant_id BIGINT, + full_name TEXT, + email TEXT, + phone TEXT, + date_of_birth DATE, + national_id TEXT, + status INT, + deleted_at TIMESTAMPTZ, + created_at TIMESTAMPTZ +); + +-- --------------------------------------------------------------------------- +-- Seed data +-- --------------------------------------------------------------------------- + +INSERT INTO tenants (id, name) +SELECT i, 'Tenant ' || chr(64 + i::int) +FROM generate_series(1, 5) AS i; + +INSERT INTO customer_statuses (id, code, label) VALUES + (1, 'ACTIVE', 'Active'), + (2, 'INACTIVE', 'Inactive'), + (3, 'SUSPENDED', 'Suspended'), + (4, 'CLOSED', 'Closed'); + +INSERT INTO order_statuses (id, code, label) VALUES + (1, 'PENDING', 'Pending'), + (2, 'PAID', 'Paid'), + (3, 'COMPLETED', 'Completed'), + (4, 'CANCELLED', 'Cancelled'); + +INSERT INTO customers (id, tenant_id, full_name, email, phone, date_of_birth, + national_id, status, deleted_at, created_at) +SELECT + i, + 1 + (i % 5), + 'Customer ' || i, + 'customer' || i || '@example.com', + '+1-555-' || lpad((1000 + i)::text, 4, '0'), + DATE '1955-01-01' + (random() * 15000)::int, + lpad((100000000 + i * 37)::text, 9, '0'), + CASE WHEN random() < 0.70 THEN 1 + WHEN random() < 0.55 THEN 2 + WHEN random() < 0.50 THEN 3 + ELSE 4 END, + CASE WHEN random() < 0.10 THEN now() - (random() * 200 || ' days')::interval END, + now() - (random() * 900 || ' days')::interval +FROM generate_series(1, 200) AS i; + +INSERT INTO products (id, sku, name, category, unit_price, discontinued_at) +SELECT + i, + 'SKU-' || lpad(i::text, 5, '0'), + 'Product ' || i, + (ARRAY['electronics', 'apparel', 'home', 'sports', 'grocery'])[1 + (i % 5)], + round((5 + random() * 495)::numeric, 2), + CASE WHEN random() < 0.08 THEN now() - (random() * 400 || ' days')::interval END +FROM generate_series(1, 40) AS i; + +INSERT INTO orders (id, tenant_id, customer_id, status, order_date, total_amount, + deleted_at, created_at) +SELECT + i, + 1 + (i % 5), + 1 + (i * 7) % 200, + CASE WHEN random() < 0.10 THEN 1 + WHEN random() < 0.25 THEN 2 + WHEN random() < 0.85 THEN 3 + ELSE 4 END, + DATE '2024-01-01' + (random() * 520)::int, + round((10 + random() * 1990)::numeric, 2), + CASE WHEN random() < 0.03 THEN now() - (random() * 100 || ' days')::interval END, + now() - (random() * 500 || ' days')::interval +FROM generate_series(1, 2000) AS i; + +INSERT INTO order_items (id, order_id, product_id, quantity, unit_price) +SELECT + i, + 1 + (i % 2000), + 1 + (i * 13) % 40, + 1 + (random() * 4)::int, + round((5 + random() * 495)::numeric, 2) +FROM generate_series(1, 6000) AS i; + +INSERT INTO payments (id, order_id, amount, method, paid_at) +SELECT + o.id, + o.id, + o.total_amount, + CASE WHEN random() < 0.55 THEN 1 WHEN random() < 0.75 THEN 2 ELSE 3 END, + o.order_date::timestamptz + interval '1 day' +FROM orders o +WHERE o.status IN (2, 3); + +INSERT INTO events (tenant_id, entity_type, entity_id, event_type, payload, created_at) +SELECT + 1 + (i % 5), + CASE WHEN i % 3 = 0 THEN 'customer' ELSE 'order' END, + 1 + (i % 2000), + (ARRAY['created', 'updated', 'status_changed', 'deleted'])[1 + (i % 4)], + jsonb_build_object('source', 'ops', 'seq', i), + now() - (random() * 300 || ' days')::interval +FROM generate_series(1, 3000) AS i; + +-- --------------------------------------------------------------------------- +-- Views: crystallized business logic (the compiler's richest evidence) +-- --------------------------------------------------------------------------- + +CREATE VIEW v_active_customers AS +SELECT id, tenant_id, full_name, email, status, created_at +FROM customers +WHERE deleted_at IS NULL AND status = 1; + +CREATE VIEW v_monthly_revenue AS +SELECT + tenant_id, + date_trunc('month', order_date) AS month, + SUM(total_amount) AS revenue, + COUNT(*) AS order_count +FROM orders +WHERE deleted_at IS NULL AND status = 3 +GROUP BY tenant_id, date_trunc('month', order_date); + +CREATE VIEW v_customer_order_counts AS +SELECT + c.id AS customer_id, + c.tenant_id, + c.full_name, + COUNT(o.id) AS order_count, + SUM(o.total_amount) AS lifetime_value +FROM customers c +LEFT JOIN orders o ON o.customer_id = c.id AND o.deleted_at IS NULL +WHERE c.deleted_at IS NULL +GROUP BY c.id, c.tenant_id, c.full_name; + +-- Populate pg_stats (most_common_vals etc.) — collectors are blind without this. +ANALYZE; diff --git a/backend/tests/test_compilation_service.py b/backend/tests/test_compilation_service.py new file mode 100644 index 0000000..c88d0e2 --- /dev/null +++ b/backend/tests/test_compilation_service.py @@ -0,0 +1,98 @@ +"""Annotator agent + service-level annotation merge (no DB, no LLM).""" + +from types import SimpleNamespace + +from app.llm.agents.semantic_annotator import SemanticAnnotatorAgent +from app.llm.base_provider import LLMConfig +from app.semantic_compiler.types import Evidence, Finding +from app.services import compilation_service + + +class _FakeProvider: + def __init__(self, content: str): + self._content = content + self.calls = 0 + + async def complete(self, messages, config): + self.calls += 1 + return SimpleNamespace(content=self._content) + + +class _ExplodingProvider: + async def complete(self, messages, config): + raise RuntimeError("provider down") + + +def _agent(content: str) -> SemanticAnnotatorAgent: + return SemanticAnnotatorAgent(_FakeProvider(content), LLMConfig(model="fake")) + + +_FINDINGS = [ + {"title": "Metric from v_monthly_revenue", "payload": {"metric_name": "x"}, "evidence": []}, + {"title": "Recurring aggregate", "payload": {"metric_name": "y"}, "evidence": []}, +] + + +async def test_annotate_merges_allowed_fields(): + agent = _agent( + '{"annotations": [{"index": 0, "metric_name": "Monthly Revenue!", ' + '"display_name": "Monthly Revenue", "description": "Total completed-order revenue."}]}' + ) + result = await agent.annotate("metric", _FINDINGS) + assert result[0]["metric_name"] == "monthly_revenue" # sanitized to identifier + assert result[0]["display_name"] == "Monthly Revenue" + + +async def test_annotate_rejects_unknown_indices_and_fields(): + agent = _agent( + '{"annotations": [' + '{"index": 7, "description": "phantom finding"},' + '{"index": 1, "sql_expression": "SUM(invented)", "description": "ok"}]}' + ) + result = await agent.annotate("metric", _FINDINGS) + assert 7 not in result + assert "sql_expression" not in result[1] # structural field dropped + assert result[1]["description"] == "ok" + + +async def test_annotate_swallows_provider_failure(): + agent = SemanticAnnotatorAgent(_ExplodingProvider(), LLMConfig(model="fake")) + assert await agent.annotate("metric", _FINDINGS) == {} + + +async def test_annotate_unknown_kind_is_noop(): + agent = _agent('{"annotations": []}') + assert await agent.annotate("not_a_kind", _FINDINGS) == {} + + +async def test_service_annotation_merge(monkeypatch): + findings = [ + Finding( + kind="metric", + title="Metric from view", + payload={"metric_name": "raw_name", "sql_expression": "SUM(orders.total_amount)"}, + evidence=[Evidence("view", "from v_monthly_revenue")], + confidence=0.8, + ), + Finding(kind="fanout_warning", title="Fan-out", payload={"guidance": "g"}, confidence=0.7), + ] + provider = _FakeProvider( + '{"annotations": [{"index": 0, "metric_name": "monthly_revenue", ' + '"display_name": "Monthly Revenue", "description": "desc"}]}' + ) + monkeypatch.setattr("app.llm.provider_registry.get_provider", lambda *a, **k: provider) + result = await compilation_service._annotate(findings) + assert result[0].payload["metric_name"] == "monthly_revenue" + assert result[0].payload["display_name"] == "Monthly Revenue" + # structure untouched + assert result[0].payload["sql_expression"] == "SUM(orders.total_amount)" + + +async def test_service_annotation_survives_missing_provider(monkeypatch): + def _raise(*a, **k): + raise ValueError("no api key") + + monkeypatch.setattr("app.llm.provider_registry.get_provider", _raise) + findings = [Finding(kind="metric", title="t", payload={"metric_name": "m"}, confidence=0.6)] + result = await compilation_service._annotate(findings) + assert result[0].payload["metric_name"] == "m" # unchanged, no crash diff --git a/backend/tests/test_compiler_inference.py b/backend/tests/test_compiler_inference.py new file mode 100644 index 0000000..da29704 --- /dev/null +++ b/backend/tests/test_compiler_inference.py @@ -0,0 +1,352 @@ +"""Deterministic inference modules of the semantic layer compiler. + +Pure unit tests: hand-built TableProfiles + a fake prober, no DB, no LLM. +""" + +from app.semantic_compiler.inference import ( + infer_dead_tables, + infer_dictionaries, + infer_fanout_warnings, + infer_glossary_entities, + infer_joins, + infer_pii, + infer_tenant_scope, + infer_view_metrics, +) +from app.semantic_compiler.sqlmeta import JoinPair, SqlAnalysis +from app.semantic_compiler.types import ( + KIND_RELATIONSHIP, + ColumnProfile, + Finding, + TableProfile, + Thresholds, + ViewDef, +) + + +class FakeProber: + """Canned answers: overlap probes, lookup labels, sample values.""" + + def __init__(self, overlap: float = 1.0, lookup_rows=None, samples=None): + self.overlap = overlap + self.lookup_rows = lookup_rows or [] + self.samples = samples or {} + self.queries: list[str] = [] + + async def query(self, sql: str, max_rows: int = 1000): + self.queries.append(sql) + if "overlap" in sql: + return [{"overlap": self.overlap}] + return self.lookup_rows + + async def sample_values(self, schema, table, column, limit=20): + return self.samples.get((table, column), []) + + +def col(name, data_type="bigint", pk=False, unique=False, **kwargs): + return ColumnProfile( + name=name, data_type=data_type, is_primary_key=pk, is_unique=unique, **kwargs + ) + + +def make_tables() -> list[TableProfile]: + customers = TableProfile( + schema_name="public", + table_name="customers", + row_count_estimate=200, + columns=[ + col("id", pk=True), + col("tenant_id"), + col("email", "text"), + col("status", "integer", n_distinct=4.0), + ], + ) + orders = TableProfile( + schema_name="public", + table_name="orders", + row_count_estimate=2000, + columns=[ + col("id", pk=True), + col("tenant_id"), + col("customer_id"), + col("status", "integer", n_distinct=4.0), + col("total_amount", "numeric"), + ], + ) + order_statuses = TableProfile( + schema_name="public", + table_name="order_statuses", + row_count_estimate=4, + columns=[col("id", "integer", pk=True), col("code", "text"), col("label", "text")], + ) + customers_bak = TableProfile( + schema_name="public", + table_name="customers_bak", + row_count_estimate=0, + columns=[col("id"), col("email", "text")], + ) + return [customers, orders, order_statuses, customers_bak] + + +# --- joins ------------------------------------------------------------------ + + +async def test_join_inference_naming_plus_overlap(): + tables = make_tables() + findings = await infer_joins(tables, [], FakeProber(overlap=0.99), Thresholds()) + by_key = { + ( + f.payload["source_table"], + f.payload["source_column"], + f.payload["target_table"], + ): f + for f in findings + } + edge = by_key[("orders", "customer_id", "customers")] + assert edge.payload["target_column"] == "id" + assert edge.payload["cardinality"] == "N:1" + assert edge.confidence >= 0.75 # naming 0.45 + overlap 0.35 + sources = {e.source for e in edge.evidence} + assert {"naming", "value_overlap"} <= sources + + +async def test_join_inference_lookup_table_pattern(): + tables = make_tables() + findings = await infer_joins(tables, [], FakeProber(overlap=0.99), Thresholds()) + keys = { + (f.payload["source_table"], f.payload["source_column"], f.payload["target_table"]) + for f in findings + } + assert ("orders", "status", "order_statuses") in keys + + +async def test_join_inference_failed_probe_kills_candidate(): + tables = make_tables() + findings = await infer_joins(tables, [], FakeProber(overlap=0.1), Thresholds()) + assert findings == [] + + +async def test_join_inference_log_co_occurrence_boost(): + tables = make_tables() + analysis = SqlAnalysis( + tables=["orders", "customers"], + join_pairs=[JoinPair("orders", "customer_id", "customers", "id")], + ) + findings = await infer_joins( + tables, [(analysis, 50, "query log")], FakeProber(overlap=0.99), Thresholds() + ) + edge = next(f for f in findings if f.payload["source_column"] == "customer_id") + assert edge.confidence >= 0.9 # naming + overlap + logs + assert any(e.source == "query_logs" for e in edge.evidence) + + +async def test_join_inference_skips_declared_fks(): + tables = make_tables() + from app.semantic_compiler.types import DeclaredFK + + tables[1].declared_fks.append(DeclaredFK("customer_id", "public", "customers", "id")) + findings = await infer_joins(tables, [], FakeProber(overlap=0.99), Thresholds()) + keys = {(f.payload["source_table"], f.payload["source_column"]) for f in findings} + assert ("orders", "customer_id") not in keys + + +# --- dictionaries ----------------------------------------------------------- + + +async def test_dictionary_from_check_constraint(): + tables = make_tables() + tables[0].column("status").check_in_values = ["1", "2", "3", "4"] + findings = await infer_dictionaries(tables, [], FakeProber()) + finding = next( + f for f in findings if f.payload["table"] == "customers" and f.payload["column"] == "status" + ) + assert [e["raw_value"] for e in finding.payload["entries"]] == ["1", "2", "3", "4"] + assert finding.confidence >= 0.8 + + +async def test_dictionary_from_lookup_table_labels(): + tables = make_tables() + rel = Finding( + kind=KIND_RELATIONSHIP, + title="orders.status -> order_statuses.id", + payload={ + "source_table": "orders", + "source_column": "status", + "target_table": "order_statuses", + "target_column": "id", + }, + confidence=0.8, + ) + prober = FakeProber( + lookup_rows=[ + {"raw": 1, "display": "Pending"}, + {"raw": 2, "display": "Paid"}, + ] + ) + findings = await infer_dictionaries(tables, [rel], prober) + finding = next( + f for f in findings if f.payload["table"] == "orders" and f.payload["column"] == "status" + ) + assert finding.payload["entries"][0]["display_value"] == "Pending" + assert finding.confidence >= 0.8 + + +async def test_dictionary_from_most_common_vals(): + tables = make_tables() + tables[0].columns.append( + col( + "segment", + "character varying", + n_distinct=3.0, + most_common_vals=["retail", "corporate", "sme"], + ) + ) + findings = await infer_dictionaries(tables, [], FakeProber()) + finding = next(f for f in findings if f.payload["column"] == "segment") + assert {e["raw_value"] for e in finding.payload["entries"]} == {"retail", "corporate", "sme"} + + +# --- view metrics ----------------------------------------------------------- + + +def test_view_metrics_from_aggregates(): + view = ViewDef("public", "v_monthly_revenue", "unused") + analysis = SqlAnalysis( + tables=["orders"], + aggregates=[], + group_by=["orders.tenant_id"], + where_sql="orders.status = 3", + ) + from app.semantic_compiler.sqlmeta import AggregateItem + + analysis.aggregates = [ + AggregateItem( + sql="SUM(orders.total_amount)", + function="sum", + column="orders.total_amount", + alias="revenue", + ) + ] + findings = infer_view_metrics([(view, analysis)]) + assert len(findings) == 1 + payload = findings[0].payload + assert payload["metric_name"] == "monthly_revenue_revenue" + assert payload["sql_expression"] == "SUM(orders.total_amount)" + assert payload["aggregation_type"] == "sum" + assert payload["dimensions"] == ["tenant_id"] + assert payload["filters"] == {"where": "orders.status = 3"} + assert findings[0].confidence >= 0.75 + + +# --- refusal boundaries ----------------------------------------------------- + + +def test_dead_table_detection(): + findings = infer_dead_tables(make_tables(), {"orders", "customers"}, logs_available=True) + assert len(findings) == 1 + finding = findings[0] + assert finding.payload["table"] == "customers_bak" + assert finding.confidence >= 0.9 # suffix + zero rows + never queried + + +def test_tenant_scope_detection(): + analyses = [ + (SqlAnalysis(tables=["orders"], where_columns=["orders.tenant_id"]), 10), + (SqlAnalysis(tables=["customers"], where_columns=["customers.tenant_id"]), 5), + (SqlAnalysis(tables=["orders"], where_columns=["orders.status"]), 1), + ] + # entity tables = customers, orders, products (lookup + _bak excluded); + # tenant_id on 2 of 3 — log confirmation pushes it over the threshold. + tables = make_tables() + tables.append( + TableProfile( + schema_name="public", + table_name="products", + row_count_estimate=40, + columns=[col("id", pk=True), col("sku", "text"), col("unit_price", "numeric")], + ) + ) + findings = infer_tenant_scope(tables, analyses) + finding = next(f for f in findings if f.payload["column"] == "tenant_id") + assert "orders" in finding.payload["row_filters"] + assert finding.confidence >= 0.7 # presence + known name + log fraction + + +def test_tenant_scope_needs_log_confirmation(): + tables = make_tables() + tables.append( + TableProfile( + schema_name="public", + table_name="products", + row_count_estimate=40, + columns=[col("id", pk=True), col("sku", "text"), col("unit_price", "numeric")], + ) + ) + findings = infer_tenant_scope(tables, []) # no query logs + finding = next(f for f in findings if f.payload["column"] == "tenant_id") + assert finding.confidence < 0.5 # stays below the default emit threshold + + +async def test_pii_name_and_value_signals(): + tables = make_tables() + prober = FakeProber(samples={("customers", "email"): ["a@b.com", "c@d.org", None]}) + findings = await infer_pii(tables, prober) + email = next( + f for f in findings if f.payload["column"] == "email" and f.payload["table"] == "customers" + ) + assert email.payload["category"] == "email" + assert email.confidence >= 0.85 + + +def test_fanout_warning_from_n1_edge(): + tables = make_tables() + rel = Finding( + kind=KIND_RELATIONSHIP, + title="orders.customer_id -> customers.id", + payload={ + "source_table": "orders", + "source_column": "customer_id", + "target_table": "customers", + "target_column": "id", + "cardinality": "N:1", + }, + confidence=0.9, + ) + # parent customers has no measure columns → no warning; orders as parent does + rel2 = Finding( + kind=KIND_RELATIONSHIP, + title="order_items.order_id -> orders.id", + payload={ + "source_table": "order_items", + "source_column": "order_id", + "target_table": "orders", + "target_column": "id", + "cardinality": "N:1", + }, + confidence=0.9, + ) + findings = infer_fanout_warnings(tables, [rel, rel2]) + warning = next(f for f in findings if f.payload["parent_table"] == "orders") + assert "total_amount" in warning.payload["risky_columns"] + assert "double-count" in warning.payload["guidance"] + + +def test_glossary_hub_entities(): + tables = make_tables() + rel = Finding( + kind=KIND_RELATIONSHIP, + title="orders.customer_id -> customers.id", + payload={ + "source_table": "orders", + "source_column": "customer_id", + "target_table": "customers", + "target_column": "id", + }, + confidence=0.9, + ) + findings = infer_glossary_entities(tables, [rel], dead_table_names={"customers_bak"}) + terms = {f.payload["term"] for f in findings} + assert "Customer" in terms + assert all("customers_bak" not in f.payload["related_tables"] for f in findings) + customer = next(f for f in findings if f.payload["term"] == "Customer") + assert "orders.customer_id" in customer.payload["related_columns"] diff --git a/backend/tests/test_compiler_sqlmeta.py b/backend/tests/test_compiler_sqlmeta.py new file mode 100644 index 0000000..308be52 --- /dev/null +++ b/backend/tests/test_compiler_sqlmeta.py @@ -0,0 +1,62 @@ +"""sqlmeta: SQL analysis for the semantic layer compiler.""" + +import pytest + +pytest.importorskip("sqlglot") + +from app.semantic_compiler.sqlmeta import analyze # noqa: E402 + +VIEW_SQL = """ +SELECT + o.tenant_id, + date_trunc('month', o.order_date) AS month, + SUM(o.total_amount) AS revenue, + COUNT(*) AS order_count +FROM orders o +JOIN customers c ON o.customer_id = c.id +WHERE o.deleted_at IS NULL AND o.status = 3 +GROUP BY o.tenant_id, date_trunc('month', o.order_date) +""" + + +def test_analyze_extracts_aggregates_and_dimensions(): + analysis = analyze(VIEW_SQL, dialect="postgres") + assert analysis is not None + assert sorted(analysis.tables) == ["customers", "orders"] + + functions = {a.function for a in analysis.aggregates} + assert functions == {"sum", "count"} + sum_agg = next(a for a in analysis.aggregates if a.function == "sum") + assert sum_agg.column == "orders.total_amount" + assert sum_agg.alias == "revenue" + + assert "orders.tenant_id" in analysis.group_by + + +def test_analyze_extracts_join_pairs_resolving_aliases(): + analysis = analyze(VIEW_SQL, dialect="postgres") + assert analysis is not None + assert len(analysis.join_pairs) == 1 + pair = analysis.join_pairs[0] + assert pair.key() == (("customers", "id"), ("orders", "customer_id")) + + +def test_analyze_where_columns_and_text(): + analysis = analyze(VIEW_SQL, dialect="postgres") + assert analysis is not None + assert "orders.deleted_at" in analysis.where_columns + assert "orders.status" in analysis.where_columns + assert analysis.where_sql is not None + assert "deleted_at" in analysis.where_sql.lower() + + +def test_analyze_unqualified_columns_single_table(): + analysis = analyze("SELECT SUM(total_amount) FROM orders WHERE tenant_id = 1") + assert analysis is not None + assert analysis.aggregates[0].column == "orders.total_amount" + assert analysis.where_columns == ["orders.tenant_id"] + + +def test_analyze_degrades_on_garbage(): + assert analyze("") is None + assert analyze("THIS IS NOT ((( SQL") is None diff --git a/docker-compose.yml b/docker-compose.yml index 6edef29..af8ff9d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,9 +15,14 @@ services: timeout: 5s retries: 5 - # Sample target database for development/testing + # Sample target database for development/testing. + # Hosts two databases: sampledb (IFRS 9, curated) and opsdb (hostile + # operational-style schema for the semantic layer compiler). + # pg_stat_statements feeds the compiler's query-log collector. + # NOTE: init scripts only run on a fresh volume (docker compose down -v). sample-db: image: postgres:16 + command: postgres -c shared_preload_libraries=pg_stat_statements -c pg_stat_statements.track=all environment: POSTGRES_DB: sampledb POSTGRES_USER: sample @@ -26,7 +31,9 @@ services: - "5433:5432" volumes: - sample_db_data:/var/lib/postgresql/data - - ./backend/tests/fixtures/sample_seed.sql:/docker-entrypoint-initdb.d/seed.sql + - ./backend/tests/fixtures/ops_extensions.sql:/docker-entrypoint-initdb.d/05_extensions.sql + - ./backend/tests/fixtures/sample_seed.sql:/docker-entrypoint-initdb.d/10_seed.sql + - ./backend/tests/fixtures/ops_seed.sql:/docker-entrypoint-initdb.d/20_ops_seed.sql healthcheck: test: ["CMD-SHELL", "pg_isready -U sample -d sampledb"] interval: 5s diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 4989f4e..b46d6f1 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -13,6 +13,7 @@ import { DashboardDetailPage } from './pages/DashboardDetailPage'; import { DictionaryPage } from './pages/DictionaryPage'; import { KnowledgePage } from './pages/KnowledgePage'; import { CatalogPage } from './pages/CatalogPage'; +import { CompilerPage } from './pages/CompilerPage'; import { HistoryPage } from './pages/HistoryPage'; import { AuditPage } from './pages/AuditPage'; import { SchedulesPage } from './pages/SchedulesPage'; @@ -40,6 +41,7 @@ export default function App() { } /> } /> } /> + } /> } /> } /> } /> diff --git a/frontend/src/api/compilationApi.ts b/frontend/src/api/compilationApi.ts new file mode 100644 index 0000000..7be0896 --- /dev/null +++ b/frontend/src/api/compilationApi.ts @@ -0,0 +1,44 @@ +import { api } from './client'; +import type { CompilationFinding, CompilationRun } from '../types/api'; + +export interface StartRunOptions { + llm_enabled?: boolean; + min_confidence?: number; + ignore_declared_fks?: boolean; +} + +export const compilationApi = { + startRun: (connectionId: string, options: StartRunOptions = {}) => + api + .post(`/connections/${connectionId}/compilation/runs`, options) + .then(r => r.data), + listRuns: (connectionId: string) => + api.get(`/connections/${connectionId}/compilation/runs`).then(r => r.data), + getRun: (connectionId: string, runId: string) => + api + .get(`/connections/${connectionId}/compilation/runs/${runId}`) + .then(r => r.data), + listFindings: (connectionId: string, params?: { status?: string; kind?: string }) => + api + .get(`/connections/${connectionId}/compilation/findings`, { params }) + .then(r => r.data), + accept: (connectionId: string, findingId: string) => + api + .post( + `/connections/${connectionId}/compilation/findings/${findingId}/accept`, + ) + .then(r => r.data), + dismiss: (connectionId: string, findingId: string) => + api + .post( + `/connections/${connectionId}/compilation/findings/${findingId}/dismiss`, + ) + .then(r => r.data), + bulk: (connectionId: string, findingIds: string[], action: 'accept' | 'dismiss') => + api + .post<{ succeeded: number; failed: number }>( + `/connections/${connectionId}/compilation/findings/bulk`, + { finding_ids: findingIds, action }, + ) + .then(r => r.data), +}; diff --git a/frontend/src/components/compiler/FindingCard.tsx b/frontend/src/components/compiler/FindingCard.tsx new file mode 100644 index 0000000..be79ed4 --- /dev/null +++ b/frontend/src/components/compiler/FindingCard.tsx @@ -0,0 +1,233 @@ +import { useState } from 'react'; +import { + ActionIcon, + Badge, + Button, + Card, + Code, + Collapse, + Group, + Progress, + Stack, + Table, + Text, + Tooltip, +} from '@mantine/core'; +import { IconCheck, IconChevronDown, IconChevronRight, IconX } from '@tabler/icons-react'; +import type { CompilationFinding } from '../../types/api'; + +const SOURCE_LABEL: Record = { + naming: 'Naming', + value_overlap: 'Data probe', + query_logs: 'Query logs', + pg_stats: 'Statistics', + constraint: 'Constraint', + view: 'View', + heuristic: 'Heuristic', +}; + +function confidenceColor(confidence: number): string { + if (confidence >= 0.8) return 'green'; + if (confidence >= 0.6) return 'yellow'; + return 'orange'; +} + +interface DictionaryEntryPayload { + raw_value: string; + display_value: string; +} + +interface FindingPayload { + // metric + sql_expression?: string; + description?: string; + dimensions?: string[]; + filters?: { where?: string }; + // glossary + definition?: string; + // dictionary + entries?: DictionaryEntryPayload[]; + // relationship + source_table?: string; + source_column?: string; + target_table?: string; + target_column?: string; + cardinality?: string | null; + // policies / dead table / fanout + column?: string; + tables?: string[]; + masked_column?: string; + category?: string; + table?: string; + guidance?: string; +} + +function PayloadSummary({ finding }: { finding: CompilationFinding }) { + const p = finding.payload as FindingPayload; + switch (finding.kind) { + case 'metric': + return ( + + {String(p.sql_expression ?? '')} + {p.description ? {String(p.description)} : null} + + {p.dimensions?.length ? `Dimensions: ${p.dimensions.join(', ')}` : ''} + {p.filters?.where ? ` · Filter: ${String(p.filters.where)}` : ''} + + + ); + case 'glossary': + return {String(p.definition ?? '')}; + case 'dictionary': + return ( + + + {p.entries?.slice(0, 8).map(e => ( + + + {String(e.raw_value)} + + {String(e.display_value)} + + ))} + +
+ ); + case 'relationship': + return ( + + + {String(p.source_table)}.{String(p.source_column)} + {' '} + →{' '} + + {String(p.target_table)}.{String(p.target_column)} + + {p.cardinality ? ( + + {String(p.cardinality)} + + ) : null} + + ); + case 'data_policy_row_filter': + return ( + + + Scoping column {String(p.column)} on {p.tables?.length ?? 0} tables. + Accepting creates a disabled row-filter policy — edit the{' '} + :tenant_id placeholder before enabling. + + + ); + case 'data_policy_masking': + return ( + + Mask {String(p.masked_column)} ({String(p.category)}). Merged into the + “Compiler: PII masking” policy (created disabled). + + ); + case 'dead_table': + return ( + + Block {String(p.table)} from queries (merged into the “Compiler: dead + tables” policy, created disabled). + + ); + case 'fanout_warning': + return {String(p.guidance ?? '')}; + default: + return {JSON.stringify(p, null, 2)}; + } +} + +interface FindingCardProps { + finding: CompilationFinding; + onAccept: (id: string) => void; + onDismiss: (id: string) => void; + busy?: boolean; +} + +export function FindingCard({ finding, onAccept, onDismiss, busy }: FindingCardProps) { + const [showEvidence, setShowEvidence] = useState(false); + const reviewed = finding.status !== 'proposed'; + + return ( + + + + + + {finding.title} + + {reviewed && ( + + {finding.status} + + )} + + + + + + + + + + + {finding.evidence.map((e, i) => ( + + + {SOURCE_LABEL[e.source] ?? e.source} + + + {e.detail} + + + ))} + + + + {!reviewed && ( + + + onAccept(finding.id)} + > + + + + + onDismiss(finding.id)} + > + + + + + )} + + + ); +} diff --git a/frontend/src/components/layout/AppLayout.tsx b/frontend/src/components/layout/AppLayout.tsx index 82d21d3..f4b89e9 100644 --- a/frontend/src/components/layout/AppLayout.tsx +++ b/frontend/src/components/layout/AppLayout.tsx @@ -26,6 +26,7 @@ import { IconClockHour4, IconLockCog, IconChartHistogram, + IconSparkles, } from '@tabler/icons-react'; import { Outlet, useLocation, useNavigate } from 'react-router-dom'; import { EmbeddingStatusBanner } from '../common/EmbeddingStatusBanner'; @@ -42,6 +43,7 @@ const NAV_ITEMS = [ { label: 'Dictionary', path: '/dictionary', icon: IconVocabulary }, { label: 'Knowledge', path: '/knowledge', icon: IconFileText }, { label: 'Catalog', path: '/catalog', icon: IconBook2 }, + { label: 'Compiler', path: '/compiler', icon: IconSparkles }, { label: 'Schedules', path: '/schedules', icon: IconClockHour4 }, { label: 'History', path: '/history', icon: IconHistory }, { label: 'Usage & Cost', path: '/analytics', icon: IconChartHistogram, adminOnly: true }, diff --git a/frontend/src/hooks/useCompilation.ts b/frontend/src/hooks/useCompilation.ts new file mode 100644 index 0000000..a8610f1 --- /dev/null +++ b/frontend/src/hooks/useCompilation.ts @@ -0,0 +1,69 @@ +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; +import { compilationApi, type StartRunOptions } from '../api/compilationApi'; + +export function useCompilationRuns(connectionId: string | undefined) { + return useQuery({ + queryKey: ['compilation-runs', connectionId], + queryFn: () => compilationApi.listRuns(connectionId!), + enabled: !!connectionId, + // Poll while a run is queued/running (mirrors useEmbeddingStatus). + refetchInterval: query => { + const data = query.state.data; + if (!data) return false; + const active = data.some(r => r.status === 'queued' || r.status === 'running'); + return active ? 2000 : false; + }, + }); +} + +export function useCompilationFindings( + connectionId: string | undefined, + filters: { status?: string; kind?: string } = {}, +) { + return useQuery({ + queryKey: ['compilation-findings', connectionId, filters], + queryFn: () => compilationApi.listFindings(connectionId!, filters), + enabled: !!connectionId, + }); +} + +function useInvalidate(connectionId: string) { + const qc = useQueryClient(); + return () => { + qc.invalidateQueries({ queryKey: ['compilation-findings', connectionId] }); + qc.invalidateQueries({ queryKey: ['compilation-runs', connectionId] }); + }; +} + +export function useStartCompilation(connectionId: string) { + const invalidate = useInvalidate(connectionId); + return useMutation({ + mutationFn: (options: StartRunOptions) => compilationApi.startRun(connectionId, options), + onSuccess: invalidate, + }); +} + +export function useAcceptFinding(connectionId: string) { + const invalidate = useInvalidate(connectionId); + return useMutation({ + mutationFn: (findingId: string) => compilationApi.accept(connectionId, findingId), + onSuccess: invalidate, + }); +} + +export function useDismissFinding(connectionId: string) { + const invalidate = useInvalidate(connectionId); + return useMutation({ + mutationFn: (findingId: string) => compilationApi.dismiss(connectionId, findingId), + onSuccess: invalidate, + }); +} + +export function useBulkReview(connectionId: string) { + const invalidate = useInvalidate(connectionId); + return useMutation({ + mutationFn: ({ ids, action }: { ids: string[]; action: 'accept' | 'dismiss' }) => + compilationApi.bulk(connectionId, ids, action), + onSuccess: invalidate, + }); +} diff --git a/frontend/src/pages/CompilerPage.tsx b/frontend/src/pages/CompilerPage.tsx new file mode 100644 index 0000000..d633d4b --- /dev/null +++ b/frontend/src/pages/CompilerPage.tsx @@ -0,0 +1,255 @@ +import { useMemo, useState } from 'react'; +import { + Accordion, + Alert, + Badge, + Button, + Checkbox, + Group, + Loader, + Progress, + Select, + Slider, + Stack, + Text, + Title, +} from '@mantine/core'; +import { IconSparkles } from '@tabler/icons-react'; +import { useConnections } from '../hooks/useConnections'; +import { + useAcceptFinding, + useBulkReview, + useCompilationFindings, + useCompilationRuns, + useDismissFinding, + useStartCompilation, +} from '../hooks/useCompilation'; +import { FindingCard } from '../components/compiler/FindingCard'; +import type { CompilationFinding, CompilationRun } from '../types/api'; + +const KIND_LABEL: Record = { + relationship: 'Inferred join paths', + metric: 'Metric candidates', + dictionary: 'Value dictionaries', + glossary: 'Glossary entities', + data_policy_row_filter: 'Row-filter policies (tenant scoping)', + data_policy_masking: 'PII masking', + dead_table: 'Dead tables', + fanout_warning: 'Fan-out warnings', +}; + +const KIND_ORDER = Object.keys(KIND_LABEL) as CompilationFinding['kind'][]; + +function RunBanner({ run }: { run: CompilationRun }) { + if (run.status === 'failed') { + return ( + + {run.error} + + ); + } + if (run.status === 'queued' || run.status === 'running') { + const p = run.progress; + return ( + + + {p?.stage || 'Starting…'} + + + + ); + } + const sources = run.stats.sources_available ?? {}; + const missing = Object.entries(sources) + .filter(([, available]) => !available) + .map(([name]) => name); + return ( + + + Examined {run.stats.tables_examined ?? 0} tables, {run.stats.views_examined ?? 0} views,{' '} + {run.stats.logged_queries_examined ?? 0} logged queries. + + {missing.length > 0 && ( + + Unavailable evidence sources: {missing.join(', ')} — confidence is reduced where these + would have helped. (pg_stats needs ANALYZE; query logs need the pg_stat_statements + extension.) + + )} + + ); +} + +export function CompilerPage() { + const { data: connections } = useConnections(); + const [connectionId, setConnectionId] = useState(null); + const connOptions = (connections ?? []).map(c => ({ value: c.id, label: c.name })); + if (!connectionId && connOptions.length > 0) setConnectionId(connOptions[0].value); + + const [llmEnabled, setLlmEnabled] = useState(true); + const [minConfidence, setMinConfidence] = useState(0.5); + const [statusFilter, setStatusFilter] = useState('proposed'); + + const { data: runs } = useCompilationRuns(connectionId ?? undefined); + const latestRun = runs?.[0]; + const runActive = latestRun?.status === 'queued' || latestRun?.status === 'running'; + + const { data: findings, isLoading } = useCompilationFindings(connectionId ?? undefined, { + status: statusFilter || undefined, + }); + + const start = useStartCompilation(connectionId ?? ''); + const accept = useAcceptFinding(connectionId ?? ''); + const dismiss = useDismissFinding(connectionId ?? ''); + const bulk = useBulkReview(connectionId ?? ''); + + const grouped = useMemo(() => { + const groups = new Map(); + for (const finding of findings ?? []) { + const list = groups.get(finding.kind) ?? []; + list.push(finding); + groups.set(finding.kind, list); + } + return groups; + }, [findings]); + + const busy = accept.isPending || dismiss.isPending || bulk.isPending; + + return ( + + + Semantic Layer Compiler + + + + + Introspects the connected database — schema, column statistics, view definitions, and + query logs — and proposes draft semantic-layer objects with evidence and confidence. + Nothing is created until you accept a finding; accepted objects land as{' '} + + draft + {' '} + for normal certification. + + + + setStatusFilter(v ?? 'proposed')} + /> + + + {latestRun && } + + {isLoading ? ( + + + + ) : !findings || findings.length === 0 ? ( + + No {statusFilter} findings on this connection. Run the compiler to generate proposals. + + ) : ( + + {KIND_ORDER.filter(kind => grouped.has(kind)).map(kind => { + const group = grouped.get(kind)!; + const proposed = group.filter(f => f.status === 'proposed'); + return ( + + + + {KIND_LABEL[kind]} + + {group.length} + + + + + + {proposed.length > 1 && ( + + + + + )} + {group.map(finding => ( + accept.mutate(id)} + onDismiss={id => dismiss.mutate(id)} + busy={busy} + /> + ))} + + + + ); + })} + + )} + + ); +} diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index d77c541..b54b28c 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -496,3 +496,62 @@ export interface TableUsage { table: string; query_count: number; } + +// --- Semantic layer compiler --- + +export interface CompilationProgress { + total: number; + completed: number; + stage: string; + status: string; + error: string | null; +} + +export interface CompilationRun { + id: string; + connection_id: string; + status: 'queued' | 'running' | 'completed' | 'failed'; + options: Record; + stats: { + findings?: Record; + sources_available?: Record; + superseded_proposals?: number; + tables_examined?: number; + views_examined?: number; + logged_queries_examined?: number; + }; + error: string | null; + started_at: string | null; + finished_at: string | null; + created_at: string; + progress: CompilationProgress | null; +} + +export interface CompilationEvidence { + source: string; + detail: string; +} + +export interface CompilationFinding { + id: string; + run_id: string; + connection_id: string; + kind: + | 'relationship' + | 'metric' + | 'dictionary' + | 'glossary' + | 'data_policy_row_filter' + | 'data_policy_masking' + | 'dead_table' + | 'fanout_warning'; + title: string; + payload: Record; + evidence: CompilationEvidence[]; + confidence: number; + status: 'proposed' | 'accepted' | 'dismissed'; + created_entity_type: string | null; + created_entity_id: string | null; + reviewed_at: string | null; + created_at: string; +}