diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 397754b..b043e95 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: cache: pip - name: Install dependencies - run: pip install -e ".[llm,dev,observability]" + run: pip install -e ".[llm,dev,observability,lineage]" # Gating: the test suite must pass before any Phase 0+ refactor lands. - name: Tests diff --git a/CHANGELOG.md b/CHANGELOG.md index 0562f46..1f6a69f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,6 +72,32 @@ product surface; all optional dependencies degrade gracefully). - New optional dependency extra: `export` (`openpyxl`). Frontend adds `recharts` and `react-grid-layout`. +### Added (Phase 3 - Discovery, catalog & trust) +- **Certification & semantic versioning** (migration `007`) — metrics, glossary terms, sample + queries, and saved queries gain a governed lifecycle (`draft → in_review → certified → + deprecated`), an integer `version`, and certification stamps (`certified_by`/`certified_at`). + Editors submit for review / revert; admins certify / deprecate. Certifying validates the + entity's SQL (read-only blocklist + a sqlglot parse). +- **Version history & changelog** — every content edit and status transition appends a + `SemanticVersion` snapshot, exposed at `.../{entity}/{id}/versions` with a field-level diff + helper; surfaced in the UI as a per-entity history timeline. +- **Lifecycle logic** centralized in `versioning_service.py` so all four entity types behave + identically; status transitions go through a single governed endpoint + (`POST .../{entity}/{id}/status`). +- **Data catalog** (`catalog_service.py`, `GET /connections/{id}/catalog/search` + `/facets`) — a + unified hybrid search across tables, columns, metrics, glossary, sample/saved queries, and + knowledge, reusing the existing pgvector embeddings + keyword scorer (no new full-text infra). + Certified items are boosted in ranking; facets by type, status, schema, and owner. New + `frontend/src/pages/CatalogPage.tsx` with search, facet sidebar, and a detail/lineage drawer. +- **Lightweight lineage** (migration `008`, `lineage_service.py`) — saved-query and metric SQL is + parsed with sqlglot into `artifact_dependencies` edges on create/update (best-effort; degrades + to a no-op if sqlglot is absent). Powers the per-artifact "what this touches" view + (`.../{entity}/{id}/lineage`) and the impact view "what depends on this table" + (`GET .../catalog/lineage?table=`). +- New optional dependency extra: `lineage` (`sqlglot`); installed in the backend image and in CI + so the lineage tests run (they `importorskip` past `sqlglot` when the extra is absent). +- **Deferred to a later milestone:** column profiling (null rate / distinct counts / sample values). + ## [1.0.0] - 2026-06-04 First stable release: natural-language-to-SQL with a semantic metadata layer. diff --git a/CLAUDE.md b/CLAUDE.md index 2dec3e3..ee60f1f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -44,7 +44,7 @@ For manual seeding (if auto-setup disabled): `python backend/scripts/seed_ifrs9_ Run from `backend/`: ```bash -pip install -e ".[llm,dev,bigquery,databricks]" # Install all deps +pip install -e ".[llm,dev,bigquery,databricks,lineage]" # Install all deps (add export,observability,jobs as needed) alembic upgrade head # Run migrations uvicorn app.main:app --reload # Dev server on :8000 pytest # Run tests @@ -233,7 +233,7 @@ dependencies degrade gracefully — the app boots without `structlog` / - **Jobs** (`app/jobs/`): `JobQueue` ABC with `InProcessJobQueue` (asyncio, default) and `ArqJobQueue` (Redis). Jobs are registered by name in `registry.py`; `launch_background_embeddings` submits `"generate_embeddings"` through `get_job_queue()`. For arq, run a worker: `JOB_BACKEND=arq arq app.jobs.worker.WorkerSettings` (embedding progress then lives in the worker process). - **Health** (`app/api/v1/endpoints/health.py`): `GET /health/live` (process) and `GET /health/ready` (DB + job queue + LLM provider, 503 on failure) for K8s probes. - **LLM endpoints:** Azure OpenAI provider (`azure_openai`) added so the pipeline can run inside a customer VPC; registered in `provider_registry`. -- **Tests/CI:** unit tests in `backend/tests/` (no DB/LLM needed); `.github/workflows/ci.yml` runs pytest (gating) + ruff/mypy/frontend build (advisory until pre-existing lint debt is cleared). Optional deps: `pip install -e ".[observability,jobs]"`. +- **Tests/CI:** unit tests in `backend/tests/` (no DB/LLM needed); `.github/workflows/ci.yml` installs `.[llm,dev,observability,lineage]` and runs pytest (gating) + ruff/mypy/frontend build (advisory until pre-existing lint debt is cleared). The lineage tests need `sqlglot` (the `[lineage]` extra) and `pytest.importorskip` past it otherwise. Optional deps: `pip install -e ".[observability,jobs]"`. ## Identity & auth (Phase 1) @@ -258,3 +258,14 @@ One-shot answers become saved, owned, re-runnable, shareable objects. Two milest - **Export:** client-side CSV/JSON in the frontend; backend CSV/JSON/XLSX for saved queries (XLSX needs the optional `export` extra → `openpyxl`). - **Endpoints:** `/connections/{id}/saved-queries` (+ `/run`, `/clone`, `/export`, `/charts`), `/dashboards` (+ `/tiles`, `/layout`, `/tiles/{id}/run`). - **Frontend:** Recharts (`components/charts/ChartView.tsx`) for viz; `react-grid-layout` for the dashboard grid; shared typed `components/common/ParamInputs.tsx` for params/filters. Charts are managed inside the saved-query view (no separate Charts page). Note: the frontend container's anonymous `node_modules` volume means new deps (recharts, react-grid-layout) need `docker compose exec frontend npm install` or an image rebuild. + +## Discovery, catalog & trust (Phase 3) + +Makes the semantic layer discoverable and trustworthy. Two milestones; migrations `007` (certification + versioning) and `008` (catalog lineage). **Column profiling is deferred** to a later milestone. + +- **Certification lifecycle** (`app/services/versioning_service.py`): metrics, glossary terms, sample queries, and saved queries carry `status` (`draft|in_review|certified|deprecated`), an integer `version`, and `certified_by_id`/`certified_at`. Transitions go through one governed endpoint per entity (`POST /connections/{id}/{entity}/{eid}/status`); the state machine (`_ALLOWED_TRANSITIONS`) and role gate (`_ROLE_FOR_TARGET`) live in the service — **editor** submits-for-review/reverts, **admin** certifies/deprecates. Certifying runs a lightweight SQL check (`check_sql_safety` + a sqlglot parse). One service handles all four entity types via `_SNAPSHOT_FIELDS` / `_SQL_FIELD` maps. +- **Versioning & changelog** (`SemanticVersion` model): every content edit (PUT → `record_edit`, bumps version) and status transition appends an append-only snapshot. Exposed at `GET .../{entity}/{eid}/versions` (+ `/{version}`); `versioning_service.diff` gives a field-level diff. UI: shared `frontend/src/components/common/{CertificationBadge,StatusActions,VersionHistory}.tsx`, wired into the Metrics/Glossary/SavedQueries pages. +- **Catalog search** (`app/services/catalog_service.py`, `app/api/v1/endpoints/catalog.py`): `GET /connections/{id}/catalog/search` runs a hybrid search across tables, columns, metrics, glossary, sample/saved queries, and knowledge — **reusing the existing pgvector embeddings + the keyword scorer** (`semantic/relevance_scorer.py`), no tsvector. Hits merge into a uniform `CatalogHit`; certified items are boosted (`rank_hits`). `GET .../catalog/facets` returns schemas/owners/tags/type+status counts. Connection-scoped via `require_connection_read`. Frontend: `pages/CatalogPage.tsx` (search + facet sidebar + detail/lineage drawer). +- **Lineage** (`app/services/lineage_service.py`, `ArtifactDependency` model): saved-query `pinned_sql` and metric `sql_expression` are parsed with **sqlglot** (optional `[lineage]` extra; lazy import, degrades to a no-op if absent) into table/column edges, recomputed on create/update (best-effort, never blocks the write). Per-artifact "what this touches" at `GET .../{saved-queries|metrics}/{id}/lineage`; impact view "what depends on this table" at `GET .../catalog/lineage?table=&column=`. Connector type → sqlglot dialect via `dialect_for`. +- **Endpoints:** `/connections/{id}/catalog/{search,facets,lineage}`, plus `/status`, `/versions`, `/versions/{v}`, and `/lineage` sub-resources on the metric/glossary/sample-query/saved-query routers. +- **Heads-up:** existing rows migrate to `status='draft'`, `version=1`. The saved-query PUT routes any `status` change through the governed lifecycle (no raw status writes). sqlglot is a new optional dep — install the `[lineage]` extra (or rebuild the backend image) for lineage to populate. diff --git a/README.md b/README.md index 271e0ad..332dc38 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,9 @@ A full-stack application that translates natural language questions into SQL que - **Saved queries** — name and pin a question + SQL with typed parameters (`{{region}}`); re-run, version, clone, and export (CSV/JSON/XLSX) - **Charts & result caching** — visualize a saved query (line/bar/area/pie/scatter via Recharts); results are snapshotted to a Postgres cache so re-runs don't re-hit the warehouse - **Dashboards** — compose saved queries into a shareable, draggable tile grid with dashboard-level filters that flow into every tile's SQL +- **Certification & versioning** — govern metrics, glossary, and saved queries through a `draft → in_review → certified → deprecated` lifecycle (editors submit, admins certify) with a per-entity version history and changelog +- **Data catalog** — hybrid search (embeddings + keyword) across tables, columns, metrics, glossary, and knowledge, with facets and certified-first ranking +- **Lineage** — sqlglot parses saved-query/metric SQL to show what each touches and what depends on a given table (impact view) - **Production hardening** — rate limiting, async job queue, OpenTelemetry tracing, structured logging, health probes @@ -354,8 +357,8 @@ cd backend python3.12 -m venv .venv source .venv/bin/activate -# Install dependencies -pip install -e ".[llm,dev]" +# Install dependencies (add `lineage` for sqlglot-based catalog lineage) +pip install -e ".[llm,dev,lineage]" # Start PostgreSQL with pgvector (must be running on localhost:5432) # Run migrations diff --git a/backend/Dockerfile b/backend/Dockerfile index 52bbc28..d61b076 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -7,7 +7,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ rm -rf /var/lib/apt/lists/* COPY . . -RUN pip install --no-cache-dir -e ".[llm,dev,bigquery,databricks,observability,jobs]" +RUN pip install --no-cache-dir -e ".[llm,dev,bigquery,databricks,observability,jobs,lineage]" EXPOSE 8000 diff --git a/backend/alembic/versions/007_certification_and_versioning.py b/backend/alembic/versions/007_certification_and_versioning.py new file mode 100644 index 0000000..36cb303 --- /dev/null +++ b/backend/alembic/versions/007_certification_and_versioning.py @@ -0,0 +1,122 @@ +"""Certification + semantic versioning (Phase 3 — Milestone 1) + +Revision ID: 007 +Revises: 006 +Create Date: 2026-06-08 + +Adds a trust/lifecycle layer to the semantic objects. Metrics, glossary terms, +sample queries, and saved queries gain a ``status`` +(draft|in_review|certified|deprecated), an integer ``version``, and certification +stamps (``certified_by_id`` / ``certified_at``). The new ``semantic_versions`` +table is an append-only changelog: a snapshot of an entity at each version with +the reviewer and reason, written on every content edit and status transition. + +Existing rows default to status='draft', version=1 (saved_queries already carry +status + version from migration 005, so only the certification stamps are added +there). +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "007" +down_revision: str = "006" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +# Tables getting the full lifecycle column set (status + version + cert stamps). +_LIFECYCLE_TABLES = ("metric_definitions", "glossary_terms", "sample_queries") + + +def upgrade() -> None: + for table in _LIFECYCLE_TABLES: + op.add_column( + table, + sa.Column("status", sa.String(20), nullable=False, server_default=sa.text("'draft'")), + ) + op.add_column( + table, + sa.Column("version", sa.Integer, nullable=False, server_default=sa.text("1")), + ) + op.add_column( + table, + sa.Column( + "certified_by_id", + UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + ) + op.add_column( + table, + sa.Column("certified_at", sa.DateTime(timezone=True), nullable=True), + ) + + # saved_queries already has status + version (migration 005); add cert stamps. + op.add_column( + "saved_queries", + sa.Column( + "certified_by_id", + UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + ) + op.add_column( + "saved_queries", + sa.Column("certified_at", sa.DateTime(timezone=True), nullable=True), + ) + + op.create_table( + "semantic_versions", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "organization_id", + UUID(as_uuid=True), + sa.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "connection_id", + UUID(as_uuid=True), + sa.ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("entity_type", sa.String(20), nullable=False), + sa.Column("entity_id", UUID(as_uuid=True), nullable=False), + sa.Column("version", sa.Integer, nullable=False), + sa.Column("status", sa.String(20), nullable=False), + sa.Column("snapshot", JSONB, nullable=False), + sa.Column("change_reason", sa.Text, nullable=True), + sa.Column( + "changed_by_id", + UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + op.create_index( + "ix_semantic_versions_entity", + "semantic_versions", + ["entity_type", "entity_id", "version"], + ) + + +def downgrade() -> None: + op.drop_index("ix_semantic_versions_entity", table_name="semantic_versions") + op.drop_table("semantic_versions") + + op.drop_column("saved_queries", "certified_at") + op.drop_column("saved_queries", "certified_by_id") + + for table in _LIFECYCLE_TABLES: + op.drop_column(table, "certified_at") + op.drop_column(table, "certified_by_id") + op.drop_column(table, "version") + op.drop_column(table, "status") diff --git a/backend/alembic/versions/008_catalog_lineage.py b/backend/alembic/versions/008_catalog_lineage.py new file mode 100644 index 0000000..55cd8e2 --- /dev/null +++ b/backend/alembic/versions/008_catalog_lineage.py @@ -0,0 +1,79 @@ +"""Catalog lineage (Phase 3 — Milestone 2) + +Revision ID: 008 +Revises: 007 +Create Date: 2026-06-08 + +Adds ``artifact_dependencies`` — lineage edges recording which tables/columns a +saved query or metric references, parsed from its SQL via sqlglot. Powers the +catalog impact view ("what depends on this table") and the per-artifact +"what this touches" view. Names are stored denormalized; table_id/column_id are +resolved best-effort against the schema cache. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "008" +down_revision: str = "007" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "artifact_dependencies", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "organization_id", + UUID(as_uuid=True), + sa.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "connection_id", + UUID(as_uuid=True), + sa.ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("artifact_type", sa.String(20), nullable=False), + sa.Column("artifact_id", UUID(as_uuid=True), nullable=False), + sa.Column("ref_kind", sa.String(10), nullable=False), + sa.Column("schema_name", sa.String(255), nullable=True), + sa.Column("table_name", sa.String(255), nullable=False), + sa.Column("column_name", sa.String(255), nullable=True), + sa.Column( + "table_id", + UUID(as_uuid=True), + sa.ForeignKey("cached_tables.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column( + "column_id", + UUID(as_uuid=True), + sa.ForeignKey("cached_columns.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + op.create_index( + "ix_artifact_dependencies_artifact", + "artifact_dependencies", + ["artifact_type", "artifact_id"], + ) + op.create_index( + "ix_artifact_dependencies_table", + "artifact_dependencies", + ["connection_id", "table_name"], + ) + + +def downgrade() -> None: + op.drop_index("ix_artifact_dependencies_table", table_name="artifact_dependencies") + op.drop_index("ix_artifact_dependencies_artifact", table_name="artifact_dependencies") + op.drop_table("artifact_dependencies") diff --git a/backend/app/api/v1/endpoints/catalog.py b/backend/app/api/v1/endpoints/catalog.py new file mode 100644 index 0000000..8d7f821 --- /dev/null +++ b/backend/app/api/v1/endpoints/catalog.py @@ -0,0 +1,71 @@ +import uuid + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.deps import require_connection_read +from app.api.v1.schemas.catalog import ( + CatalogFacetsResponse, + CatalogHitResponse, + LineageRefResponse, +) +from app.core.auth import AuthContext +from app.db.session import get_db +from app.services import catalog_service, lineage_service + +router = APIRouter(tags=["catalog"]) + + +@router.get( + "/connections/{connection_id}/catalog/search", + response_model=list[CatalogHitResponse], +) +async def catalog_search( + connection_id: uuid.UUID, + q: str = Query("", description="Search text"), + types: str | None = Query(None, description="Comma-separated hit types to include"), + status: str | None = Query(None, description="Filter by certification status"), + owner: str | None = Query(None, description="Filter by owner/creator id"), + schema: str | None = Query(None, description="Filter tables/columns by schema"), + limit: int = Query(50, ge=1, le=200), + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + type_list = [t.strip() for t in types.split(",") if t.strip()] if types else None + return await catalog_service.search( + db, + connection_id, + q, + types=type_list, + status=status, + owner=owner, + schema=schema, + limit=limit, + ) + + +@router.get( + "/connections/{connection_id}/catalog/facets", + response_model=CatalogFacetsResponse, +) +async def catalog_facets( + connection_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + return await catalog_service.facets(db, connection_id) + + +@router.get( + "/connections/{connection_id}/catalog/lineage", + response_model=list[LineageRefResponse], +) +async def catalog_lineage_impact( + connection_id: uuid.UUID, + table: str = Query(..., description="Table name to find dependents of"), + column: str | None = Query(None, description="Optional column name"), + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + """Impact view: which saved queries / metrics depend on a table (or column).""" + return await lineage_service.dependents_of(db, connection_id, table, column) diff --git a/backend/app/api/v1/endpoints/glossary.py b/backend/app/api/v1/endpoints/glossary.py index 1187a84..69fa897 100644 --- a/backend/app/api/v1/endpoints/glossary.py +++ b/backend/app/api/v1/endpoints/glossary.py @@ -10,10 +10,13 @@ GlossaryTermResponse, GlossaryTermUpdate, ) +from app.api.v1.schemas.semantic_version import SemanticVersionResponse, StatusTransition from app.core.auth import AuthContext from app.core.exceptions import NotFoundError from app.db.models.glossary import GlossaryTerm +from app.db.models.semantic_version import ENTITY_GLOSSARY from app.db.session import get_db +from app.services import versioning_service from app.services.embedding_service import embed_glossary_term router = APIRouter(tags=["glossary"]) @@ -91,7 +94,7 @@ async def update_glossary_term( connection_id: uuid.UUID, term_id: uuid.UUID, body: GlossaryTermUpdate, - _ctx: AuthContext = Depends(require_connection_write), + ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): term = await db.get(GlossaryTerm, term_id) @@ -101,6 +104,7 @@ async def update_glossary_term( for key, value in body.model_dump(exclude_none=True).items(): setattr(term, key, value) + await versioning_service.record_edit(db, ctx, ENTITY_GLOSSARY, term) await db.flush() try: term.term_embedding = await embed_glossary_term(term) @@ -124,3 +128,63 @@ async def delete_glossary_term( raise NotFoundError("GlossaryTerm", str(term_id)) await db.delete(term) await db.flush() + + +# --------------------------------------------------------------------------- # +# Certification lifecycle + version history +# --------------------------------------------------------------------------- # +@router.post( + "/connections/{connection_id}/glossary/{term_id}/status", + response_model=GlossaryTermResponse, +) +async def transition_glossary_status( + connection_id: uuid.UUID, + term_id: uuid.UUID, + body: StatusTransition, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + term = await db.get(GlossaryTerm, term_id) + if not term or term.connection_id != connection_id: + raise NotFoundError("GlossaryTerm", str(term_id)) + await versioning_service.transition_status( + db, ctx, ENTITY_GLOSSARY, term, body.status, reason=body.reason + ) + await db.flush() + return term + + +@router.get( + "/connections/{connection_id}/glossary/{term_id}/versions", + response_model=list[SemanticVersionResponse], +) +async def list_glossary_versions( + connection_id: uuid.UUID, + term_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + term = await db.get(GlossaryTerm, term_id) + if not term or term.connection_id != connection_id: + raise NotFoundError("GlossaryTerm", str(term_id)) + return await versioning_service.list_versions(db, ENTITY_GLOSSARY, term_id) + + +@router.get( + "/connections/{connection_id}/glossary/{term_id}/versions/{version}", + response_model=SemanticVersionResponse, +) +async def get_glossary_version( + connection_id: uuid.UUID, + term_id: uuid.UUID, + version: int, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + term = await db.get(GlossaryTerm, term_id) + if not term or term.connection_id != connection_id: + raise NotFoundError("GlossaryTerm", str(term_id)) + snap = await versioning_service.get_version(db, ENTITY_GLOSSARY, term_id, version) + if snap is None: + raise NotFoundError("GlossaryVersion", f"{term_id}@{version}") + return snap diff --git a/backend/app/api/v1/endpoints/metrics.py b/backend/app/api/v1/endpoints/metrics.py index 3a77ef6..76124b7 100644 --- a/backend/app/api/v1/endpoints/metrics.py +++ b/backend/app/api/v1/endpoints/metrics.py @@ -6,10 +6,15 @@ from app.api.v1.deps import require_connection_read, require_connection_write from app.api.v1.schemas.metric import MetricCreate, MetricResponse, MetricUpdate +from app.api.v1.schemas.catalog import LineageRefResponse +from app.api.v1.schemas.semantic_version import SemanticVersionResponse, StatusTransition from app.core.auth import AuthContext from app.core.exceptions import NotFoundError +from app.db.models.artifact_dependency import ARTIFACT_METRIC from app.db.models.metric import MetricDefinition +from app.db.models.semantic_version import ENTITY_METRIC from app.db.session import get_db +from app.services import lineage_service, versioning_service from app.services.embedding_service import embed_metric router = APIRouter(tags=["metrics"]) @@ -55,6 +60,7 @@ async def create_metric( metric.metric_embedding = await embed_metric(metric) except Exception: pass + await lineage_service.recompute_metric(db, ctx, metric) return metric @@ -82,7 +88,7 @@ async def update_metric( connection_id: uuid.UUID, metric_id: uuid.UUID, body: MetricUpdate, - _ctx: AuthContext = Depends(require_connection_write), + ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): metric = await db.get(MetricDefinition, metric_id) @@ -92,11 +98,13 @@ async def update_metric( for key, value in body.model_dump(exclude_none=True).items(): setattr(metric, key, value) + await versioning_service.record_edit(db, ctx, ENTITY_METRIC, metric) await db.flush() try: metric.metric_embedding = await embed_metric(metric) except Exception: pass + await lineage_service.recompute_metric(db, ctx, metric) return metric @@ -115,3 +123,80 @@ async def delete_metric( raise NotFoundError("Metric", str(metric_id)) await db.delete(metric) await db.flush() + + +# --------------------------------------------------------------------------- # +# Certification lifecycle + version history +# --------------------------------------------------------------------------- # +@router.post( + "/connections/{connection_id}/metrics/{metric_id}/status", + response_model=MetricResponse, +) +async def transition_metric_status( + connection_id: uuid.UUID, + metric_id: uuid.UUID, + body: StatusTransition, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + metric = await db.get(MetricDefinition, metric_id) + if not metric or metric.connection_id != connection_id: + raise NotFoundError("Metric", str(metric_id)) + await versioning_service.transition_status( + db, ctx, ENTITY_METRIC, metric, body.status, reason=body.reason + ) + await db.flush() + return metric + + +@router.get( + "/connections/{connection_id}/metrics/{metric_id}/versions", + response_model=list[SemanticVersionResponse], +) +async def list_metric_versions( + connection_id: uuid.UUID, + metric_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + metric = await db.get(MetricDefinition, metric_id) + if not metric or metric.connection_id != connection_id: + raise NotFoundError("Metric", str(metric_id)) + return await versioning_service.list_versions(db, ENTITY_METRIC, metric_id) + + +@router.get( + "/connections/{connection_id}/metrics/{metric_id}/versions/{version}", + response_model=SemanticVersionResponse, +) +async def get_metric_version( + connection_id: uuid.UUID, + metric_id: uuid.UUID, + version: int, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + metric = await db.get(MetricDefinition, metric_id) + if not metric or metric.connection_id != connection_id: + raise NotFoundError("Metric", str(metric_id)) + snap = await versioning_service.get_version(db, ENTITY_METRIC, metric_id, version) + if snap is None: + raise NotFoundError("MetricVersion", f"{metric_id}@{version}") + return snap + + +@router.get( + "/connections/{connection_id}/metrics/{metric_id}/lineage", + response_model=list[LineageRefResponse], +) +async def get_metric_lineage( + connection_id: uuid.UUID, + metric_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + """What tables/columns this metric's SQL touches.""" + metric = await db.get(MetricDefinition, metric_id) + if not metric or metric.connection_id != connection_id: + raise NotFoundError("Metric", str(metric_id)) + return await lineage_service.refs_for_artifact(db, ARTIFACT_METRIC, metric_id) diff --git a/backend/app/api/v1/endpoints/sample_queries.py b/backend/app/api/v1/endpoints/sample_queries.py index 69c4489..e5b33a7 100644 --- a/backend/app/api/v1/endpoints/sample_queries.py +++ b/backend/app/api/v1/endpoints/sample_queries.py @@ -6,10 +6,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.v1.deps import require_connection_read, require_connection_write +from app.api.v1.schemas.semantic_version import SemanticVersionResponse, StatusTransition from app.core.auth import AuthContext from app.core.exceptions import NotFoundError from app.db.models.sample_query import SampleQuery +from app.db.models.semantic_version import ENTITY_SAMPLE_QUERY from app.db.session import get_db +from app.services import versioning_service from app.services.embedding_service import embed_sample_query router = APIRouter(tags=["sample_queries"]) @@ -39,6 +42,10 @@ class SampleQueryResponse(BaseModel): description: str | None tags: list[str] | None is_validated: bool + status: str + version: int + certified_by_id: uuid.UUID | None + certified_at: str | None created_at: str updated_at: str @@ -96,7 +103,7 @@ async def update_sample_query( connection_id: uuid.UUID, sq_id: uuid.UUID, body: SampleQueryUpdate, - _ctx: AuthContext = Depends(require_connection_write), + ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): sq = await db.get(SampleQuery, sq_id) @@ -104,6 +111,7 @@ async def update_sample_query( raise NotFoundError("SampleQuery", str(sq_id)) for key, value in body.model_dump(exclude_none=True).items(): setattr(sq, key, value) + await versioning_service.record_edit(db, ctx, ENTITY_SAMPLE_QUERY, sq) await db.flush() try: sq.question_embedding = await embed_sample_query(sq) @@ -127,3 +135,63 @@ async def delete_sample_query( raise NotFoundError("SampleQuery", str(sq_id)) await db.delete(sq) await db.flush() + + +# --------------------------------------------------------------------------- # +# Certification lifecycle + version history +# --------------------------------------------------------------------------- # +@router.post( + "/connections/{connection_id}/sample-queries/{sq_id}/status", + response_model=SampleQueryResponse, +) +async def transition_sample_query_status( + connection_id: uuid.UUID, + sq_id: uuid.UUID, + body: StatusTransition, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + sq = await db.get(SampleQuery, sq_id) + if not sq or sq.connection_id != connection_id: + raise NotFoundError("SampleQuery", str(sq_id)) + await versioning_service.transition_status( + db, ctx, ENTITY_SAMPLE_QUERY, sq, body.status, reason=body.reason + ) + await db.flush() + return sq + + +@router.get( + "/connections/{connection_id}/sample-queries/{sq_id}/versions", + response_model=list[SemanticVersionResponse], +) +async def list_sample_query_versions( + connection_id: uuid.UUID, + sq_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + sq = await db.get(SampleQuery, sq_id) + if not sq or sq.connection_id != connection_id: + raise NotFoundError("SampleQuery", str(sq_id)) + return await versioning_service.list_versions(db, ENTITY_SAMPLE_QUERY, sq_id) + + +@router.get( + "/connections/{connection_id}/sample-queries/{sq_id}/versions/{version}", + response_model=SemanticVersionResponse, +) +async def get_sample_query_version( + connection_id: uuid.UUID, + sq_id: uuid.UUID, + version: int, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + sq = await db.get(SampleQuery, sq_id) + if not sq or sq.connection_id != connection_id: + raise NotFoundError("SampleQuery", str(sq_id)) + snap = await versioning_service.get_version(db, ENTITY_SAMPLE_QUERY, sq_id, version) + if snap is None: + raise NotFoundError("SampleQueryVersion", f"{sq_id}@{version}") + return snap diff --git a/backend/app/api/v1/endpoints/saved_queries.py b/backend/app/api/v1/endpoints/saved_queries.py index d04b955..7bb50cf 100644 --- a/backend/app/api/v1/endpoints/saved_queries.py +++ b/backend/app/api/v1/endpoints/saved_queries.py @@ -17,12 +17,16 @@ SavedQueryRunResponse, SavedQueryUpdate, ) +from app.api.v1.schemas.catalog import LineageRefResponse +from app.api.v1.schemas.semantic_version import SemanticVersionResponse, StatusTransition from app.core.auth import AuthContext from app.core.exceptions import AppError, NotFoundError +from app.db.models.artifact_dependency import ARTIFACT_SAVED_QUERY from app.db.models.chart import Chart from app.db.models.saved_query import SavedQuery +from app.db.models.semantic_version import ENTITY_SAVED_QUERY from app.db.session import get_db -from app.services import saved_query_service +from app.services import lineage_service, saved_query_service, versioning_service router = APIRouter(tags=["saved-queries"]) @@ -77,6 +81,7 @@ async def create_saved_query( ) db.add(saved) await db.flush() + await lineage_service.recompute_saved_query(db, ctx, saved) return saved @@ -101,19 +106,25 @@ async def update_saved_query( connection_id: uuid.UUID, saved_query_id: uuid.UUID, body: SavedQueryUpdate, - _ctx: AuthContext = Depends(require_connection_write), + ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): saved = await _get_saved_query(db, connection_id, saved_query_id) updates = body.model_dump(exclude_unset=True) + # Status changes go through the governed lifecycle, not a raw field write. + new_status = updates.pop("status", None) if "params" in updates and body.params is not None: updates["params"] = [p.model_dump() for p in body.params] - # Bump version when the executable SQL changes. - if "pinned_sql" in updates and updates["pinned_sql"] != saved.pinned_sql: - saved.version += 1 for key, value in updates.items(): setattr(saved, key, value) + # A content edit bumps the version and appends a changelog snapshot. + if updates: + await versioning_service.record_edit(db, ctx, ENTITY_SAVED_QUERY, saved) + if new_status is not None and new_status != saved.status: + await versioning_service.transition_status(db, ctx, ENTITY_SAVED_QUERY, saved, new_status) await db.flush() + if "pinned_sql" in updates: + await lineage_service.recompute_saved_query(db, ctx, saved) return saved @@ -158,9 +169,79 @@ async def clone_saved_query( ) db.add(clone) await db.flush() + await lineage_service.recompute_saved_query(db, ctx, clone) return clone +# --------------------------------------------------------------------------- # +# Certification lifecycle + version history +# --------------------------------------------------------------------------- # +@router.post( + "/connections/{connection_id}/saved-queries/{saved_query_id}/status", + response_model=SavedQueryResponse, +) +async def transition_saved_query_status( + connection_id: uuid.UUID, + saved_query_id: uuid.UUID, + body: StatusTransition, + ctx: AuthContext = Depends(require_connection_write), + db: AsyncSession = Depends(get_db), +): + saved = await _get_saved_query(db, connection_id, saved_query_id) + await versioning_service.transition_status( + db, ctx, ENTITY_SAVED_QUERY, saved, body.status, reason=body.reason + ) + await db.flush() + return saved + + +@router.get( + "/connections/{connection_id}/saved-queries/{saved_query_id}/versions", + response_model=list[SemanticVersionResponse], +) +async def list_saved_query_versions( + connection_id: uuid.UUID, + saved_query_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + await _get_saved_query(db, connection_id, saved_query_id) + return await versioning_service.list_versions(db, ENTITY_SAVED_QUERY, saved_query_id) + + +@router.get( + "/connections/{connection_id}/saved-queries/{saved_query_id}/versions/{version}", + response_model=SemanticVersionResponse, +) +async def get_saved_query_version( + connection_id: uuid.UUID, + saved_query_id: uuid.UUID, + version: int, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + await _get_saved_query(db, connection_id, saved_query_id) + snap = await versioning_service.get_version(db, ENTITY_SAVED_QUERY, saved_query_id, version) + if snap is None: + raise NotFoundError("SavedQueryVersion", f"{saved_query_id}@{version}") + return snap + + +@router.get( + "/connections/{connection_id}/saved-queries/{saved_query_id}/lineage", + response_model=list[LineageRefResponse], +) +async def get_saved_query_lineage( + connection_id: uuid.UUID, + saved_query_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), + db: AsyncSession = Depends(get_db), +): + """What tables/columns this saved query touches.""" + await _get_saved_query(db, connection_id, saved_query_id) + return await lineage_service.refs_for_artifact(db, ARTIFACT_SAVED_QUERY, saved_query_id) + + # --------------------------------------------------------------------------- # # Run + export # --------------------------------------------------------------------------- # diff --git a/backend/app/api/v1/router.py b/backend/app/api/v1/router.py index 38cf094..730faa8 100644 --- a/backend/app/api/v1/router.py +++ b/backend/app/api/v1/router.py @@ -4,6 +4,7 @@ api_keys, assistant, auth, + catalog, connections, dashboards, dictionary, @@ -37,3 +38,4 @@ api_router.include_router(dashboards.router) api_router.include_router(query_history.router) api_router.include_router(knowledge.router) +api_router.include_router(catalog.router) diff --git a/backend/app/api/v1/schemas/catalog.py b/backend/app/api/v1/schemas/catalog.py new file mode 100644 index 0000000..c047ac6 --- /dev/null +++ b/backend/app/api/v1/schemas/catalog.py @@ -0,0 +1,41 @@ +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel + + +class CatalogHitResponse(BaseModel): + type: str + id: str + name: str + description: str | None = None + status: str | None = None + certified_at: datetime | None = None + owner_id: str | None = None + context: str | None = None + score: float + match_reason: str + + model_config = {"from_attributes": True} + + +class CatalogFacetsResponse(BaseModel): + schemas: list[str] + owners: list[str] + tags: list[str] + types: list[str] + status_counts: dict[str, int] + + +class LineageRefResponse(BaseModel): + id: UUID + artifact_type: str + artifact_id: UUID + ref_kind: str + schema_name: str | None + table_name: str + column_name: str | None + table_id: UUID | None + column_id: UUID | None + + model_config = {"from_attributes": True} diff --git a/backend/app/api/v1/schemas/glossary.py b/backend/app/api/v1/schemas/glossary.py index ac15d09..f8bcb10 100644 --- a/backend/app/api/v1/schemas/glossary.py +++ b/backend/app/api/v1/schemas/glossary.py @@ -31,6 +31,10 @@ class GlossaryTermResponse(BaseModel): related_tables: list[str] | None related_columns: list[str] | None examples: list[str] | None + status: str + version: int + certified_by_id: UUID | None + certified_at: datetime | None created_at: datetime updated_at: datetime diff --git a/backend/app/api/v1/schemas/metric.py b/backend/app/api/v1/schemas/metric.py index f131776..025bec4 100644 --- a/backend/app/api/v1/schemas/metric.py +++ b/backend/app/api/v1/schemas/metric.py @@ -37,6 +37,10 @@ class MetricResponse(BaseModel): related_tables: list[str] | None dimensions: list[str] | None filters: dict | None + status: str + version: int + certified_by_id: UUID | None + certified_at: datetime | None created_at: datetime updated_at: datetime diff --git a/backend/app/api/v1/schemas/saved_query.py b/backend/app/api/v1/schemas/saved_query.py index 99030fd..4c11209 100644 --- a/backend/app/api/v1/schemas/saved_query.py +++ b/backend/app/api/v1/schemas/saved_query.py @@ -45,6 +45,8 @@ class SavedQueryResponse(BaseModel): params: list[ParamDef] | None version: int status: str + certified_by_id: UUID | None + certified_at: datetime | None is_public: bool created_at: datetime updated_at: datetime diff --git a/backend/app/api/v1/schemas/semantic_version.py b/backend/app/api/v1/schemas/semantic_version.py new file mode 100644 index 0000000..127868a --- /dev/null +++ b/backend/app/api/v1/schemas/semantic_version.py @@ -0,0 +1,28 @@ +from datetime import datetime +from typing import Any +from uuid import UUID + +from pydantic import BaseModel, Field + +from app.services.versioning_service import STATUSES + + +class StatusTransition(BaseModel): + """Request body for a certification-lifecycle transition.""" + + status: str = Field(description=f"Target status; one of: {', '.join(STATUSES)}") + reason: str | None = None + + +class SemanticVersionResponse(BaseModel): + id: UUID + entity_type: str + entity_id: UUID + version: int + status: str + snapshot: dict[str, Any] + change_reason: str | None + changed_by_id: UUID | None + created_at: datetime + + model_config = {"from_attributes": True} diff --git a/backend/app/db/models/__init__.py b/backend/app/db/models/__init__.py index 0263631..39758da 100644 --- a/backend/app/db/models/__init__.py +++ b/backend/app/db/models/__init__.py @@ -1,4 +1,5 @@ from app.db.models.api_key import ApiKey +from app.db.models.artifact_dependency import ArtifactDependency from app.db.models.chart import Chart from app.db.models.connection import DatabaseConnection from app.db.models.dashboard import Dashboard @@ -14,6 +15,7 @@ from app.db.models.sample_query import SampleQuery from app.db.models.saved_query import SavedQuery from app.db.models.schema_cache import CachedColumn, CachedRelationship, CachedTable +from app.db.models.semantic_version import SemanticVersion from app.db.models.team import Team from app.db.models.user import User @@ -39,4 +41,6 @@ "ResultSnapshot", "Dashboard", "DashboardTile", + "SemanticVersion", + "ArtifactDependency", ] diff --git a/backend/app/db/models/artifact_dependency.py b/backend/app/db/models/artifact_dependency.py new file mode 100644 index 0000000..0562a8c --- /dev/null +++ b/backend/app/db/models/artifact_dependency.py @@ -0,0 +1,52 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, String, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.base import Base + +# Artifact types whose SQL is parsed for table/column references. +ARTIFACT_SAVED_QUERY = "saved_query" +ARTIFACT_METRIC = "metric" +ARTIFACT_TYPES = (ARTIFACT_SAVED_QUERY, ARTIFACT_METRIC) + +REF_TABLE = "table" +REF_COLUMN = "column" + + +class ArtifactDependency(Base): + """A lineage edge: an artifact (saved query / metric) references a table/column. + + Recomputed from the artifact's SQL via ``lineage_service`` on create/update. + Powers the catalog impact view ("what depends on this table") and the + per-artifact "what this touches" view. Names are stored denormalized so the + edge survives even if the schema cache hasn't been (re-)introspected; + ``table_id`` / ``column_id`` are resolved best-effort when a match exists. + """ + + __tablename__ = "artifact_dependencies" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) + connection_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + ) + artifact_type: Mapped[str] = mapped_column(String(20), nullable=False) + artifact_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + ref_kind: Mapped[str] = mapped_column(String(10), nullable=False) + schema_name: Mapped[str | None] = mapped_column(String(255)) + table_name: Mapped[str] = mapped_column(String(255), nullable=False) + column_name: Mapped[str | None] = mapped_column(String(255)) + table_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("cached_tables.id", ondelete="SET NULL") + ) + column_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("cached_columns.id", ondelete="SET NULL") + ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) diff --git a/backend/app/db/models/glossary.py b/backend/app/db/models/glossary.py index 14c47dd..15cfc4d 100644 --- a/backend/app/db/models/glossary.py +++ b/backend/app/db/models/glossary.py @@ -2,7 +2,7 @@ from datetime import datetime from pgvector.sqlalchemy import Vector -from sqlalchemy import DateTime, ForeignKey, String, Text, func +from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -18,7 +18,9 @@ class GlossaryTerm(Base): UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False ) connection_id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="CASCADE"), nullable=False + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, ) term: Mapped[str] = mapped_column(String(255), nullable=False) definition: Mapped[str] = mapped_column(Text, nullable=False) @@ -27,12 +29,17 @@ class GlossaryTerm(Base): related_columns: Mapped[list[str] | None] = mapped_column(ARRAY(Text)) examples: Mapped[dict | None] = mapped_column(JSONB, default=list) term_embedding = mapped_column(Vector(settings.embedding_dimension), nullable=True) - created_by_id: Mapped[uuid.UUID | None] = mapped_column( + # Phase 3 trust/lifecycle: draft|in_review|certified|deprecated, versioned. + status: Mapped[str] = mapped_column(String(20), default="draft", nullable=False) + version: Mapped[int] = mapped_column(Integer, default=1, nullable=False) + certified_by_id: Mapped[uuid.UUID | None] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now() + certified_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + created_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) diff --git a/backend/app/db/models/metric.py b/backend/app/db/models/metric.py index df77766..ad9802f 100644 --- a/backend/app/db/models/metric.py +++ b/backend/app/db/models/metric.py @@ -2,7 +2,7 @@ from datetime import datetime from pgvector.sqlalchemy import Vector -from sqlalchemy import DateTime, ForeignKey, String, Text, func +from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -18,7 +18,9 @@ class MetricDefinition(Base): UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False ) connection_id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="CASCADE"), nullable=False + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, ) metric_name: Mapped[str] = mapped_column(String(255), nullable=False) display_name: Mapped[str] = mapped_column(String(255), nullable=False) @@ -29,12 +31,17 @@ class MetricDefinition(Base): dimensions: Mapped[list[str] | None] = mapped_column(ARRAY(Text)) filters: Mapped[dict | None] = mapped_column(JSONB, default=dict) metric_embedding = mapped_column(Vector(settings.embedding_dimension), nullable=True) - created_by_id: Mapped[uuid.UUID | None] = mapped_column( + # Phase 3 trust/lifecycle: draft|in_review|certified|deprecated, versioned. + status: Mapped[str] = mapped_column(String(20), default="draft", nullable=False) + version: Mapped[int] = mapped_column(Integer, default=1, nullable=False) + certified_by_id: Mapped[uuid.UUID | None] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now() + certified_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + created_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) diff --git a/backend/app/db/models/sample_query.py b/backend/app/db/models/sample_query.py index 0060fe8..f5eba4c 100644 --- a/backend/app/db/models/sample_query.py +++ b/backend/app/db/models/sample_query.py @@ -2,7 +2,7 @@ from datetime import datetime from pgvector.sqlalchemy import Vector -from sqlalchemy import Boolean, DateTime, ForeignKey, Text, func +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text, func from sqlalchemy.dialects.postgresql import ARRAY, UUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -18,7 +18,9 @@ class SampleQuery(Base): UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False ) connection_id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="CASCADE"), nullable=False + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, ) natural_language: Mapped[str] = mapped_column(Text, nullable=False) sql_query: Mapped[str] = mapped_column(Text, nullable=False) @@ -26,12 +28,17 @@ class SampleQuery(Base): tags: Mapped[list[str] | None] = mapped_column(ARRAY(Text)) is_validated: Mapped[bool] = mapped_column(Boolean, default=False) question_embedding = mapped_column(Vector(settings.embedding_dimension), nullable=True) - created_by_id: Mapped[uuid.UUID | None] = mapped_column( + # Phase 3 trust/lifecycle: draft|in_review|certified|deprecated, versioned. + status: Mapped[str] = mapped_column(String(20), default="draft", nullable=False) + version: Mapped[int] = mapped_column(Integer, default=1, nullable=False) + certified_by_id: Mapped[uuid.UUID | None] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now() + certified_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + created_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) diff --git a/backend/app/db/models/saved_query.py b/backend/app/db/models/saved_query.py index 0237040..926f950 100644 --- a/backend/app/db/models/saved_query.py +++ b/backend/app/db/models/saved_query.py @@ -36,8 +36,12 @@ class SavedQuery(Base): # List of param defs: {name, type: string|number|date|boolean, label, default} params: Mapped[list | None] = mapped_column(JSONB, default=list) version: Mapped[int] = mapped_column(Integer, default=1, nullable=False) - # Forward-compat for Phase 3 certification (draft|certified|deprecated). + # Phase 3 trust/lifecycle: draft|in_review|certified|deprecated. status: Mapped[str] = mapped_column(String(20), default="draft", nullable=False) + certified_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) + certified_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) # Visible to the whole workspace vs. owner-only. is_public: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) diff --git a/backend/app/db/models/semantic_version.py b/backend/app/db/models/semantic_version.py new file mode 100644 index 0000000..63024a6 --- /dev/null +++ b/backend/app/db/models/semantic_version.py @@ -0,0 +1,48 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.base import Base + +# Semantic entity types that carry a certification lifecycle + version history. +ENTITY_METRIC = "metric" +ENTITY_GLOSSARY = "glossary" +ENTITY_SAMPLE_QUERY = "sample_query" +ENTITY_SAVED_QUERY = "saved_query" +ENTITY_TYPES = (ENTITY_METRIC, ENTITY_GLOSSARY, ENTITY_SAMPLE_QUERY, ENTITY_SAVED_QUERY) + + +class SemanticVersion(Base): + """An append-only snapshot of a semantic entity at a given version. + + Written on every content edit and status transition, giving each metric / + glossary term / sample query / saved query a changelog + diff history with the + reviewer and reason. Scoped by ``connection_id`` (the workspace cascade root) + like the entities it snapshots. + """ + + __tablename__ = "semantic_versions" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) + connection_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("database_connections.id", ondelete="CASCADE"), + nullable=False, + ) + entity_type: Mapped[str] = mapped_column(String(20), nullable=False) + entity_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + version: Mapped[int] = mapped_column(Integer, nullable=False) + status: Mapped[str] = mapped_column(String(20), nullable=False) + # Full serialized entity fields at this version. + snapshot: Mapped[dict] = mapped_column(JSONB, nullable=False) + change_reason: Mapped[str | None] = mapped_column(Text) + changed_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) diff --git a/backend/app/services/catalog_service.py b/backend/app/services/catalog_service.py new file mode 100644 index 0000000..fc4a453 --- /dev/null +++ b/backend/app/services/catalog_service.py @@ -0,0 +1,550 @@ +"""Catalog search service (Phase 3 — Milestone 2). + +A unified, read-side discovery layer over the schema cache + semantic layer. +Reuses the existing pgvector embeddings and the keyword scorer from the NL→SQL +pipeline (``relevance_scorer``) — no new full-text infrastructure. Results across +tables, columns, metrics, glossary terms, sample queries, saved queries, and +knowledge documents are merged into a uniform :class:`CatalogHit`, with certified +items boosted in the ranking. +""" + +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass +from datetime import datetime + +from sqlalchemy import func, or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models.glossary import GlossaryTerm +from app.db.models.knowledge import KnowledgeChunk, KnowledgeDocument +from app.db.models.metric import MetricDefinition +from app.db.models.sample_query import SampleQuery +from app.db.models.saved_query import SavedQuery +from app.db.models.schema_cache import CachedColumn, CachedTable +from app.semantic.relevance_scorer import extract_keywords, keyword_match_score +from app.services.embedding_service import embed_text + +logger = logging.getLogger(__name__) + +# Hit types. +TYPE_TABLE = "table" +TYPE_COLUMN = "column" +TYPE_METRIC = "metric" +TYPE_GLOSSARY = "glossary" +TYPE_SAMPLE_QUERY = "sample_query" +TYPE_SAVED_QUERY = "saved_query" +TYPE_KNOWLEDGE = "knowledge" +ALL_TYPES = ( + TYPE_TABLE, + TYPE_COLUMN, + TYPE_METRIC, + TYPE_GLOSSARY, + TYPE_SAMPLE_QUERY, + TYPE_SAVED_QUERY, + TYPE_KNOWLEDGE, +) + +_W_EMB = 0.6 +_W_KW = 0.4 +_CERT_BOOST = 0.15 +_VECTOR_LIMIT = 20 + + +@dataclass +class CatalogHit: + type: str + id: str + name: str + description: str | None = None + status: str | None = None + certified_at: datetime | None = None + owner_id: str | None = None + context: str | None = None # e.g. "schema.table" for a column, schema for a table + score: float = 0.0 + match_reason: str = "keyword" + + +def _combine(emb: float, kw: float, status: str | None) -> tuple[float, str]: + score = _W_EMB * emb + _W_KW * kw + if status == "certified": + score += _CERT_BOOST + return score, ("embedding" if emb >= kw else "keyword") + + +async def _vector_hits(db: AsyncSession, stmt) -> list: + """Execute a pgvector similarity query, degrading to [] on failure (e.g. dim mismatch).""" + try: + result = await db.execute(stmt) + return list(result.all()) + except Exception: + logger.warning("Catalog vector search failed; keyword-only for this type.", exc_info=True) + await db.rollback() + return [] + + +async def search( + db: AsyncSession, + connection_id: uuid.UUID, + query: str, + *, + types: list[str] | None = None, + status: str | None = None, + owner: str | None = None, + schema: str | None = None, + limit: int = 50, +) -> list[CatalogHit]: + """Hybrid search across the schema cache + semantic layer for one connection.""" + wanted = set(types) if types else set(ALL_TYPES) + keywords = extract_keywords(query) if query else [] + + qvec: list[float] | None = None + if query.strip(): + try: + qvec = await embed_text(query) + except Exception: + logger.debug("Catalog query embedding failed; keyword-only.", exc_info=True) + + hits: list[CatalogHit] = [] + + if TYPE_TABLE in wanted: + hits += await _search_tables(db, connection_id, qvec, keywords, schema) + if TYPE_COLUMN in wanted: + hits += await _search_columns(db, connection_id, qvec, keywords, schema) + if TYPE_GLOSSARY in wanted: + hits += await _search_glossary(db, connection_id, qvec, keywords) + if TYPE_METRIC in wanted: + hits += await _search_metrics(db, connection_id, qvec, keywords) + if TYPE_SAMPLE_QUERY in wanted: + hits += await _search_sample_queries(db, connection_id, qvec, keywords) + if TYPE_SAVED_QUERY in wanted: + hits += await _search_saved_queries(db, connection_id, keywords) + if TYPE_KNOWLEDGE in wanted: + hits += await _search_knowledge(db, connection_id, qvec, keywords) + + # Facet filters applied post-merge (status/owner only meaningful for some types). + if status: + hits = [h for h in hits if h.status == status] + if owner: + hits = [h for h in hits if h.owner_id == owner] + + return rank_hits(hits, limit) + + +def rank_hits(hits: list[CatalogHit], limit: int) -> list[CatalogHit]: + """Order hits certified-first, then by score; truncate to ``limit``.""" + hits.sort(key=lambda h: (h.status == "certified", h.score), reverse=True) + return hits[:limit] + + +# --------------------------------------------------------------------------- # +# Per-type searches +# --------------------------------------------------------------------------- # +async def _search_tables(db, connection_id, qvec, keywords, schema) -> list[CatalogHit]: + scores: dict[uuid.UUID, tuple[float, float]] = {} # id -> (emb, kw) + rows: dict[uuid.UUID, CachedTable] = {} + + if qvec is not None: + stmt = ( + select( + CachedTable, + (1 - CachedTable.description_embedding.cosine_distance(qvec)).label("sim"), + ) + .where( + CachedTable.connection_id == connection_id, + CachedTable.description_embedding.isnot(None), + ) + .order_by(CachedTable.description_embedding.cosine_distance(qvec)) + .limit(_VECTOR_LIMIT) + ) + for tbl, sim in await _vector_hits(db, stmt): + rows[tbl.id] = tbl + scores[tbl.id] = (float(sim), 0.0) + + if keywords: + conds = [CachedTable.table_name.ilike(f"%{kw}%") for kw in keywords] + result = await db.execute( + select(CachedTable).where(CachedTable.connection_id == connection_id, or_(*conds)) + ) + for tbl in result.scalars().all(): + rows[tbl.id] = tbl + emb = scores.get(tbl.id, (0.0, 0.0))[0] + scores[tbl.id] = (emb, keyword_match_score(tbl.table_name, keywords)) + + hits = [] + for tid, tbl in rows.items(): + if schema and tbl.schema_name != schema: + continue + emb, kw = scores[tid] + score, reason = _combine(emb, kw, None) + hits.append( + CatalogHit( + type=TYPE_TABLE, + id=str(tbl.id), + name=tbl.table_name, + description=tbl.comment, + context=tbl.schema_name, + score=score, + match_reason=reason, + ) + ) + return hits + + +async def _search_columns(db, connection_id, qvec, keywords, schema) -> list[CatalogHit]: + scores: dict[uuid.UUID, tuple[float, float]] = {} + rows: dict[uuid.UUID, tuple[CachedColumn, CachedTable]] = {} + + if qvec is not None: + stmt = ( + select( + CachedColumn, + CachedTable, + (1 - CachedColumn.description_embedding.cosine_distance(qvec)).label("sim"), + ) + .join(CachedTable, CachedColumn.table_id == CachedTable.id) + .where( + CachedTable.connection_id == connection_id, + CachedColumn.description_embedding.isnot(None), + ) + .order_by(CachedColumn.description_embedding.cosine_distance(qvec)) + .limit(_VECTOR_LIMIT) + ) + for col, tbl, sim in await _vector_hits(db, stmt): + rows[col.id] = (col, tbl) + scores[col.id] = (float(sim), 0.0) + + if keywords: + conds = [CachedColumn.column_name.ilike(f"%{kw}%") for kw in keywords] + result = await db.execute( + select(CachedColumn, CachedTable) + .join(CachedTable, CachedColumn.table_id == CachedTable.id) + .where(CachedTable.connection_id == connection_id, or_(*conds)) + .limit(100) + ) + for col, tbl in result.all(): + rows[col.id] = (col, tbl) + emb = scores.get(col.id, (0.0, 0.0))[0] + scores[col.id] = (emb, keyword_match_score(col.column_name, keywords)) + + hits = [] + for cid, (col, tbl) in rows.items(): + if schema and tbl.schema_name != schema: + continue + emb, kw = scores[cid] + score, reason = _combine(emb, kw, None) + hits.append( + CatalogHit( + type=TYPE_COLUMN, + id=str(col.id), + name=col.column_name, + description=col.comment, + context=f"{tbl.schema_name}.{tbl.table_name}", + score=score, + match_reason=reason, + ) + ) + return hits + + +async def _search_glossary(db, connection_id, qvec, keywords) -> list[CatalogHit]: + scores: dict[uuid.UUID, tuple[float, float]] = {} + rows: dict[uuid.UUID, GlossaryTerm] = {} + + if qvec is not None: + stmt = ( + select( + GlossaryTerm, + (1 - GlossaryTerm.term_embedding.cosine_distance(qvec)).label("sim"), + ) + .where( + GlossaryTerm.connection_id == connection_id, + GlossaryTerm.term_embedding.isnot(None), + ) + .order_by(GlossaryTerm.term_embedding.cosine_distance(qvec)) + .limit(_VECTOR_LIMIT) + ) + for term, sim in await _vector_hits(db, stmt): + rows[term.id] = term + scores[term.id] = (float(sim), 0.0) + + if keywords: + conds = [GlossaryTerm.term.ilike(f"%{kw}%") for kw in keywords] + result = await db.execute( + select(GlossaryTerm).where(GlossaryTerm.connection_id == connection_id, or_(*conds)) + ) + for term in result.scalars().all(): + rows[term.id] = term + emb = scores.get(term.id, (0.0, 0.0))[0] + scores[term.id] = (emb, keyword_match_score(term.term, keywords)) + + hits = [] + for tid, term in rows.items(): + emb, kw = scores[tid] + score, reason = _combine(emb, kw, term.status) + hits.append( + CatalogHit( + type=TYPE_GLOSSARY, + id=str(term.id), + name=term.term, + description=term.definition, + status=term.status, + certified_at=term.certified_at, + owner_id=str(term.created_by_id) if term.created_by_id else None, + score=score, + match_reason=reason, + ) + ) + return hits + + +async def _search_metrics(db, connection_id, qvec, keywords) -> list[CatalogHit]: + scores: dict[uuid.UUID, tuple[float, float]] = {} + rows: dict[uuid.UUID, MetricDefinition] = {} + + if qvec is not None: + stmt = ( + select( + MetricDefinition, + (1 - MetricDefinition.metric_embedding.cosine_distance(qvec)).label("sim"), + ) + .where( + MetricDefinition.connection_id == connection_id, + MetricDefinition.metric_embedding.isnot(None), + ) + .order_by(MetricDefinition.metric_embedding.cosine_distance(qvec)) + .limit(_VECTOR_LIMIT) + ) + for metric, sim in await _vector_hits(db, stmt): + rows[metric.id] = metric + scores[metric.id] = (float(sim), 0.0) + + if keywords: + conds = [MetricDefinition.metric_name.ilike(f"%{kw}%") for kw in keywords] + [ + MetricDefinition.display_name.ilike(f"%{kw}%") for kw in keywords + ] + result = await db.execute( + select(MetricDefinition).where( + MetricDefinition.connection_id == connection_id, or_(*conds) + ) + ) + for metric in result.scalars().all(): + rows[metric.id] = metric + emb = scores.get(metric.id, (0.0, 0.0))[0] + kw = max( + keyword_match_score(metric.metric_name, keywords), + keyword_match_score(metric.display_name, keywords), + ) + scores[metric.id] = (emb, kw) + + hits = [] + for mid, metric in rows.items(): + emb, kw = scores[mid] + score, reason = _combine(emb, kw, metric.status) + hits.append( + CatalogHit( + type=TYPE_METRIC, + id=str(metric.id), + name=metric.display_name or metric.metric_name, + description=metric.description, + status=metric.status, + certified_at=metric.certified_at, + owner_id=str(metric.created_by_id) if metric.created_by_id else None, + score=score, + match_reason=reason, + ) + ) + return hits + + +async def _search_sample_queries(db, connection_id, qvec, keywords) -> list[CatalogHit]: + scores: dict[uuid.UUID, tuple[float, float]] = {} + rows: dict[uuid.UUID, SampleQuery] = {} + + if qvec is not None: + stmt = ( + select( + SampleQuery, + (1 - SampleQuery.question_embedding.cosine_distance(qvec)).label("sim"), + ) + .where( + SampleQuery.connection_id == connection_id, + SampleQuery.question_embedding.isnot(None), + ) + .order_by(SampleQuery.question_embedding.cosine_distance(qvec)) + .limit(_VECTOR_LIMIT) + ) + for sq, sim in await _vector_hits(db, stmt): + rows[sq.id] = sq + scores[sq.id] = (float(sim), 0.0) + + if keywords: + conds = [SampleQuery.natural_language.ilike(f"%{kw}%") for kw in keywords] + result = await db.execute( + select(SampleQuery).where(SampleQuery.connection_id == connection_id, or_(*conds)) + ) + for sq in result.scalars().all(): + rows[sq.id] = sq + emb = scores.get(sq.id, (0.0, 0.0))[0] + scores[sq.id] = (emb, keyword_match_score(sq.natural_language, keywords)) + + hits = [] + for sid, sq in rows.items(): + emb, kw = scores[sid] + score, reason = _combine(emb, kw, sq.status) + hits.append( + CatalogHit( + type=TYPE_SAMPLE_QUERY, + id=str(sq.id), + name=sq.natural_language[:120], + description=sq.description, + status=sq.status, + certified_at=sq.certified_at, + owner_id=str(sq.created_by_id) if sq.created_by_id else None, + score=score, + match_reason=reason, + ) + ) + return hits + + +async def _search_saved_queries(db, connection_id, keywords) -> list[CatalogHit]: + # No embedding column on saved queries — keyword-only. + if not keywords: + return [] + conds = [SavedQuery.name.ilike(f"%{kw}%") for kw in keywords] + result = await db.execute( + select(SavedQuery).where(SavedQuery.connection_id == connection_id, or_(*conds)) + ) + hits = [] + for sq in result.scalars().all(): + kw = keyword_match_score(sq.name, keywords) + score, reason = _combine(0.0, kw, sq.status) + hits.append( + CatalogHit( + type=TYPE_SAVED_QUERY, + id=str(sq.id), + name=sq.name, + description=sq.description, + status=sq.status, + certified_at=sq.certified_at, + owner_id=str(sq.owner_id) if sq.owner_id else None, + score=score, + match_reason=reason, + ) + ) + return hits + + +async def _search_knowledge(db, connection_id, qvec, keywords) -> list[CatalogHit]: + # Search chunks by vector, collapse to the best score per document. + best: dict[uuid.UUID, float] = {} + docs: dict[uuid.UUID, KnowledgeDocument] = {} + + if qvec is not None: + stmt = ( + select( + KnowledgeDocument, + (1 - KnowledgeChunk.chunk_embedding.cosine_distance(qvec)).label("sim"), + ) + .join(KnowledgeChunk, KnowledgeChunk.document_id == KnowledgeDocument.id) + .where( + KnowledgeDocument.connection_id == connection_id, + KnowledgeChunk.chunk_embedding.isnot(None), + ) + .order_by(KnowledgeChunk.chunk_embedding.cosine_distance(qvec)) + .limit(_VECTOR_LIMIT) + ) + for doc, sim in await _vector_hits(db, stmt): + docs[doc.id] = doc + best[doc.id] = max(best.get(doc.id, 0.0), float(sim)) + + if keywords: + conds = [KnowledgeDocument.title.ilike(f"%{kw}%") for kw in keywords] + result = await db.execute( + select(KnowledgeDocument).where( + KnowledgeDocument.connection_id == connection_id, or_(*conds) + ) + ) + for doc in result.scalars().all(): + docs[doc.id] = doc + + hits = [] + for did, doc in docs.items(): + emb = best.get(did, 0.0) + kw = keyword_match_score(doc.title, keywords) if keywords else 0.0 + score, reason = _combine(emb, kw, None) + hits.append( + CatalogHit( + type=TYPE_KNOWLEDGE, + id=str(doc.id), + name=doc.title, + description=doc.source_url, + score=score, + match_reason=reason, + ) + ) + return hits + + +# --------------------------------------------------------------------------- # +# Facets +# --------------------------------------------------------------------------- # +async def facets(db: AsyncSession, connection_id: uuid.UUID) -> dict: + """Return available facet values for the filter sidebar.""" + schemas = list( + ( + await db.execute( + select(CachedTable.schema_name) + .where(CachedTable.connection_id == connection_id) + .distinct() + .order_by(CachedTable.schema_name) + ) + ) + .scalars() + .all() + ) + + owners: set[str] = set() + for model, col in ( + (MetricDefinition, MetricDefinition.created_by_id), + (GlossaryTerm, GlossaryTerm.created_by_id), + (SampleQuery, SampleQuery.created_by_id), + (SavedQuery, SavedQuery.owner_id), + ): + result = await db.execute( + select(col).where(model.connection_id == connection_id, col.isnot(None)).distinct() + ) + owners.update(str(v) for v in result.scalars().all()) + + # Status counts across the lifecycle entities. + status_counts: dict[str, int] = {} + for model in (MetricDefinition, GlossaryTerm, SampleQuery, SavedQuery): + result = await db.execute( + select(model.status, func.count()) + .where(model.connection_id == connection_id) + .group_by(model.status) + ) + for st, cnt in result.all(): + status_counts[st] = status_counts.get(st, 0) + cnt + + # Tags from sample queries (ARRAY column). + tags: set[str] = set() + tag_rows = await db.execute( + select(SampleQuery.tags).where( + SampleQuery.connection_id == connection_id, SampleQuery.tags.isnot(None) + ) + ) + for tag_list in tag_rows.scalars().all(): + for tag in tag_list or []: + tags.add(str(tag)) + + return { + "schemas": schemas, + "owners": sorted(owners), + "tags": sorted(tags), + "types": list(ALL_TYPES), + "status_counts": status_counts, + } diff --git a/backend/app/services/lineage_service.py b/backend/app/services/lineage_service.py new file mode 100644 index 0000000..cd23f60 --- /dev/null +++ b/backend/app/services/lineage_service.py @@ -0,0 +1,260 @@ +"""Lineage extraction (Phase 3 — Milestone 2). + +Parses an artifact's SQL with sqlglot to record which tables/columns it touches, +stored as :class:`ArtifactDependency` edges. Powers the catalog impact view +("what depends on this table") and the per-artifact "what this touches" view. + +sqlglot is an optional dependency (the ``[lineage]`` extra). When it is absent, +extraction degrades to a no-op — the parent write (saving a query/metric) is never +blocked by lineage. +""" + +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.auth import AuthContext +from app.db.models.artifact_dependency import ( + ARTIFACT_METRIC, + ARTIFACT_SAVED_QUERY, + REF_COLUMN, + REF_TABLE, + ArtifactDependency, +) +from app.db.models.connection import DatabaseConnection +from app.db.models.schema_cache import CachedColumn, CachedTable + +logger = logging.getLogger(__name__) + +# Map our connector types to sqlglot dialect names. +_DIALECTS: dict[str, str] = { + "postgresql": "postgres", + "mysql": "mysql", + "snowflake": "snowflake", + "bigquery": "bigquery", + "databricks": "databricks", +} + + +def dialect_for(connector_type: str | None) -> str | None: + if not connector_type: + return None + return _DIALECTS.get(connector_type.lower()) + + +@dataclass(frozen=True) +class Ref: + """A table or column reference parsed out of SQL.""" + + ref_kind: str # REF_TABLE | REF_COLUMN + schema_name: str | None + table_name: str + column_name: str | None = None + + +def extract_refs(sql: str, dialect: str | None = None) -> list[Ref]: + """Extract table + qualified-column references from a SQL string. + + Returns an empty list if sqlglot is unavailable or the SQL cannot be parsed — + lineage is best-effort and must never raise into the caller. + """ + if not sql or not sql.strip(): + return [] + try: + import sqlglot + from sqlglot import exp + except ImportError: + logger.debug("sqlglot not installed; skipping lineage extraction") + return [] + + try: + tree = sqlglot.parse_one(sql, dialect=dialect) + except Exception as exc: + logger.debug("lineage parse failed: %s", exc) + return [] + + refs: dict[tuple, Ref] = {} + + # Table references, plus an alias -> table map for resolving qualified columns. + alias_to_table: dict[str, tuple[str | None, str]] = {} + for table in tree.find_all(exp.Table): + table_name = table.name + if not table_name: + continue + schema_name = table.db or None + key: tuple = (REF_TABLE, schema_name, table_name, None) + refs[key] = Ref(REF_TABLE, schema_name, table_name) + alias_to_table[table_name] = (schema_name, table_name) + if table.alias: + alias_to_table[table.alias] = (schema_name, table_name) + + # Column references — only keep those we can attribute to a known table + # (qualified by an alias/table name, or unambiguous single-table queries). + single_table = ( + next(iter(alias_to_table.values())) if len(set(alias_to_table.values())) == 1 else None + ) + for col in tree.find_all(exp.Column): + col_name = col.name + if not col_name or col_name == "*": + continue + qualifier = col.table + if qualifier and qualifier in alias_to_table: + schema_name, table_name = alias_to_table[qualifier] + elif not qualifier and single_table is not None: + schema_name, table_name = single_table + else: + continue + key = (REF_COLUMN, schema_name, table_name, col_name) + refs[key] = Ref(REF_COLUMN, schema_name, table_name, col_name) + + return list(refs.values()) + + +async def _resolve_ids( + db: AsyncSession, connection_id: uuid.UUID, ref: Ref +) -> tuple[uuid.UUID | None, uuid.UUID | None]: + """Best-effort resolve a Ref's table_id / column_id against the schema cache.""" + table_stmt = select(CachedTable).where( + CachedTable.connection_id == connection_id, + CachedTable.table_name == ref.table_name, + ) + if ref.schema_name: + table_stmt = table_stmt.where(CachedTable.schema_name == ref.schema_name) + table = (await db.execute(table_stmt.limit(1))).scalar_one_or_none() + if table is None: + return None, None + if ref.ref_kind == REF_TABLE or not ref.column_name: + return table.id, None + column = ( + await db.execute( + select(CachedColumn) + .where( + CachedColumn.table_id == table.id, + CachedColumn.column_name == ref.column_name, + ) + .limit(1) + ) + ).scalar_one_or_none() + return table.id, (column.id if column else None) + + +async def recompute_for_artifact( + db: AsyncSession, + ctx: AuthContext, + artifact_type: str, + artifact_id: uuid.UUID, + sql: str, + connection_id: uuid.UUID, + *, + connector_type: str | None = None, +) -> int: + """Re-derive an artifact's lineage edges from its SQL. Best-effort; never raises. + + Returns the number of edges written (0 on parse failure / missing sqlglot). + """ + try: + refs = extract_refs(sql, dialect_for(connector_type)) + # Replace prior edges for this artifact. + existing = await db.execute( + select(ArtifactDependency).where( + ArtifactDependency.artifact_type == artifact_type, + ArtifactDependency.artifact_id == artifact_id, + ) + ) + for row in existing.scalars().all(): + await db.delete(row) + + for ref in refs: + table_id, column_id = await _resolve_ids(db, connection_id, ref) + db.add( + ArtifactDependency( + organization_id=ctx.organization_id, + connection_id=connection_id, + artifact_type=artifact_type, + artifact_id=artifact_id, + ref_kind=ref.ref_kind, + schema_name=ref.schema_name, + table_name=ref.table_name, + column_name=ref.column_name, + table_id=table_id, + column_id=column_id, + ) + ) + await db.flush() + return len(refs) + except Exception: + logger.exception("lineage recompute failed for %s %s", artifact_type, artifact_id) + return 0 + + +async def _connector_type(db: AsyncSession, connection_id: uuid.UUID) -> str | None: + return ( + await db.execute( + select(DatabaseConnection.connector_type).where(DatabaseConnection.id == connection_id) + ) + ).scalar_one_or_none() + + +async def recompute_saved_query(db: AsyncSession, ctx: AuthContext, saved) -> int: + """Re-derive lineage for a saved query from its pinned SQL (best-effort).""" + connector_type = await _connector_type(db, saved.connection_id) + return await recompute_for_artifact( + db, + ctx, + ARTIFACT_SAVED_QUERY, + saved.id, + saved.pinned_sql, + saved.connection_id, + connector_type=connector_type, + ) + + +async def recompute_metric(db: AsyncSession, ctx: AuthContext, metric) -> int: + """Re-derive lineage for a metric from its SQL expression (best-effort).""" + connector_type = await _connector_type(db, metric.connection_id) + return await recompute_for_artifact( + db, + ctx, + ARTIFACT_METRIC, + metric.id, + metric.sql_expression, + metric.connection_id, + connector_type=connector_type, + ) + + +async def refs_for_artifact( + db: AsyncSession, artifact_type: str, artifact_id: uuid.UUID +) -> list[ArtifactDependency]: + """What this artifact touches.""" + result = await db.execute( + select(ArtifactDependency) + .where( + ArtifactDependency.artifact_type == artifact_type, + ArtifactDependency.artifact_id == artifact_id, + ) + .order_by(ArtifactDependency.table_name, ArtifactDependency.column_name) + ) + return list(result.scalars().all()) + + +async def dependents_of( + db: AsyncSession, + connection_id: uuid.UUID, + table_name: str, + column_name: str | None = None, +) -> list[ArtifactDependency]: + """What artifacts depend on a given table (optionally a specific column).""" + stmt = select(ArtifactDependency).where( + ArtifactDependency.connection_id == connection_id, + ArtifactDependency.table_name == table_name, + ) + if column_name: + stmt = stmt.where(ArtifactDependency.column_name == column_name) + result = await db.execute(stmt.order_by(ArtifactDependency.artifact_type)) + return list(result.scalars().all()) diff --git a/backend/app/services/versioning_service.py b/backend/app/services/versioning_service.py new file mode 100644 index 0000000..ee8a28a --- /dev/null +++ b/backend/app/services/versioning_service.py @@ -0,0 +1,271 @@ +"""Semantic certification + versioning service (Phase 3 — Milestone 1). + +Single home for the trust lifecycle shared by metrics, glossary terms, sample +queries, and saved queries. Each entity carries ``status`` +(draft|in_review|certified|deprecated) and an integer ``version``; every content +edit and status transition appends a :class:`SemanticVersion` snapshot so the +entity has a changelog with the reviewer and reason. + +Role gate: ``in_review`` and revert-to-``draft`` require *editor*; ``certified`` +and ``deprecated`` require *admin* (the trust gate). Certifying validates the +entity's SQL (read-only blocklist + a best-effort sqlglot parse). +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.auth import AuthContext +from app.core.exceptions import AppError, ValidationError +from app.db.models.membership import ROLE_ADMIN, ROLE_EDITOR +from app.db.models.semantic_version import ( + ENTITY_GLOSSARY, + ENTITY_METRIC, + ENTITY_SAMPLE_QUERY, + ENTITY_SAVED_QUERY, + ENTITY_TYPES, + SemanticVersion, +) +from app.utils.sql_sanitizer import check_sql_safety + +# Lifecycle states. +STATUS_DRAFT = "draft" +STATUS_IN_REVIEW = "in_review" +STATUS_CERTIFIED = "certified" +STATUS_DEPRECATED = "deprecated" +STATUSES = (STATUS_DRAFT, STATUS_IN_REVIEW, STATUS_CERTIFIED, STATUS_DEPRECATED) + +# Allowed source -> target transitions. (draft -> certified is an admin fast-path.) +_ALLOWED_TRANSITIONS: dict[str, set[str]] = { + STATUS_DRAFT: {STATUS_IN_REVIEW, STATUS_CERTIFIED}, + STATUS_IN_REVIEW: {STATUS_DRAFT, STATUS_CERTIFIED}, + STATUS_CERTIFIED: {STATUS_DRAFT, STATUS_DEPRECATED}, + STATUS_DEPRECATED: {STATUS_DRAFT}, +} + +# Minimum role required to move *to* a given status. +_ROLE_FOR_TARGET: dict[str, str] = { + STATUS_DRAFT: ROLE_EDITOR, + STATUS_IN_REVIEW: ROLE_EDITOR, + STATUS_CERTIFIED: ROLE_ADMIN, + STATUS_DEPRECATED: ROLE_ADMIN, +} + +# Content fields captured in each entity's snapshot, per type. +_SNAPSHOT_FIELDS: dict[str, tuple[str, ...]] = { + ENTITY_METRIC: ( + "metric_name", + "display_name", + "description", + "sql_expression", + "aggregation_type", + "related_tables", + "dimensions", + "filters", + ), + ENTITY_GLOSSARY: ( + "term", + "definition", + "sql_expression", + "related_tables", + "related_columns", + "examples", + ), + ENTITY_SAMPLE_QUERY: ( + "natural_language", + "sql_query", + "description", + "tags", + "is_validated", + ), + ENTITY_SAVED_QUERY: ( + "name", + "description", + "nl_question", + "pinned_sql", + "params", + "is_public", + ), +} + +# The SQL-bearing field validated on certification, per type. +_SQL_FIELD: dict[str, str] = { + ENTITY_METRIC: "sql_expression", + ENTITY_GLOSSARY: "sql_expression", + ENTITY_SAMPLE_QUERY: "sql_query", + ENTITY_SAVED_QUERY: "pinned_sql", +} + + +def _require_known_type(entity_type: str) -> None: + if entity_type not in ENTITY_TYPES: + raise AppError(f"Unknown semantic entity type '{entity_type}'.", status_code=400) + + +def serialize(entity_type: str, entity: Any) -> dict[str, Any]: + """Snapshot an entity's content fields into a JSON-friendly dict.""" + _require_known_type(entity_type) + snap: dict[str, Any] = {} + for field in _SNAPSHOT_FIELDS[entity_type]: + value = getattr(entity, field, None) + # ARRAY columns may come back as None; normalize for stable diffs. + snap[field] = value + return snap + + +async def snapshot( + db: AsyncSession, + ctx: AuthContext, + entity_type: str, + entity: Any, + *, + reason: str | None = None, +) -> SemanticVersion: + """Append a :class:`SemanticVersion` row capturing the entity's current state.""" + _require_known_type(entity_type) + row = SemanticVersion( + organization_id=entity.organization_id, + connection_id=entity.connection_id, + entity_type=entity_type, + entity_id=entity.id, + version=entity.version, + status=entity.status, + snapshot=serialize(entity_type, entity), + change_reason=reason, + changed_by_id=ctx.user_id, + ) + db.add(row) + await db.flush() + return row + + +async def record_edit( + db: AsyncSession, + ctx: AuthContext, + entity_type: str, + entity: Any, + *, + reason: str | None = None, +) -> SemanticVersion: + """Bump the entity's version and snapshot the new state (call after a content edit).""" + entity.version = (entity.version or 1) + 1 + row = await snapshot(db, ctx, entity_type, entity, reason=reason or "edited") + # The UPDATE expires the server-side ``onupdate`` timestamp; refresh so the + # response serializer doesn't trigger a lazy load outside the async context. + await db.refresh(entity) + return row + + +def _validate_sql_for_certify(entity_type: str, entity: Any) -> None: + """Lightweight pre-certification SQL check: read-only blocklist + parse.""" + sql = getattr(entity, _SQL_FIELD[entity_type], None) + if not sql or not str(sql).strip(): + return + issues = check_sql_safety(str(sql)) + if issues: + raise ValidationError( + "Cannot certify: SQL failed the read-only safety check — " + "; ".join(issues) + ) + # Best-effort parse for full statements (fragments like glossary expressions + # are skipped). A parse error blocks certification with a clear message. + stripped = str(sql).lstrip().lower() + if entity_type in (ENTITY_SAVED_QUERY, ENTITY_SAMPLE_QUERY) or stripped.startswith( + ("select", "with") + ): + try: + import sqlglot + + sqlglot.parse_one(str(sql)) + except ImportError: + # sqlglot not installed — skip parse validation, keep the blocklist guard. + pass + except Exception as exc: # sqlglot.errors.ParseError and friends + raise ValidationError(f"Cannot certify: SQL did not parse — {exc}") from exc + + +async def transition_status( + db: AsyncSession, + ctx: AuthContext, + entity_type: str, + entity: Any, + new_status: str, + *, + reason: str | None = None, +) -> Any: + """Validate + apply a status transition, stamping certification and snapshotting.""" + _require_known_type(entity_type) + current = entity.status or STATUS_DRAFT + if new_status not in STATUSES: + raise ValidationError(f"Unknown status '{new_status}'.") + if new_status == current: + raise ValidationError(f"Entity is already '{current}'.") + if new_status not in _ALLOWED_TRANSITIONS.get(current, set()): + raise ValidationError(f"Cannot transition from '{current}' to '{new_status}'.") + + ctx.require_role(_ROLE_FOR_TARGET[new_status]) + + if new_status == STATUS_CERTIFIED: + _validate_sql_for_certify(entity_type, entity) + entity.certified_by_id = ctx.user_id + entity.certified_at = datetime.now(UTC) + elif new_status == STATUS_DRAFT: + # Re-opening clears the certification stamp. + entity.certified_by_id = None + entity.certified_at = None + + entity.status = new_status + await snapshot(db, ctx, entity_type, entity, reason=reason or f"status → {new_status}") + # Refresh server-side ``onupdate`` columns expired by the UPDATE (see record_edit). + await db.refresh(entity) + return entity + + +async def list_versions( + db: AsyncSession, entity_type: str, entity_id: uuid.UUID +) -> list[SemanticVersion]: + """Return the changelog for an entity, newest first.""" + _require_known_type(entity_type) + result = await db.execute( + select(SemanticVersion) + .where( + SemanticVersion.entity_type == entity_type, + SemanticVersion.entity_id == entity_id, + ) + .order_by(SemanticVersion.created_at.desc()) + ) + return list(result.scalars().all()) + + +async def get_version( + db: AsyncSession, entity_type: str, entity_id: uuid.UUID, version: int +) -> SemanticVersion | None: + """Return the most recent snapshot row at a given version number.""" + _require_known_type(entity_type) + result = await db.execute( + select(SemanticVersion) + .where( + SemanticVersion.entity_type == entity_type, + SemanticVersion.entity_id == entity_id, + SemanticVersion.version == version, + ) + .order_by(SemanticVersion.created_at.desc()) + .limit(1) + ) + return result.scalar_one_or_none() + + +def diff(before: dict[str, Any] | None, after: dict[str, Any] | None) -> dict[str, dict[str, Any]]: + """Field-level diff between two snapshots: {field: {before, after}} for changed fields.""" + before = before or {} + after = after or {} + changed: dict[str, dict[str, Any]] = {} + for key in set(before) | set(after): + b, a = before.get(key), after.get(key) + if b != a: + changed[key] = {"before": b, "after": a} + return changed diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d91af8e..6f6dc9b 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -34,6 +34,9 @@ databricks = [ export = [ "openpyxl>=3.1", ] +lineage = [ + "sqlglot>=25.0", +] observability = [ "structlog>=24.0", "prometheus-client>=0.20", diff --git a/backend/tests/test_catalog_service.py b/backend/tests/test_catalog_service.py new file mode 100644 index 0000000..4556d39 --- /dev/null +++ b/backend/tests/test_catalog_service.py @@ -0,0 +1,35 @@ +"""Unit tests for catalog_service ranking + score combination (no DB).""" + +from app.services import catalog_service as svc + + +def test_combine_certified_boost(): + plain, _ = svc._combine(0.5, 0.5, None) + certified, _ = svc._combine(0.5, 0.5, "certified") + assert certified == plain + svc._CERT_BOOST + + +def test_combine_reason_picks_dominant_signal(): + _, reason = svc._combine(0.9, 0.1, None) + assert reason == "embedding" + _, reason = svc._combine(0.1, 0.9, None) + assert reason == "keyword" + + +def test_rank_certified_first_then_score(): + hits = [ + svc.CatalogHit(type="metric", id="1", name="low", status="draft", score=0.9), + svc.CatalogHit(type="metric", id="2", name="cert", status="certified", score=0.2), + svc.CatalogHit(type="metric", id="3", name="mid", status="draft", score=0.5), + ] + ranked = svc.rank_hits(hits, limit=10) + assert [h.id for h in ranked] == ["2", "1", "3"] # certified first, then by score desc + + +def test_rank_respects_limit(): + hits = [svc.CatalogHit(type="table", id=str(i), name=f"t{i}", score=i) for i in range(5)] + assert len(svc.rank_hits(hits, limit=2)) == 2 + + +def test_all_types_present(): + assert set(svc.ALL_TYPES) >= {"table", "column", "metric", "glossary", "saved_query"} diff --git a/backend/tests/test_lineage_service.py b/backend/tests/test_lineage_service.py new file mode 100644 index 0000000..002fa80 --- /dev/null +++ b/backend/tests/test_lineage_service.py @@ -0,0 +1,57 @@ +"""Unit tests for lineage_service.extract_refs (pure sqlglot parsing, no DB).""" + +import pytest + +from app.db.models.artifact_dependency import REF_COLUMN, REF_TABLE +from app.services import lineage_service as svc + +# extract_refs degrades to a no-op without sqlglot (the optional [lineage] extra); +# these tests assert the *populated* path, so skip the module when it's absent. +pytest.importorskip("sqlglot") + + +def _tables(refs): + return {r.table_name for r in refs if r.ref_kind == REF_TABLE} + + +def _cols(refs): + return {(r.table_name, r.column_name) for r in refs if r.ref_kind == REF_COLUMN} + + +def test_single_table_columns_attributed(): + refs = svc.extract_refs("SELECT a, b FROM exposures", "postgres") + assert _tables(refs) == {"exposures"} + assert _cols(refs) == {("exposures", "a"), ("exposures", "b")} + + +def test_join_with_aliases_resolves_qualifiers(): + sql = "SELECT e.id, c.name FROM exposures e JOIN counterparties c ON e.cp_id = c.id" + refs = svc.extract_refs(sql, "postgres") + assert _tables(refs) == {"exposures", "counterparties"} + cols = _cols(refs) + assert ("exposures", "id") in cols + assert ("counterparties", "name") in cols + assert ("counterparties", "id") in cols + + +def test_schema_qualified_table(): + refs = svc.extract_refs("SELECT * FROM public.facilities", "postgres") + table_refs = [r for r in refs if r.ref_kind == REF_TABLE] + assert table_refs[0].table_name == "facilities" + assert table_refs[0].schema_name == "public" + + +def test_unparseable_sql_returns_empty(): + assert svc.extract_refs("this is not sql ;;;(", "postgres") == [] + + +def test_empty_sql_returns_empty(): + assert svc.extract_refs("", "postgres") == [] + assert svc.extract_refs(" ", "postgres") == [] + + +def test_dialect_mapping(): + assert svc.dialect_for("postgresql") == "postgres" + assert svc.dialect_for("bigquery") == "bigquery" + assert svc.dialect_for(None) is None + assert svc.dialect_for("unknown-db") is None diff --git a/backend/tests/test_versioning_service.py b/backend/tests/test_versioning_service.py new file mode 100644 index 0000000..17d5547 --- /dev/null +++ b/backend/tests/test_versioning_service.py @@ -0,0 +1,152 @@ +"""Unit tests for versioning_service: transitions, role gating, cert stamping (no DB).""" + +import uuid +from types import SimpleNamespace + +import pytest + +from app.core.auth import AuthContext +from app.core.exceptions import AuthorizationError, ValidationError +from app.db.models.membership import ROLE_ADMIN, ROLE_EDITOR, ROLE_VIEWER +from app.db.models.semantic_version import ENTITY_GLOSSARY, ENTITY_METRIC +from app.services import versioning_service as svc + + +class FakeSession: + """Minimal stand-in for AsyncSession used by snapshot().""" + + def __init__(self) -> None: + self.added: list = [] + + def add(self, obj) -> None: + self.added.append(obj) + + async def flush(self) -> None: + pass + + async def refresh(self, obj, attrs=None) -> None: + pass + + +def _ctx(role=ROLE_EDITOR) -> AuthContext: + return AuthContext( + user=SimpleNamespace(id=uuid.uuid4()), + organization_id=uuid.uuid4(), + workspace_id=uuid.uuid4(), + role=role, + ) + + +def _metric(status="draft", version=1, sql="SELECT sum(amount) FROM exposures") -> SimpleNamespace: + return SimpleNamespace( + id=uuid.uuid4(), + organization_id=uuid.uuid4(), + connection_id=uuid.uuid4(), + status=status, + version=version, + certified_by_id=None, + certified_at=None, + metric_name="ecl", + display_name="ECL", + description="expected credit loss", + sql_expression=sql, + aggregation_type="sum", + related_tables=["exposures"], + dimensions=[], + filters={}, + ) + + +# --------------------------------------------------------------------------- # +# Transition validity +# --------------------------------------------------------------------------- # +async def test_draft_to_in_review_ok_for_editor(): + db, ctx, m = FakeSession(), _ctx(ROLE_EDITOR), _metric() + await svc.transition_status(db, ctx, ENTITY_METRIC, m, "in_review") + assert m.status == "in_review" + assert len(db.added) == 1 # one snapshot row + + +async def test_same_status_rejected(): + db, ctx, m = FakeSession(), _ctx(), _metric(status="draft") + with pytest.raises(ValidationError): + await svc.transition_status(db, ctx, ENTITY_METRIC, m, "draft") + + +async def test_illegal_transition_rejected(): + db, ctx, m = FakeSession(), _ctx(ROLE_ADMIN), _metric(status="draft") + with pytest.raises(ValidationError): + await svc.transition_status(db, ctx, ENTITY_METRIC, m, "deprecated") + + +# --------------------------------------------------------------------------- # +# Role gating +# --------------------------------------------------------------------------- # +async def test_certify_requires_admin(): + db, ctx, m = FakeSession(), _ctx(ROLE_EDITOR), _metric(status="in_review") + with pytest.raises(AuthorizationError): + await svc.transition_status(db, ctx, ENTITY_METRIC, m, "certified") + + +async def test_in_review_requires_editor(): + db, ctx, m = FakeSession(), _ctx(ROLE_VIEWER), _metric(status="draft") + with pytest.raises(AuthorizationError): + await svc.transition_status(db, ctx, ENTITY_METRIC, m, "in_review") + + +# --------------------------------------------------------------------------- # +# Certification stamping +# --------------------------------------------------------------------------- # +async def test_certify_stamps_owner_and_time(): + db, ctx, m = FakeSession(), _ctx(ROLE_ADMIN), _metric(status="in_review") + await svc.transition_status(db, ctx, ENTITY_METRIC, m, "certified") + assert m.status == "certified" + assert m.certified_by_id == ctx.user_id + assert m.certified_at is not None + + +async def test_revert_to_draft_clears_certification(): + db, ctx, m = FakeSession(), _ctx(ROLE_ADMIN), _metric(status="certified") + m.certified_by_id = uuid.uuid4() + m.certified_at = object() + await svc.transition_status(db, ctx, ENTITY_METRIC, m, "draft") + assert m.status == "draft" + assert m.certified_by_id is None + assert m.certified_at is None + + +async def test_certify_blocks_unsafe_sql(): + db, ctx = FakeSession(), _ctx(ROLE_ADMIN) + m = _metric(status="in_review", sql="SELECT 1; DROP TABLE exposures") + with pytest.raises(ValidationError): + await svc.transition_status(db, ctx, ENTITY_METRIC, m, "certified") + + +# --------------------------------------------------------------------------- # +# Snapshots, edits, diff +# --------------------------------------------------------------------------- # +async def test_record_edit_bumps_version_and_snapshots(): + db, ctx, m = FakeSession(), _ctx(), _metric(version=3) + await svc.record_edit(db, ctx, ENTITY_METRIC, m) + assert m.version == 4 + assert db.added[0].version == 4 + assert db.added[0].snapshot["metric_name"] == "ecl" + + +def test_serialize_only_content_fields(): + snap = svc.serialize( + ENTITY_GLOSSARY, _metric() + ) # reuse namespace; glossary fields read via getattr + assert set(snap) == { + "term", + "definition", + "sql_expression", + "related_tables", + "related_columns", + "examples", + } + + +def test_diff_reports_changed_fields(): + d = svc.diff({"a": 1, "b": 2}, {"a": 1, "b": 3, "c": 4}) + assert d == {"b": {"before": 2, "after": 3}, "c": {"before": None, "after": 4}} diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index b97a496..dbff764 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -12,6 +12,7 @@ import { DashboardsPage } from './pages/DashboardsPage'; import { DashboardDetailPage } from './pages/DashboardDetailPage'; import { DictionaryPage } from './pages/DictionaryPage'; import { KnowledgePage } from './pages/KnowledgePage'; +import { CatalogPage } from './pages/CatalogPage'; import { HistoryPage } from './pages/HistoryPage'; export default function App() { @@ -34,6 +35,7 @@ export default function App() { } /> } /> } /> + } /> } /> diff --git a/frontend/src/api/catalogApi.ts b/frontend/src/api/catalogApi.ts new file mode 100644 index 0000000..9787e81 --- /dev/null +++ b/frontend/src/api/catalogApi.ts @@ -0,0 +1,35 @@ +import { api } from './client'; +import type { CatalogFacets, CatalogHit, LineageRef } from '../types/api'; + +export interface CatalogSearchParams { + q?: string; + types?: string[]; + status?: string; + owner?: string; + schema?: string; + limit?: number; +} + +export const catalogApi = { + search: (connectionId: string, params: CatalogSearchParams) => + api + .get(`/connections/${connectionId}/catalog/search`, { + params: { + q: params.q ?? '', + types: params.types?.length ? params.types.join(',') : undefined, + status: params.status || undefined, + owner: params.owner || undefined, + schema: params.schema || undefined, + limit: params.limit ?? 50, + }, + }) + .then((r) => r.data), + facets: (connectionId: string) => + api.get(`/connections/${connectionId}/catalog/facets`).then((r) => r.data), + lineageImpact: (connectionId: string, table: string, column?: string) => + api + .get(`/connections/${connectionId}/catalog/lineage`, { + params: { table, column: column || undefined }, + }) + .then((r) => r.data), +}; diff --git a/frontend/src/api/glossaryApi.ts b/frontend/src/api/glossaryApi.ts index ca909f4..78758b6 100644 --- a/frontend/src/api/glossaryApi.ts +++ b/frontend/src/api/glossaryApi.ts @@ -1,5 +1,11 @@ import { api } from './client'; -import type { GlossaryTerm, MetricDefinition, DictionaryEntry } from '../types/api'; +import type { + GlossaryTerm, + LineageRef, + MetricDefinition, + DictionaryEntry, + SemanticVersion, +} from '../types/api'; export const glossaryApi = { list: (connectionId: string) => @@ -10,6 +16,14 @@ export const glossaryApi = { api.put(`/connections/${connectionId}/glossary/${termId}`, data).then(r => r.data), delete: (connectionId: string, termId: string) => api.delete(`/connections/${connectionId}/glossary/${termId}`), + transitionStatus: (connectionId: string, termId: string, status: string, reason?: string) => + api + .post(`/connections/${connectionId}/glossary/${termId}/status`, { status, reason }) + .then(r => r.data), + versions: (connectionId: string, termId: string) => + api + .get(`/connections/${connectionId}/glossary/${termId}/versions`) + .then(r => r.data), }; export const metricsApi = { @@ -21,6 +35,18 @@ export const metricsApi = { api.put(`/connections/${connectionId}/metrics/${metricId}`, data).then(r => r.data), delete: (connectionId: string, metricId: string) => api.delete(`/connections/${connectionId}/metrics/${metricId}`), + transitionStatus: (connectionId: string, metricId: string, status: string, reason?: string) => + api + .post(`/connections/${connectionId}/metrics/${metricId}/status`, { status, reason }) + .then(r => r.data), + versions: (connectionId: string, metricId: string) => + api + .get(`/connections/${connectionId}/metrics/${metricId}/versions`) + .then(r => r.data), + lineage: (connectionId: string, metricId: string) => + api + .get(`/connections/${connectionId}/metrics/${metricId}/lineage`) + .then(r => r.data), }; export const dictionaryApi = { diff --git a/frontend/src/api/savedQueriesApi.ts b/frontend/src/api/savedQueriesApi.ts index 1f216df..093b70a 100644 --- a/frontend/src/api/savedQueriesApi.ts +++ b/frontend/src/api/savedQueriesApi.ts @@ -3,8 +3,10 @@ import type { Chart, ChartType, ChartConfig, + LineageRef, SavedQuery, SavedQueryRunResult, + SemanticVersion, } from '../types/api'; const base = (connectionId: string) => `/connections/${connectionId}/saved-queries`; @@ -33,6 +35,14 @@ export const savedQueriesApi = { .then((r) => r.data), exportUrl: (connectionId: string, id: string, format: 'csv' | 'json' | 'xlsx') => `${api.defaults.baseURL}${base(connectionId)}/${id}/export?format=${format}`, + transitionStatus: (connectionId: string, id: string, status: string, reason?: string) => + api + .post(`${base(connectionId)}/${id}/status`, { status, reason }) + .then((r) => r.data), + versions: (connectionId: string, id: string) => + api.get(`${base(connectionId)}/${id}/versions`).then((r) => r.data), + lineage: (connectionId: string, id: string) => + api.get(`${base(connectionId)}/${id}/lineage`).then((r) => r.data), }; export const chartsApi = { diff --git a/frontend/src/components/common/CertificationBadge.tsx b/frontend/src/components/common/CertificationBadge.tsx new file mode 100644 index 0000000..ac3323d --- /dev/null +++ b/frontend/src/components/common/CertificationBadge.tsx @@ -0,0 +1,23 @@ +import { Badge } from '@mantine/core'; + +const STATUS_COLOR: Record = { + draft: 'gray', + in_review: 'yellow', + certified: 'green', + deprecated: 'red', +}; + +const STATUS_LABEL: Record = { + draft: 'draft', + in_review: 'in review', + certified: 'certified', + deprecated: 'deprecated', +}; + +export function CertificationBadge({ status, size = 'sm' }: { status: string; size?: string }) { + return ( + + {STATUS_LABEL[status] ?? status} + + ); +} diff --git a/frontend/src/components/common/StatusActions.tsx b/frontend/src/components/common/StatusActions.tsx new file mode 100644 index 0000000..3f64a25 --- /dev/null +++ b/frontend/src/components/common/StatusActions.tsx @@ -0,0 +1,71 @@ +import { ActionIcon, Menu, Tooltip } from '@mantine/core'; +import { IconDotsVertical } from '@tabler/icons-react'; +import { useAuth } from '../../context/auth'; +import type { Role } from '../../types/auth'; + +// Mirror of versioning_service._ALLOWED_TRANSITIONS (source -> targets). +const ALLOWED: Record = { + draft: ['in_review', 'certified'], + in_review: ['draft', 'certified'], + certified: ['draft', 'deprecated'], + deprecated: ['draft'], +}; + +// Minimum role required to move *to* a status (mirror of _ROLE_FOR_TARGET). +const ROLE_FOR_TARGET: Record = { + draft: 'editor', + in_review: 'editor', + certified: 'admin', + deprecated: 'admin', +}; + +const ROLE_RANK: Record = { viewer: 1, editor: 2, admin: 3 }; + +const ACTION_LABEL: Record = { + draft: 'Revert to draft', + in_review: 'Submit for review', + certified: 'Certify', + deprecated: 'Deprecate', +}; + +export interface StatusActionsProps { + status: string; + onTransition: (target: string, reason?: string) => void; +} + +/** A role-aware menu offering the valid certification-lifecycle transitions. */ +export function StatusActions({ status, onTransition }: StatusActionsProps) { + const { role } = useAuth(); + const myRank = role ? ROLE_RANK[role] : 0; + const targets = (ALLOWED[status] ?? []).filter( + (t) => myRank >= ROLE_RANK[ROLE_FOR_TARGET[t]], + ); + + if (targets.length === 0) return null; + + return ( + + + + + + + + + + Certification + {targets.map((t) => ( + { + const reason = window.prompt(`${ACTION_LABEL[t]} — reason (optional):`) ?? undefined; + onTransition(t, reason || undefined); + }} + > + {ACTION_LABEL[t] ?? t} + + ))} + + + ); +} diff --git a/frontend/src/components/common/VersionHistory.tsx b/frontend/src/components/common/VersionHistory.tsx new file mode 100644 index 0000000..27e9a58 --- /dev/null +++ b/frontend/src/components/common/VersionHistory.tsx @@ -0,0 +1,65 @@ +import { Code, Modal, Text, Timeline, Loader, Group } from '@mantine/core'; +import { useQuery } from '@tanstack/react-query'; +import { CertificationBadge } from './CertificationBadge'; +import type { SemanticVersion } from '../../types/api'; + +export interface VersionHistoryProps { + opened: boolean; + onClose: () => void; + title: string; + /** Fetches the version changelog; only called while the modal is open. */ + queryKey: unknown[]; + fetchVersions: () => Promise; +} + +/** A modal rendering an entity's certification/version changelog as a timeline. */ +export function VersionHistory({ + opened, + onClose, + title, + queryKey, + fetchVersions, +}: VersionHistoryProps) { + const { data, isLoading } = useQuery({ + queryKey, + queryFn: fetchVersions, + enabled: opened, + }); + + return ( + + {isLoading && ( + + + + )} + {data && data.length === 0 && ( + + No version history yet. Edits and certification changes appear here. + + )} + {data && data.length > 0 && ( + + {data.map((v) => ( + + + + + {new Date(v.created_at).toLocaleString()} + + + {v.change_reason && ( + + {v.change_reason} + + )} + + {JSON.stringify(v.snapshot, null, 2)} + + + ))} + + )} + + ); +} diff --git a/frontend/src/components/layout/AppLayout.tsx b/frontend/src/components/layout/AppLayout.tsx index 7937e31..a95df3d 100644 --- a/frontend/src/components/layout/AppLayout.tsx +++ b/frontend/src/components/layout/AppLayout.tsx @@ -21,6 +21,7 @@ import { IconUserCircle, IconBookmark, IconLayoutDashboard, + IconBook2, } from '@tabler/icons-react'; import { Outlet, useLocation, useNavigate } from 'react-router-dom'; import { EmbeddingStatusBanner } from '../common/EmbeddingStatusBanner'; @@ -36,6 +37,7 @@ const NAV_ITEMS = [ { label: 'Metrics', path: '/metrics', icon: IconChartBar }, { label: 'Dictionary', path: '/dictionary', icon: IconVocabulary }, { label: 'Knowledge', path: '/knowledge', icon: IconFileText }, + { label: 'Catalog', path: '/catalog', icon: IconBook2 }, { label: 'History', path: '/history', icon: IconHistory }, ]; diff --git a/frontend/src/hooks/useCatalog.ts b/frontend/src/hooks/useCatalog.ts new file mode 100644 index 0000000..de50365 --- /dev/null +++ b/frontend/src/hooks/useCatalog.ts @@ -0,0 +1,33 @@ +import { useQuery } from '@tanstack/react-query'; +import { catalogApi, type CatalogSearchParams } from '../api/catalogApi'; + +export function useCatalogSearch( + connectionId: string | undefined, + params: CatalogSearchParams, +) { + return useQuery({ + queryKey: ['catalog', 'search', connectionId, params], + queryFn: () => catalogApi.search(connectionId!, params), + enabled: !!connectionId, + }); +} + +export function useCatalogFacets(connectionId: string | undefined) { + return useQuery({ + queryKey: ['catalog', 'facets', connectionId], + queryFn: () => catalogApi.facets(connectionId!), + enabled: !!connectionId, + }); +} + +export function useCatalogLineage( + connectionId: string | undefined, + table: string | undefined, + column?: string, +) { + return useQuery({ + queryKey: ['catalog', 'lineage', connectionId, table, column], + queryFn: () => catalogApi.lineageImpact(connectionId!, table!, column), + enabled: !!connectionId && !!table, + }); +} diff --git a/frontend/src/hooks/useSavedQueries.ts b/frontend/src/hooks/useSavedQueries.ts index 37b10f9..e93800f 100644 --- a/frontend/src/hooks/useSavedQueries.ts +++ b/frontend/src/hooks/useSavedQueries.ts @@ -43,6 +43,15 @@ export function useCloneSavedQuery(connectionId: string) { }); } +export function useTransitionSavedQueryStatus(connectionId: string) { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ id, status, reason }: { id: string; status: string; reason?: string }) => + savedQueriesApi.transitionStatus(connectionId, id, status, reason), + onSuccess: () => qc.invalidateQueries({ queryKey: ['savedQueries', connectionId] }), + }); +} + export function useRunSavedQuery(connectionId: string) { return useMutation({ mutationFn: ({ diff --git a/frontend/src/pages/CatalogPage.tsx b/frontend/src/pages/CatalogPage.tsx new file mode 100644 index 0000000..00e2e9a --- /dev/null +++ b/frontend/src/pages/CatalogPage.tsx @@ -0,0 +1,329 @@ +import { useEffect, useState } from 'react'; +import { + Alert, + Badge, + Card, + Checkbox, + Drawer, + Grid, + Group, + Loader, + Paper, + Select, + Stack, + Text, + TextInput, + Title, +} from '@mantine/core'; +import { useQuery } from '@tanstack/react-query'; +import { IconSearch } from '@tabler/icons-react'; +import { useConnections } from '../hooks/useConnections'; +import { useCatalogFacets, useCatalogSearch } from '../hooks/useCatalog'; +import { catalogApi } from '../api/catalogApi'; +import { metricsApi } from '../api/glossaryApi'; +import { savedQueriesApi } from '../api/savedQueriesApi'; +import { CertificationBadge } from '../components/common/CertificationBadge'; +import type { CatalogHit, LineageRef } from '../types/api'; + +const TYPE_COLOR: Record = { + table: 'blue', + column: 'cyan', + metric: 'grape', + glossary: 'teal', + sample_query: 'orange', + saved_query: 'indigo', + knowledge: 'gray', +}; + +const TYPE_LABEL: Record = { + table: 'Table', + column: 'Column', + metric: 'Metric', + glossary: 'Glossary', + sample_query: 'Sample query', + saved_query: 'Saved query', + knowledge: 'Knowledge', +}; + +const ALL_TYPES = Object.keys(TYPE_COLOR); + +export function CatalogPage() { + const [connectionId, setConnectionId] = useState(null); + const [search, setSearch] = useState(''); + const [debounced, setDebounced] = useState(''); + const [selectedTypes, setSelectedTypes] = useState([]); + const [status, setStatus] = useState(null); + const [schema, setSchema] = useState(null); + const [selected, setSelected] = useState(null); + + const { data: connections } = useConnections(); + const connOptions = connections?.map((c) => ({ value: c.id, label: c.name })) ?? []; + if (!connectionId && connOptions.length > 0) { + setConnectionId(connOptions[0].value); + } + + useEffect(() => { + const t = setTimeout(() => setDebounced(search), 300); + return () => clearTimeout(t); + }, [search]); + + const { data: facets } = useCatalogFacets(connectionId ?? undefined); + const { data: hits, isLoading } = useCatalogSearch(connectionId ?? undefined, { + q: debounced, + types: selectedTypes, + status: status ?? undefined, + schema: schema ?? undefined, + }); + + return ( + + Data Catalog + + + + +