From b5dade28b75b5b2902e29202df315f6b80e5cf9b Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Wed, 25 Feb 2026 15:20:55 -0800 Subject: [PATCH 01/12] feat: replace empty lexicons/ with atdata-lexicon submodule pinned to v0.2.1b1 Add forecast-bio/atdata-lexicon as a git submodule at lexicons/ using HTTPS URL, pinned to the v0.2.1b1 tag. Update all CI checkout steps to initialize submodules, and document the submodule in CLAUDE.md and README.md. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/ci.yml | 8 ++++++++ .github/workflows/publish.yml | 2 ++ .gitmodules | 3 +++ CLAUDE.md | 10 ++++++++++ README.md | 13 +++++++++++++ lexicons | 1 + 6 files changed, 37 insertions(+) create mode 100644 .gitmodules create mode 160000 lexicons diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 95cd225..4526e04 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,6 +15,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + submodules: true - uses: astral-sh/setup-uv@v5 with: enable-cache: true @@ -28,6 +30,8 @@ jobs: python-version: ["3.12", "3.13"] steps: - uses: actions/checkout@v4 + with: + submodules: true - uses: astral-sh/setup-uv@v5 with: enable-cache: true @@ -56,6 +60,8 @@ jobs: --health-retries 5 steps: - uses: actions/checkout@v4 + with: + submodules: true - name: Apply schema env: PGHOST: localhost @@ -105,6 +111,8 @@ jobs: --health-retries 5 steps: - uses: actions/checkout@v4 + with: + submodules: true - uses: astral-sh/setup-uv@v5 with: enable-cache: true diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4c9d7e5..48eb356 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -12,6 +12,8 @@ jobs: id-token: write steps: - uses: actions/checkout@v4 + with: + submodules: true - uses: astral-sh/setup-uv@v5 with: enable-cache: true diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..d4faf98 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "lexicons"] + path = lexicons + url = https://github.com/forecast-bio/atdata-lexicon.git diff --git a/CLAUDE.md b/CLAUDE.md index 6020760..c5b8285 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,6 +28,16 @@ uv run ruff check src/ tests/ uv run uvicorn atdata_app.main:app --reload ``` +## Lexicon Submodule + +The `lexicons/` directory is a git submodule pointing to [forecast-bio/atdata-lexicon](https://github.com/forecast-bio/atdata-lexicon), which contains the authoritative `science.alt.dataset.*` lexicon definitions. The submodule is for reference and CI validation — the Python source code does not read lexicon files at runtime. + +After cloning, initialize the submodule: + +```bash +git submodule update --init +``` + ## Architecture ### Data Flow diff --git a/README.md b/README.md index 00d8bdf..33ac782 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ ATProto Network # Install dependencies uv sync --dev +# Initialize the lexicon submodule +git submodule update --init + # Set up PostgreSQL (schema auto-applies on startup) createdb atdata_app @@ -135,6 +138,16 @@ uv run ruff check src/ tests/ Tests mock all external dependencies (database, HTTP, identity resolution) using `unittest.mock.AsyncMock`. HTTP endpoint tests use httpx `ASGITransport` for in-process testing without a running server. +### Lexicon Definitions + +The `lexicons/` directory is a [git submodule](https://github.com/forecast-bio/atdata-lexicon) containing the authoritative `science.alt.dataset.*` lexicon schemas. Initialize it with: + +```bash +git submodule update --init +``` + +The lexicons are for reference and CI validation. The Python source code uses hardcoded NSID constants and does not read the lexicon JSON files at runtime. + ## License MIT diff --git a/lexicons b/lexicons new file mode 160000 index 0000000..ece47ae --- /dev/null +++ b/lexicons @@ -0,0 +1 @@ +Subproject commit ece47ae2158c3226bdf2ac280695db32ea5ad336 From c9108b19b6c62830257b6bb94d06caf1fa059667 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 00:53:11 -0800 Subject: [PATCH 02/12] feat: add sendInteractions XRPC procedure for usage telemetry Add science.alt.dataset.sendInteractions POST endpoint that accepts batches of download, citation, and derivative interaction events. Validates AT-URIs, interaction types, and optional ISO 8601 timestamps, then fires analytics events via the existing fire-and-forget infrastructure. Also extends getEntryStats to surface per-entry interaction counts (downloads, citations, derivatives) alongside existing view/search metrics. Co-Authored-By: Claude Opus 4.6 --- src/atdata_app/database.py | 8 +- src/atdata_app/models.py | 3 + src/atdata_app/xrpc/procedures.py | 79 +++++++++++ tests/test_analytics.py | 222 ++++++++++++++++++++++++++++++ 4 files changed, 311 insertions(+), 1 deletion(-) diff --git a/src/atdata_app/database.py b/src/atdata_app/database.py index b0ea925..748e2c3 100644 --- a/src/atdata_app/database.py +++ b/src/atdata_app/database.py @@ -720,7 +720,10 @@ async def query_entry_stats( """ SELECT COUNT(*) FILTER (WHERE event_type = 'view_entry') AS views, - COUNT(*) FILTER (WHERE event_type = 'search') AS search_appearances + COUNT(*) FILTER (WHERE event_type = 'search') AS search_appearances, + COUNT(*) FILTER (WHERE event_type = 'download') AS downloads, + COUNT(*) FILTER (WHERE event_type = 'citation') AS citations, + COUNT(*) FILTER (WHERE event_type = 'derivative') AS derivatives FROM analytics_events WHERE target_did = $1 AND target_rkey = $2 AND created_at >= NOW() - $3::interval @@ -732,6 +735,9 @@ async def query_entry_stats( return { "views": row["views"], "searchAppearances": row["search_appearances"], + "downloads": row["downloads"], + "citations": row["citations"], + "derivatives": row["derivatives"], "period": period, } diff --git a/src/atdata_app/models.py b/src/atdata_app/models.py index 1cde2f2..7cec0d9 100644 --- a/src/atdata_app/models.py +++ b/src/atdata_app/models.py @@ -236,4 +236,7 @@ class GetAnalyticsResponse(BaseModel): class GetEntryStatsResponse(BaseModel): views: int searchAppearances: int + downloads: int = 0 + citations: int = 0 + derivatives: int = 0 period: str diff --git a/src/atdata_app/xrpc/procedures.py b/src/atdata_app/xrpc/procedures.py index c066987..e954dc3 100644 --- a/src/atdata_app/xrpc/procedures.py +++ b/src/atdata_app/xrpc/procedures.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging +from datetime import datetime from typing import Any import httpx @@ -15,6 +16,7 @@ from atdata_app import get_resolver from atdata_app.auth import verify_service_auth from atdata_app.database import ( + fire_analytics_event, query_get_entry, query_get_schema, query_record_exists, @@ -268,3 +270,80 @@ async def publish_lens(request: Request) -> dict[str, Any]: pds, pds_token, auth.iss, "science.alt.dataset.lens", record, rkey ) return {"uri": result.get("uri"), "cid": result.get("cid")} + + +# --------------------------------------------------------------------------- +# sendInteractions +# --------------------------------------------------------------------------- + +_VALID_INTERACTION_TYPES = frozenset({"download", "citation", "derivative"}) +_MAX_INTERACTIONS_BATCH = 100 + + +def _validate_iso8601(value: str) -> None: + """Raise ValueError if *value* is not a valid ISO 8601 datetime string.""" + # datetime.fromisoformat handles the common subset we accept + datetime.fromisoformat(value) + + +@router.post("/science.alt.dataset.sendInteractions") +async def send_interactions(request: Request) -> dict[str, Any]: + pool = request.app.state.db_pool + + body = await request.json() + interactions = body.get("interactions") + if not isinstance(interactions, list): + raise HTTPException(status_code=400, detail="interactions must be an array") + + if len(interactions) > _MAX_INTERACTIONS_BATCH: + raise HTTPException( + status_code=400, + detail=f"Batch size exceeds maximum of {_MAX_INTERACTIONS_BATCH}", + ) + + for i, item in enumerate(interactions): + if not isinstance(item, dict): + raise HTTPException(status_code=400, detail=f"interactions[{i}]: must be an object") + + itype = item.get("type") + if itype not in _VALID_INTERACTION_TYPES: + raise HTTPException( + status_code=400, + detail=f"interactions[{i}]: invalid type '{itype}', " + f"must be one of: {', '.join(sorted(_VALID_INTERACTION_TYPES))}", + ) + + dataset_uri = item.get("datasetUri") + if not isinstance(dataset_uri, str): + raise HTTPException( + status_code=400, detail=f"interactions[{i}]: datasetUri is required" + ) + try: + parse_at_uri(dataset_uri) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"interactions[{i}]: invalid AT-URI: {dataset_uri}", + ) + + timestamp = item.get("timestamp") + if timestamp is not None: + if not isinstance(timestamp, str): + raise HTTPException( + status_code=400, + detail=f"interactions[{i}]: timestamp must be a string", + ) + try: + _validate_iso8601(timestamp) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"interactions[{i}]: invalid ISO 8601 timestamp: {timestamp}", + ) + + # All valid — fire analytics events + for item in interactions: + did, _collection, rkey = parse_at_uri(item["datasetUri"]) + fire_analytics_event(pool, item["type"], target_did=did, target_rkey=rkey) + + return {} diff --git a/tests/test_analytics.py b/tests/test_analytics.py index 6be9707..be53149 100644 --- a/tests/test_analytics.py +++ b/tests/test_analytics.py @@ -374,3 +374,225 @@ def test_describe_service_response_without_analytics(): recordCount={"science.alt.dataset.entry": 10}, ) assert resp.analytics is None + + +def test_get_entry_stats_response_with_interactions(): + resp = GetEntryStatsResponse( + views=10, searchAppearances=3, downloads=5, citations=2, derivatives=1, period="week" + ) + assert resp.downloads == 5 + assert resp.citations == 2 + assert resp.derivatives == 1 + + +def test_get_entry_stats_response_defaults_interactions(): + resp = GetEntryStatsResponse(views=10, searchAppearances=3, period="week") + assert resp.downloads == 0 + assert resp.citations == 0 + assert resp.derivatives == 0 + + +# --------------------------------------------------------------------------- +# sendInteractions endpoint +# --------------------------------------------------------------------------- + +_PROC = "atdata_app.xrpc.procedures" + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_valid_batch(mock_fire, config, pool): + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + { + "type": "download", + "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz", + }, + { + "type": "citation", + "datasetUri": "at://did:plc:def/science.alt.dataset.entry/4abc", + "timestamp": "2025-06-01T12:00:00Z", + }, + ] + }, + ) + + assert resp.status_code == 200 + assert resp.json() == {} + assert mock_fire.call_count == 2 + mock_fire.assert_any_call(pool, "download", target_did="did:plc:abc", target_rkey="3xyz") + mock_fire.assert_any_call(pool, "citation", target_did="did:plc:def", target_rkey="4abc") + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_empty_array(mock_fire, config, pool): + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": []}, + ) + + assert resp.status_code == 200 + assert resp.json() == {} + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_invalid_uri(mock_fire, config, pool): + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + {"type": "download", "datasetUri": "https://not-an-at-uri"}, + ] + }, + ) + + assert resp.status_code == 400 + assert "invalid AT-URI" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_invalid_type(mock_fire, config, pool): + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + { + "type": "bookmark", + "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz", + }, + ] + }, + ) + + assert resp.status_code == 400 + assert "invalid type" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_batch_size_exceeded(mock_fire, config, pool): + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + interactions = [ + {"type": "download", "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz"} + for _ in range(101) + ] + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": interactions}, + ) + + assert resp.status_code == 400 + assert "Batch size exceeds maximum" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_invalid_timestamp(mock_fire, config, pool): + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + { + "type": "download", + "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz", + "timestamp": "not-a-date", + }, + ] + }, + ) + + assert resp.status_code == 400 + assert "invalid ISO 8601 timestamp" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_missing_dataset_uri(mock_fire, config, pool): + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + {"type": "download"}, + ] + }, + ) + + assert resp.status_code == 400 + assert "datasetUri is required" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_not_an_array(mock_fire, config, pool): + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": "not-an-array"}, + ) + + assert resp.status_code == 400 + assert "interactions must be an array" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_all_three_types(mock_fire, config, pool): + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + { + "type": "download", + "datasetUri": "at://did:plc:a/science.alt.dataset.entry/1", + }, + { + "type": "citation", + "datasetUri": "at://did:plc:b/science.alt.dataset.entry/2", + }, + { + "type": "derivative", + "datasetUri": "at://did:plc:c/science.alt.dataset.entry/3", + }, + ] + }, + ) + + assert resp.status_code == 200 + assert mock_fire.call_count == 3 From 92bb19aa64fe5f489e02129d71ee200f005b0a85 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 01:01:19 -0800 Subject: [PATCH 03/12] feat: add skeleton/hydration pattern for third-party dataset indexes Implement index provider registration and the getIndexSkeleton/getIndex query endpoints following Bluesky's feed generator pattern. Third parties can register curated dataset index endpoints; the AppView fetches URI skeletons from them and hydrates entries from the local database. - Add index_providers table to schema.sql - Add upsert/query functions in database.py with COLLECTION_TABLE_MAP entry - Add row_to_index_provider serializer and response models in models.py - Add getIndexSkeleton, getIndex, listIndexes query endpoints - Add publishIndex procedure with HTTPS URL validation - Route science.alt.dataset.index records through ingestion processor - Add 18 tests covering happy paths, error cases, and ingestion Co-Authored-By: Claude Opus 4.6 --- src/atdata_app/database.py | 75 ++++ src/atdata_app/ingestion/processor.py | 2 + src/atdata_app/models.py | 30 ++ src/atdata_app/sql/schema.sql | 16 + src/atdata_app/xrpc/procedures.py | 39 ++ src/atdata_app/xrpc/queries.py | 140 +++++++ tests/test_index.py | 541 ++++++++++++++++++++++++++ 7 files changed, 843 insertions(+) create mode 100644 tests/test_index.py diff --git a/src/atdata_app/database.py b/src/atdata_app/database.py index b0ea925..eb69133 100644 --- a/src/atdata_app/database.py +++ b/src/atdata_app/database.py @@ -20,6 +20,7 @@ f"{LEXICON_NAMESPACE}.entry": "entries", f"{LEXICON_NAMESPACE}.label": "labels", f"{LEXICON_NAMESPACE}.lens": "lenses", + f"{LEXICON_NAMESPACE}.index": "index_providers", } @@ -203,6 +204,36 @@ async def upsert_lens( ) +async def upsert_index_provider( + pool: asyncpg.Pool, + did: str, + rkey: str, + cid: str | None, + record: dict[str, Any], +) -> None: + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO index_providers (did, rkey, cid, name, description, endpoint_url, + created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (did, rkey) DO UPDATE SET + cid = EXCLUDED.cid, + name = EXCLUDED.name, + description = EXCLUDED.description, + endpoint_url = EXCLUDED.endpoint_url, + indexed_at = NOW() + """, + did, + rkey, + cid, + record["name"], + record.get("description"), + record["endpointUrl"], + record.get("createdAt", ""), + ) + + async def delete_record(pool: asyncpg.Pool, table: str, did: str, rkey: str) -> None: if table not in COLLECTION_TABLE_MAP.values(): return @@ -219,6 +250,7 @@ async def delete_record(pool: asyncpg.Pool, table: str, did: str, rkey: str) -> "entries": upsert_entry, "labels": upsert_label, "lenses": upsert_lens, + "index_providers": upsert_index_provider, } @@ -449,6 +481,49 @@ async def query_list_lenses( ) +async def query_get_index_provider( + pool: asyncpg.Pool, did: str, rkey: str +) -> asyncpg.Record | None: + async with pool.acquire() as conn: + return await conn.fetchrow( + "SELECT * FROM index_providers WHERE did = $1 AND rkey = $2", did, rkey + ) + + +async def query_list_index_providers( + pool: asyncpg.Pool, + repo: str | None = None, + limit: int = 50, + cursor_did: str | None = None, + cursor_rkey: str | None = None, + cursor_indexed_at: str | None = None, +) -> list[asyncpg.Record]: + conditions: list[str] = [] + params: list[Any] = [] + idx = 1 + + if repo: + conditions.append(f"did = ${idx}") + params.append(repo) + idx += 1 + + if cursor_indexed_at and cursor_did and cursor_rkey: + conditions.append( + f"(indexed_at, did, rkey) < (${idx}, ${idx + 1}, ${idx + 2})" + ) + params.extend([datetime.fromisoformat(cursor_indexed_at), cursor_did, cursor_rkey]) + idx += 3 + + where = f"WHERE {' AND '.join(conditions)}" if conditions else "" + params.append(limit) + + async with pool.acquire() as conn: + return await conn.fetch( + f"SELECT * FROM index_providers {where} ORDER BY indexed_at DESC, did DESC, rkey DESC LIMIT ${idx}", + *params, + ) + + async def query_search_datasets( pool: asyncpg.Pool, q: str, diff --git a/src/atdata_app/ingestion/processor.py b/src/atdata_app/ingestion/processor.py index fa34c96..78ad2d0 100644 --- a/src/atdata_app/ingestion/processor.py +++ b/src/atdata_app/ingestion/processor.py @@ -56,6 +56,8 @@ async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None: await db.upsert_label(pool, did, rkey, cid, record) elif table == "lenses": await db.upsert_lens(pool, did, rkey, cid, record) + elif table == "index_providers": + await db.upsert_index_provider(pool, did, rkey, cid, record) logger.debug("Upserted %s %s/%s", collection, did, rkey) except Exception: logger.exception("Failed to upsert %s %s/%s", collection, did, rkey) diff --git a/src/atdata_app/models.py b/src/atdata_app/models.py index 1cde2f2..6d123c9 100644 --- a/src/atdata_app/models.py +++ b/src/atdata_app/models.py @@ -138,6 +138,21 @@ def row_to_label(row) -> dict[str, Any]: return d +def row_to_index_provider(row) -> dict[str, Any]: + uri = make_at_uri(row["did"], "science.alt.dataset.index", row["rkey"]) + d: dict[str, Any] = { + "uri": uri, + "cid": row["cid"], + "did": row["did"], + "name": row["name"], + "endpointUrl": row["endpoint_url"], + "createdAt": row["created_at"], + } + if row["description"]: + d["description"] = row["description"] + return d + + def row_to_lens(row) -> dict[str, Any]: uri = make_at_uri(row["did"], "science.alt.dataset.lens", row["rkey"]) getter_code = row["getter_code"] @@ -237,3 +252,18 @@ class GetEntryStatsResponse(BaseModel): views: int searchAppearances: int period: str + + +class ListIndexesResponse(BaseModel): + indexes: list[dict[str, Any]] + cursor: str | None = None + + +class IndexSkeletonResponse(BaseModel): + items: list[dict[str, Any]] + cursor: str | None = None + + +class IndexResponse(BaseModel): + items: list[dict[str, Any]] + cursor: str | None = None diff --git a/src/atdata_app/sql/schema.sql b/src/atdata_app/sql/schema.sql index 17f1729..422b379 100644 --- a/src/atdata_app/sql/schema.sql +++ b/src/atdata_app/sql/schema.sql @@ -144,6 +144,22 @@ CREATE INDEX IF NOT EXISTS idx_analytics_events_type_created CREATE INDEX IF NOT EXISTS idx_analytics_events_target ON analytics_events (target_did, target_rkey, event_type); +-- Index providers (science.alt.dataset.index) +CREATE TABLE IF NOT EXISTS index_providers ( + did TEXT NOT NULL, + rkey TEXT NOT NULL, + cid TEXT, + name TEXT NOT NULL, + description TEXT, + endpoint_url TEXT NOT NULL, + created_at TEXT NOT NULL, + indexed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (did, rkey) +); + +CREATE INDEX IF NOT EXISTS idx_index_providers_did ON index_providers (did); +CREATE INDEX IF NOT EXISTS idx_index_providers_indexed_at ON index_providers (indexed_at DESC); + -- Pre-aggregated analytics counters (avoids expensive COUNT on events table) CREATE TABLE IF NOT EXISTS analytics_counters ( target_did TEXT NOT NULL, diff --git a/src/atdata_app/xrpc/procedures.py b/src/atdata_app/xrpc/procedures.py index c066987..6ae1151 100644 --- a/src/atdata_app/xrpc/procedures.py +++ b/src/atdata_app/xrpc/procedures.py @@ -19,6 +19,7 @@ query_get_schema, query_record_exists, ) +from urllib.parse import urlparse from atdata_app.models import parse_at_uri logger = logging.getLogger(__name__) @@ -268,3 +269,41 @@ async def publish_lens(request: Request) -> dict[str, Any]: pds, pds_token, auth.iss, "science.alt.dataset.lens", record, rkey ) return {"uri": result.get("uri"), "cid": result.get("cid")} + + +# --------------------------------------------------------------------------- +# publishIndex +# --------------------------------------------------------------------------- + + +@router.post("/science.alt.dataset.publishIndex") +async def publish_index(request: Request) -> dict[str, Any]: + auth = await verify_service_auth(request, "science.alt.dataset.publishIndex") + pds_token = _require_pds_token(request) + + body = await request.json() + record = body.get("record", {}) + rkey = body.get("rkey") + + record_type = record.get("$type", "") + if record_type and record_type != "science.alt.dataset.index": + raise HTTPException(status_code=400, detail="Invalid $type for index") + + for field in ("name", "endpointUrl", "createdAt"): + if field not in record: + raise HTTPException(status_code=400, detail=f"Missing required field: {field}") + + # Validate endpoint URL is HTTPS + parsed = urlparse(record["endpointUrl"]) + if parsed.scheme != "https" or not parsed.netloc: + raise HTTPException( + status_code=400, detail="endpointUrl must be a valid HTTPS URL" + ) + + record["$type"] = "science.alt.dataset.index" + + pds = await _resolve_pds(auth.iss) + result = await _proxy_create_record( + pds, pds_token, auth.iss, "science.alt.dataset.index", record, rkey + ) + return {"uri": result.get("uri"), "cid": result.get("cid")} diff --git a/src/atdata_app/xrpc/queries.py b/src/atdata_app/xrpc/queries.py index df5d15c..21792b2 100644 --- a/src/atdata_app/xrpc/queries.py +++ b/src/atdata_app/xrpc/queries.py @@ -8,6 +8,8 @@ from fastapi import APIRouter, HTTPException, Query, Request from atdata_app import get_resolver +import httpx + from atdata_app.database import ( COLLECTION_TABLE_MAP, fire_analytics_event, @@ -16,8 +18,10 @@ query_entry_stats, query_get_entries, query_get_entry, + query_get_index_provider, query_get_schema, query_list_entries, + query_list_index_providers, query_list_lenses, query_list_schemas, query_record_counts, @@ -32,7 +36,10 @@ GetEntriesResponse, GetEntryResponse, GetEntryStatsResponse, + IndexResponse, + IndexSkeletonResponse, ListEntriesResponse, + ListIndexesResponse, ListLensesResponse, ListSchemasResponse, ResolveBlobsResponse, @@ -44,6 +51,7 @@ parse_at_uri, parse_cursor, row_to_entry, + row_to_index_provider, row_to_label, row_to_lens, row_to_schema, @@ -340,6 +348,138 @@ async def search_lenses( ) +# --------------------------------------------------------------------------- +# listIndexes +# --------------------------------------------------------------------------- + + +@router.get("/science.alt.dataset.listIndexes") +async def list_indexes( + request: Request, + repo: str | None = Query(None), + limit: int = Query(50, ge=1, le=100), + cursor: str | None = Query(None), +) -> ListIndexesResponse: + pool = request.app.state.db_pool + c_at, c_did, c_rkey = parse_cursor(cursor) + rows = await query_list_index_providers(pool, repo, limit, c_did, c_rkey, c_at) + return ListIndexesResponse( + indexes=[row_to_index_provider(r) for r in rows], + cursor=maybe_cursor(rows, limit), + ) + + +# --------------------------------------------------------------------------- +# getIndexSkeleton +# --------------------------------------------------------------------------- + + +async def _fetch_skeleton( + endpoint_url: str, + cursor: str | None, + limit: int, +) -> dict[str, Any]: + """Fetch skeleton from an upstream index provider.""" + params: dict[str, Any] = {"limit": limit} + if cursor: + params["cursor"] = cursor + async with httpx.AsyncClient(timeout=10.0) as http: + try: + resp = await http.get(endpoint_url, params=params) + except httpx.HTTPError as e: + raise HTTPException( + status_code=502, detail=f"Index provider unreachable: {e}" + ) from e + if resp.status_code != 200: + raise HTTPException( + status_code=502, + detail=f"Index provider returned {resp.status_code}", + ) + try: + data = resp.json() + except (ValueError, KeyError) as e: + raise HTTPException( + status_code=502, detail=f"Invalid response from index provider: {e}" + ) from e + if not isinstance(data.get("items"), list): + raise HTTPException( + status_code=502, detail="Index provider response missing 'items' array" + ) + return data + + +@router.get("/science.alt.dataset.getIndexSkeleton") +async def get_index_skeleton( + request: Request, + index: str = Query(...), + cursor: str | None = Query(None), + limit: int = Query(50, ge=1, le=100), +) -> IndexSkeletonResponse: + pool = request.app.state.db_pool + try: + did, _, rkey = parse_at_uri(index) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid AT-URI for index") + provider = await query_get_index_provider(pool, did, rkey) + if not provider: + raise HTTPException(status_code=404, detail="Index provider not found") + + data = await _fetch_skeleton(provider["endpoint_url"], cursor, limit) + return IndexSkeletonResponse( + items=data["items"], + cursor=data.get("cursor"), + ) + + +# --------------------------------------------------------------------------- +# getIndex (hydrated) +# --------------------------------------------------------------------------- + + +@router.get("/science.alt.dataset.getIndex") +async def get_index( + request: Request, + index: str = Query(...), + cursor: str | None = Query(None), + limit: int = Query(50, ge=1, le=100), +) -> IndexResponse: + pool = request.app.state.db_pool + try: + did, _, rkey = parse_at_uri(index) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid AT-URI for index") + provider = await query_get_index_provider(pool, did, rkey) + if not provider: + raise HTTPException(status_code=404, detail="Index provider not found") + + data = await _fetch_skeleton(provider["endpoint_url"], cursor, limit) + + # Parse URIs from skeleton items and hydrate + keys: list[tuple[str, str]] = [] + for item in data["items"]: + uri = item.get("uri", "") + try: + entry_did, _, entry_rkey = parse_at_uri(uri) + keys.append((entry_did, entry_rkey)) + except ValueError: + continue # skip malformed URIs + + rows = await query_get_entries(pool, keys) + + # Build a lookup map to preserve skeleton ordering + row_map = {(r["did"], r["rkey"]): r for r in rows} + hydrated = [] + for entry_did, entry_rkey in keys: + row = row_map.get((entry_did, entry_rkey)) + if row: + hydrated.append(row_to_entry(row)) + + return IndexResponse( + items=hydrated, + cursor=data.get("cursor"), + ) + + # --------------------------------------------------------------------------- # describeService # --------------------------------------------------------------------------- diff --git a/tests/test_index.py b/tests/test_index.py new file mode 100644 index 0000000..8a11dbb --- /dev/null +++ b/tests/test_index.py @@ -0,0 +1,541 @@ +"""Tests for index provider endpoints (skeleton/hydration pattern).""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from httpx import ASGITransport, AsyncClient + +from atdata_app.config import AppConfig +from atdata_app.ingestion.processor import process_commit +from atdata_app.main import create_app + +_DB = "atdata_app.database" +_QUERIES = "atdata_app.xrpc.queries" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_app() -> tuple: + config = AppConfig(dev_mode=True, hostname="localhost", port=8000) + app = create_app(config) + pool = AsyncMock() + app.state.db_pool = pool + return app, pool + + +def _index_provider_row( + did: str = "did:plc:provider1", + rkey: str = "3abc", + endpoint_url: str = "https://example.com/skeleton", + name: str = "Genomics Index", + description: str = "Curated genomics datasets", +) -> dict: + """Simulate an asyncpg Record as a dict-like object.""" + return MagicMock( + **{ + "__getitem__": lambda self, key: { + "did": did, + "rkey": rkey, + "cid": "bafyindex", + "name": name, + "description": description, + "endpoint_url": endpoint_url, + "created_at": "2025-01-01T00:00:00Z", + "indexed_at": "2025-01-01T00:00:00+00:00", + }[key], + } + ) + + +def _entry_row(did: str = "did:plc:author1", rkey: str = "3xyz") -> MagicMock: + return MagicMock( + **{ + "__getitem__": lambda self, key: { + "did": did, + "rkey": rkey, + "cid": "bafyentry", + "name": "test-dataset", + "schema_ref": "at://did:plc:test/science.alt.dataset.schema/test@1.0.0", + "storage": '{"$type": "science.alt.dataset.storageHttp", "shards": []}', + "description": None, + "tags": None, + "license": None, + "size_samples": None, + "size_bytes": None, + "size_shards": None, + "created_at": "2025-01-01T00:00:00Z", + "indexed_at": "2025-01-01T00:00:00+00:00", + }[key], + } + ) + + +# --------------------------------------------------------------------------- +# getIndexSkeleton +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_success(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + + skeleton_response = { + "items": [{"uri": "at://did:plc:a/science.alt.dataset.entry/3xyz"}], + "cursor": "next123", + } + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = skeleton_response + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["items"]) == 1 + assert data["items"][0]["uri"] == "at://did:plc:a/science.alt.dataset.entry/3xyz" + assert data["cursor"] == "next123" + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_not_found(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = None + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:missing/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_invalid_uri(mock_get_provider): + app, pool = _make_app() + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "not-a-uri"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_upstream_error(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + + mock_resp = MagicMock() + mock_resp.status_code = 500 + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 502 + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_upstream_unreachable(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectError("Connection refused") + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 502 + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_invalid_response(mock_get_provider): + """Upstream returns JSON without 'items' array.""" + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"bad": "data"} + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 502 + + +# --------------------------------------------------------------------------- +# getIndex (hydrated) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_entries", new_callable=AsyncMock) +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_hydrated(mock_get_provider, mock_get_entries): + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + mock_get_entries.return_value = [_entry_row("did:plc:a", "3xyz")] + + skeleton_response = { + "items": [ + {"uri": "at://did:plc:a/science.alt.dataset.entry/3xyz"}, + ], + "cursor": "next456", + } + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = skeleton_response + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndex", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["items"]) == 1 + assert data["items"][0]["name"] == "test-dataset" + assert data["cursor"] == "next456" + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_entries", new_callable=AsyncMock) +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_omits_missing_entries(mock_get_provider, mock_get_entries): + """Entries not in the DB should be silently omitted.""" + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + # Return only one of two requested entries + mock_get_entries.return_value = [_entry_row("did:plc:a", "3xyz")] + + skeleton_response = { + "items": [ + {"uri": "at://did:plc:a/science.alt.dataset.entry/3xyz"}, + {"uri": "at://did:plc:b/science.alt.dataset.entry/3deleted"}, + ], + } + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = skeleton_response + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndex", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["items"]) == 1 + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_not_found(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = None + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndex", + params={"index": "at://did:plc:missing/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# listIndexes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_list_index_providers", new_callable=AsyncMock) +async def test_list_indexes(mock_list): + app, pool = _make_app() + mock_list.return_value = [_index_provider_row()] + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/xrpc/science.alt.dataset.listIndexes") + assert resp.status_code == 200 + data = resp.json() + assert len(data["indexes"]) == 1 + assert data["indexes"][0]["name"] == "Genomics Index" + assert data["indexes"][0]["endpointUrl"] == "https://example.com/skeleton" + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_list_index_providers", new_callable=AsyncMock) +async def test_list_indexes_empty(mock_list): + app, pool = _make_app() + mock_list.return_value = [] + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/xrpc/science.alt.dataset.listIndexes") + assert resp.status_code == 200 + data = resp.json() + assert data["indexes"] == [] + assert data["cursor"] is None + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_list_index_providers", new_callable=AsyncMock) +async def test_list_indexes_with_repo_filter(mock_list): + app, pool = _make_app() + mock_list.return_value = [] + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.listIndexes", + params={"repo": "did:plc:provider1"}, + ) + assert resp.status_code == 200 + mock_list.assert_called_once_with(pool, "did:plc:provider1", 50, None, None, None) + + +# --------------------------------------------------------------------------- +# publishIndex +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock) +@patch("atdata_app.xrpc.procedures._resolve_pds", new_callable=AsyncMock) +@patch("atdata_app.xrpc.procedures._proxy_create_record", new_callable=AsyncMock) +async def test_publish_index(mock_proxy, mock_pds, mock_auth): + app, pool = _make_app() + + mock_auth.return_value = MagicMock(iss="did:plc:publisher1") + mock_pds.return_value = "https://pds.example.com" + mock_proxy.return_value = { + "uri": "at://did:plc:publisher1/science.alt.dataset.index/3abc", + "cid": "bafynew", + } + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.publishIndex", + json={ + "record": { + "name": "Genomics Index", + "endpointUrl": "https://example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + }, + headers={ + "Authorization": "Bearer test-token", + "X-PDS-Auth": "pds-token", + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["uri"] == "at://did:plc:publisher1/science.alt.dataset.index/3abc" + assert data["cid"] == "bafynew" + + +@pytest.mark.asyncio +@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock) +async def test_publish_index_missing_field(mock_auth): + app, pool = _make_app() + mock_auth.return_value = MagicMock(iss="did:plc:publisher1") + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.publishIndex", + json={ + "record": { + "name": "Test", + "createdAt": "2025-01-01T00:00:00Z", + # missing endpointUrl + }, + }, + headers={ + "Authorization": "Bearer test-token", + "X-PDS-Auth": "pds-token", + }, + ) + assert resp.status_code == 400 + assert "endpointUrl" in resp.json()["detail"] + + +@pytest.mark.asyncio +@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock) +async def test_publish_index_http_url_rejected(mock_auth): + app, pool = _make_app() + mock_auth.return_value = MagicMock(iss="did:plc:publisher1") + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.publishIndex", + json={ + "record": { + "name": "Bad Index", + "endpointUrl": "http://insecure.example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + }, + headers={ + "Authorization": "Bearer test-token", + "X-PDS-Auth": "pds-token", + }, + ) + assert resp.status_code == 400 + assert "HTTPS" in resp.json()["detail"] + + +@pytest.mark.asyncio +@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock) +async def test_publish_index_invalid_type(mock_auth): + app, pool = _make_app() + mock_auth.return_value = MagicMock(iss="did:plc:publisher1") + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.publishIndex", + json={ + "record": { + "$type": "science.alt.dataset.entry", + "name": "Wrong Type", + "endpointUrl": "https://example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + }, + headers={ + "Authorization": "Bearer test-token", + "X-PDS-Auth": "pds-token", + }, + ) + assert resp.status_code == 400 + assert "Invalid $type" in resp.json()["detail"] + + +# --------------------------------------------------------------------------- +# Ingestion: index provider records +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch(f"{_DB}.upsert_index_provider", new_callable=AsyncMock) +async def test_process_commit_index_provider(mock_upsert): + pool = AsyncMock() + event = { + "did": "did:plc:provider1", + "time_us": 1725911162329308, + "kind": "commit", + "commit": { + "rev": "rev1", + "operation": "create", + "collection": "science.alt.dataset.index", + "rkey": "3abc", + "record": { + "$type": "science.alt.dataset.index", + "name": "Genomics Index", + "endpointUrl": "https://example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + "cid": "bafyindex", + }, + } + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:provider1", "3abc", "bafyindex", event["commit"]["record"] + ) + + +@pytest.mark.asyncio +@patch(f"{_DB}.delete_record", new_callable=AsyncMock) +async def test_process_commit_delete_index_provider(mock_delete): + pool = AsyncMock() + event = { + "did": "did:plc:provider1", + "time_us": 1725911162329308, + "kind": "commit", + "commit": { + "rev": "rev1", + "operation": "delete", + "collection": "science.alt.dataset.index", + "rkey": "3abc", + }, + } + await process_commit(pool, event) + mock_delete.assert_called_once_with(pool, "index_providers", "did:plc:provider1", "3abc") From 6e8867df7a24b9254c396277ed1dcc23027a7819 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 01:05:10 -0800 Subject: [PATCH 04/12] fix: add missing did field to label/lens serializers, deduplicate record counts, use UPSERT_FNS dispatch - row_to_label() and row_to_lens() now include `did` field, consistent with row_to_entry() and row_to_schema() - Extract _fetch_record_counts() helper so query_record_counts() and query_analytics_summary() share the same implementation - Replace if/elif upsert chain in processor.py with db.UPSERT_FNS dict lookup; update test_ingestion.py to patch the dict directly - Use parse_at_uri() in frontend dataset_detail() instead of manual string splitting - Remove duplicate config fixture from test_analytics.py (conftest provides it) - Add sendInteractions edge case tests: missing key, non-dict item, boundary at max batch size Co-Authored-By: Claude Opus 4.6 --- CHANGELOG.md | 2 + src/atdata_app/database.py | 20 ++-- src/atdata_app/frontend/routes.py | 10 +- src/atdata_app/ingestion/processor.py | 9 +- src/atdata_app/models.py | 2 + tests/test_analytics.py | 62 ++++++++-- tests/test_ingestion.py | 164 ++++++++++++++------------ tests/test_models.py | 2 + 8 files changed, 165 insertions(+), 106 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index edb123d..8b4e263 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/). - DID document service entry updated from `#atproto_appview` / `AtprotoAppView` to `#atdata_appview` / `AtdataAppView` ### Added +- Adversarial review: sendInteractions feature and surrounding code (round 3) (#39) +- Add sendInteractions XRPC procedure for usage telemetry (#35) - Dual-hostname DID document support — serve different `did:web` documents for `api.atdata.app` (appview identity) and `atdata.app` (atproto account identity) based on the `Host` header ([#19](https://github.com/forecast-bio/atdata-app/issues/19)) - Host-based route gating middleware — frontend HTML routes are only served on the frontend hostname; the API subdomain serves only XRPC, health, and DID endpoints diff --git a/src/atdata_app/database.py b/src/atdata_app/database.py index 748e2c3..c197be2 100644 --- a/src/atdata_app/database.py +++ b/src/atdata_app/database.py @@ -544,13 +544,17 @@ async def query_search_lenses( ) +async def _fetch_record_counts(conn: asyncpg.Connection) -> dict[str, int]: + counts = {} + for collection, table in COLLECTION_TABLE_MAP.items(): + row = await conn.fetchrow(f"SELECT COUNT(*) as cnt FROM {table}") # noqa: S608 + counts[collection] = row["cnt"] + return counts + + async def query_record_counts(pool: asyncpg.Pool) -> dict[str, int]: async with pool.acquire() as conn: - counts = {} - for collection, table in COLLECTION_TABLE_MAP.items(): - row = await conn.fetchrow(f"SELECT COUNT(*) as cnt FROM {table}") # noqa: S608 - counts[collection] = row["cnt"] - return counts + return await _fetch_record_counts(conn) async def query_labels_for_dataset( @@ -692,11 +696,7 @@ async def query_analytics_summary( {"term": r["term"], "count": r["count"]} for r in term_rows ] - # Record counts - counts = {} - for collection, table in COLLECTION_TABLE_MAP.items(): - c = await conn.fetchrow(f"SELECT COUNT(*) AS cnt FROM {table}") # noqa: S608 - counts[collection] = c["cnt"] + counts = await _fetch_record_counts(conn) return { "totalViews": total_views, diff --git a/src/atdata_app/frontend/routes.py b/src/atdata_app/frontend/routes.py index 9a815e4..03e1c83 100644 --- a/src/atdata_app/frontend/routes.py +++ b/src/atdata_app/frontend/routes.py @@ -22,6 +22,7 @@ ) from atdata_app.models import ( maybe_cursor, + parse_at_uri, parse_cursor, row_to_entry, row_to_label, @@ -100,10 +101,11 @@ async def dataset_detail(request: Request, did: str, rkey: str): schema_did = "" schema_rkey = "" schema_ref = entry.get("schemaRef", "") - if schema_ref.startswith("at://"): - parts = schema_ref[5:].split("/", 2) - if len(parts) == 3: - schema_did, _, schema_rkey = parts + if schema_ref: + try: + schema_did, _, schema_rkey = parse_at_uri(schema_ref) + except ValueError: + pass # Fetch labels pointing to this dataset dataset_uri = entry["uri"] diff --git a/src/atdata_app/ingestion/processor.py b/src/atdata_app/ingestion/processor.py index fa34c96..5715ad8 100644 --- a/src/atdata_app/ingestion/processor.py +++ b/src/atdata_app/ingestion/processor.py @@ -48,14 +48,7 @@ async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None: record = commit["record"] cid = commit.get("cid") try: - if table == "schemas": - await db.upsert_schema(pool, did, rkey, cid, record) - elif table == "entries": - await db.upsert_entry(pool, did, rkey, cid, record) - elif table == "labels": - await db.upsert_label(pool, did, rkey, cid, record) - elif table == "lenses": - await db.upsert_lens(pool, did, rkey, cid, record) + await db.UPSERT_FNS[table](pool, did, rkey, cid, record) logger.debug("Upserted %s %s/%s", collection, did, rkey) except Exception: logger.exception("Failed to upsert %s %s/%s", collection, did, rkey) diff --git a/src/atdata_app/models.py b/src/atdata_app/models.py index 7cec0d9..08ab2ad 100644 --- a/src/atdata_app/models.py +++ b/src/atdata_app/models.py @@ -127,6 +127,7 @@ def row_to_label(row) -> dict[str, Any]: d: dict[str, Any] = { "uri": uri, "cid": row["cid"], + "did": row["did"], "name": row["name"], "datasetUri": row["dataset_uri"], "createdAt": row["created_at"], @@ -150,6 +151,7 @@ def row_to_lens(row) -> dict[str, Any]: d: dict[str, Any] = { "uri": uri, "cid": row["cid"], + "did": row["did"], "name": row["name"], "sourceSchema": row["source_schema"], "targetSchema": row["target_schema"], diff --git a/tests/test_analytics.py b/tests/test_analytics.py index be53149..8df0a14 100644 --- a/tests/test_analytics.py +++ b/tests/test_analytics.py @@ -8,7 +8,6 @@ import pytest from httpx import ASGITransport, AsyncClient -from atdata_app.config import AppConfig from atdata_app.database import ( fire_analytics_event, record_analytics_event, @@ -28,17 +27,12 @@ # --------------------------------------------------------------------------- -@pytest.fixture -def config() -> AppConfig: - return AppConfig(dev_mode=True, hostname="localhost", port=8000) - - @pytest.fixture def pool() -> AsyncMock: return AsyncMock() -def _mock_app(config: AppConfig, pool: AsyncMock): +def _mock_app(config, pool): """Create a FastAPI app with mocked lifespan (no real DB).""" app = create_app(config) app.state.db_pool = pool @@ -596,3 +590,57 @@ async def test_send_interactions_all_three_types(mock_fire, config, pool): assert resp.status_code == 200 assert mock_fire.call_count == 3 + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_missing_key(mock_fire, config, pool): + """Body without 'interactions' key should return 400.""" + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"data": []}, + ) + + assert resp.status_code == 400 + assert "interactions must be an array" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_non_dict_item(mock_fire, config, pool): + """Non-object items in the interactions array should return 400.""" + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": ["not-a-dict"]}, + ) + + assert resp.status_code == 400 + assert "must be an object" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_boundary_at_max(mock_fire, config, pool): + """Exactly 100 interactions (the maximum) should succeed.""" + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + interactions = [ + {"type": "download", "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz"} + for _ in range(100) + ] + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": interactions}, + ) + + assert resp.status_code == 200 + assert mock_fire.call_count == 100 diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py index 2de2b49..490934f 100644 --- a/tests/test_ingestion.py +++ b/tests/test_ingestion.py @@ -8,7 +8,6 @@ from atdata_app.ingestion.processor import process_commit -# All patches target the `db` module reference used inside processor.py _DB = "atdata_app.database" @@ -44,15 +43,22 @@ def _make_event( } +def _patch_upsert(table: str): + """Patch a single entry in UPSERT_FNS by table name.""" + mock = AsyncMock() + return patch.dict(f"{_DB}.UPSERT_FNS", {table: mock}), mock + + @pytest.mark.asyncio -@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock) -async def test_process_commit_create(mock_upsert): - pool = AsyncMock() - event = _make_event(operation="create") - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) +async def test_process_commit_create(): + patcher, mock_upsert = _patch_upsert("entries") + with patcher: + pool = AsyncMock() + event = _make_event(operation="create") + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio @@ -72,85 +78,89 @@ async def test_process_commit_delete(mock_delete): @pytest.mark.asyncio -@patch(f"{_DB}.upsert_schema", new_callable=AsyncMock) -async def test_process_commit_schema(mock_upsert): - pool = AsyncMock() - event = _make_event( - collection="science.alt.dataset.schema", - record={ - "$type": "science.alt.dataset.schema", - "name": "TestSchema", - "version": "1.0.0", - "schemaType": "jsonSchema", - "schema": {"$type": "science.alt.dataset.schema#jsonSchemaFormat"}, - "createdAt": "2025-01-01T00:00:00Z", - }, - ) - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) +async def test_process_commit_schema(): + patcher, mock_upsert = _patch_upsert("schemas") + with patcher: + pool = AsyncMock() + event = _make_event( + collection="science.alt.dataset.schema", + record={ + "$type": "science.alt.dataset.schema", + "name": "TestSchema", + "version": "1.0.0", + "schemaType": "jsonSchema", + "schema": {"$type": "science.alt.dataset.schema#jsonSchemaFormat"}, + "createdAt": "2025-01-01T00:00:00Z", + }, + ) + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio -@patch(f"{_DB}.upsert_label", new_callable=AsyncMock) -async def test_process_commit_label(mock_upsert): - pool = AsyncMock() - event = _make_event( - collection="science.alt.dataset.label", - record={ - "$type": "science.alt.dataset.label", - "name": "mnist", - "datasetUri": "at://did:plc:test/science.alt.dataset.entry/3xyz", - "createdAt": "2025-01-01T00:00:00Z", - }, - ) - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) +async def test_process_commit_label(): + patcher, mock_upsert = _patch_upsert("labels") + with patcher: + pool = AsyncMock() + event = _make_event( + collection="science.alt.dataset.label", + record={ + "$type": "science.alt.dataset.label", + "name": "mnist", + "datasetUri": "at://did:plc:test/science.alt.dataset.entry/3xyz", + "createdAt": "2025-01-01T00:00:00Z", + }, + ) + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio -@patch(f"{_DB}.upsert_lens", new_callable=AsyncMock) -async def test_process_commit_lens(mock_upsert): - pool = AsyncMock() - event = _make_event( - collection="science.alt.dataset.lens", - record={ - "$type": "science.alt.dataset.lens", - "name": "test-lens", - "sourceSchema": "at://did:plc:test/science.alt.dataset.schema/a@1.0.0", - "targetSchema": "at://did:plc:test/science.alt.dataset.schema/b@1.0.0", - "getterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "get.py"}, - "putterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "put.py"}, - "createdAt": "2025-01-01T00:00:00Z", - }, - ) - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) +async def test_process_commit_lens(): + patcher, mock_upsert = _patch_upsert("lenses") + with patcher: + pool = AsyncMock() + event = _make_event( + collection="science.alt.dataset.lens", + record={ + "$type": "science.alt.dataset.lens", + "name": "test-lens", + "sourceSchema": "at://did:plc:test/science.alt.dataset.schema/a@1.0.0", + "targetSchema": "at://did:plc:test/science.alt.dataset.schema/b@1.0.0", + "getterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "get.py"}, + "putterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "put.py"}, + "createdAt": "2025-01-01T00:00:00Z", + }, + ) + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio -@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock) -async def test_process_commit_update(mock_upsert): +async def test_process_commit_update(): """Update operations should route to the same upsert function as create.""" - pool = AsyncMock() - event = _make_event(operation="update") - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) + patcher, mock_upsert = _patch_upsert("entries") + with patcher: + pool = AsyncMock() + event = _make_event(operation="update") + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio -@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock) -async def test_process_commit_upsert_error_is_caught(mock_upsert): +async def test_process_commit_upsert_error_is_caught(): """Upsert failures should be logged, not raised.""" - mock_upsert.side_effect = Exception("db error") - pool = AsyncMock() - event = _make_event(operation="create") - # Should not raise - await process_commit(pool, event) + mock = AsyncMock(side_effect=Exception("db error")) + with patch.dict(f"{_DB}.UPSERT_FNS", {"entries": mock}): + pool = AsyncMock() + event = _make_event(operation="create") + # Should not raise + await process_commit(pool, event) diff --git a/tests/test_models.py b/tests/test_models.py index 23f6046..5ba7dcc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -193,6 +193,7 @@ def test_row_to_schema_json_string_body(): def test_row_to_label(): d = row_to_label(_LABEL_ROW) assert d["uri"] == "at://did:plc:abc/science.alt.dataset.label/3lbl" + assert d["did"] == "did:plc:abc" assert d["datasetUri"] == _LABEL_ROW["dataset_uri"] assert d["version"] == "1.0.0" assert d["description"] == "First version" @@ -227,6 +228,7 @@ def test_row_to_label_omits_optional_fields(): def test_row_to_lens(): d = row_to_lens(_LENS_ROW) assert d["uri"] == "at://did:plc:abc/science.alt.dataset.lens/3lens" + assert d["did"] == "did:plc:abc" assert d["sourceSchema"] == _LENS_ROW["source_schema"] assert d["getterCode"] == _LENS_ROW["getter_code"] assert d["description"] == "Transforms A to B" From bb72c9cbb85296e36c8cfd29d96ef193d46d24fb Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 01:39:18 -0800 Subject: [PATCH 05/12] fix: add missing did field to row_to_label/row_to_lens, batch mechanical fixes - Add did field to row_to_label and row_to_lens serializers for consistency with row_to_entry, row_to_schema, row_to_index_provider - Add index_providers to query_active_publishers UNION query - Fix import ordering in queries.py and procedures.py (PEP 8) - Add row_to_index_provider unit tests to test_models.py - Add index_providers table/indexes to integration test expectations - Remove unnecessary mock patch from test_get_index_skeleton_invalid_uri Co-Authored-By: Claude Opus 4.6 --- src/atdata_app/database.py | 2 ++ src/atdata_app/models.py | 2 ++ src/atdata_app/xrpc/procedures.py | 2 +- src/atdata_app/xrpc/queries.py | 2 +- tests/test_index.py | 3 +-- tests/test_integration.py | 3 +++ tests/test_models.py | 33 +++++++++++++++++++++++++++++++ 7 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/atdata_app/database.py b/src/atdata_app/database.py index eb69133..36af53c 100644 --- a/src/atdata_app/database.py +++ b/src/atdata_app/database.py @@ -825,6 +825,8 @@ async def query_active_publishers(pool: asyncpg.Pool, days: int = 30) -> int: SELECT did FROM labels WHERE indexed_at >= NOW() - $1::interval UNION SELECT did FROM lenses WHERE indexed_at >= NOW() - $1::interval + UNION + SELECT did FROM index_providers WHERE indexed_at >= NOW() - $1::interval ) sub """, interval, diff --git a/src/atdata_app/models.py b/src/atdata_app/models.py index 6d123c9..4bc4f7b 100644 --- a/src/atdata_app/models.py +++ b/src/atdata_app/models.py @@ -127,6 +127,7 @@ def row_to_label(row) -> dict[str, Any]: d: dict[str, Any] = { "uri": uri, "cid": row["cid"], + "did": row["did"], "name": row["name"], "datasetUri": row["dataset_uri"], "createdAt": row["created_at"], @@ -165,6 +166,7 @@ def row_to_lens(row) -> dict[str, Any]: d: dict[str, Any] = { "uri": uri, "cid": row["cid"], + "did": row["did"], "name": row["name"], "sourceSchema": row["source_schema"], "targetSchema": row["target_schema"], diff --git a/src/atdata_app/xrpc/procedures.py b/src/atdata_app/xrpc/procedures.py index 6ae1151..cd85042 100644 --- a/src/atdata_app/xrpc/procedures.py +++ b/src/atdata_app/xrpc/procedures.py @@ -8,6 +8,7 @@ import logging from typing import Any +from urllib.parse import urlparse import httpx from fastapi import APIRouter, HTTPException, Request @@ -19,7 +20,6 @@ query_get_schema, query_record_exists, ) -from urllib.parse import urlparse from atdata_app.models import parse_at_uri logger = logging.getLogger(__name__) diff --git a/src/atdata_app/xrpc/queries.py b/src/atdata_app/xrpc/queries.py index 21792b2..218aa44 100644 --- a/src/atdata_app/xrpc/queries.py +++ b/src/atdata_app/xrpc/queries.py @@ -7,9 +7,9 @@ from fastapi import APIRouter, HTTPException, Query, Request -from atdata_app import get_resolver import httpx +from atdata_app import get_resolver from atdata_app.database import ( COLLECTION_TABLE_MAP, fire_analytics_event, diff --git a/tests/test_index.py b/tests/test_index.py index 8a11dbb..1af7a0a 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -132,8 +132,7 @@ async def test_get_index_skeleton_not_found(mock_get_provider): @pytest.mark.asyncio -@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) -async def test_get_index_skeleton_invalid_uri(mock_get_provider): +async def test_get_index_skeleton_invalid_uri(): app, pool = _make_app() transport = ASGITransport(app=app) diff --git a/tests/test_integration.py b/tests/test_integration.py index 485ee27..2cccd1f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -162,6 +162,7 @@ async def test_all_expected_tables_exist(self, db_pool): "entries", "labels", "lenses", + "index_providers", "cursor_state", "analytics_events", "analytics_counters", @@ -196,6 +197,8 @@ async def test_expected_indexes_exist(self, db_pool): "idx_lenses_did", "idx_analytics_events_type_created", "idx_analytics_events_target", + "idx_index_providers_did", + "idx_index_providers_indexed_at", } async with db_pool.acquire() as conn: rows = await conn.fetch( diff --git a/tests/test_models.py b/tests/test_models.py index 23f6046..08d250c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -10,6 +10,7 @@ make_at_uri, parse_at_uri, row_to_entry, + row_to_index_provider, row_to_label, row_to_lens, row_to_schema, @@ -193,6 +194,7 @@ def test_row_to_schema_json_string_body(): def test_row_to_label(): d = row_to_label(_LABEL_ROW) assert d["uri"] == "at://did:plc:abc/science.alt.dataset.label/3lbl" + assert d["did"] == "did:plc:abc" assert d["datasetUri"] == _LABEL_ROW["dataset_uri"] assert d["version"] == "1.0.0" assert d["description"] == "First version" @@ -227,6 +229,7 @@ def test_row_to_label_omits_optional_fields(): def test_row_to_lens(): d = row_to_lens(_LENS_ROW) assert d["uri"] == "at://did:plc:abc/science.alt.dataset.lens/3lens" + assert d["did"] == "did:plc:abc" assert d["sourceSchema"] == _LENS_ROW["source_schema"] assert d["getterCode"] == _LENS_ROW["getter_code"] assert d["description"] == "Transforms A to B" @@ -249,3 +252,33 @@ def test_row_to_lens_json_string_code(): d = row_to_lens(row) assert d["getterCode"] == {"repo": "x"} assert d["putterCode"] == {"repo": "y"} + + +# --------------------------------------------------------------------------- +# row_to_index_provider +# --------------------------------------------------------------------------- + +_INDEX_PROVIDER_ROW = { + "did": "did:plc:provider1", + "rkey": "3idx", + "cid": "bafyindex", + "name": "Genomics Index", + "description": "Curated genomics datasets", + "endpoint_url": "https://example.com/skeleton", + "created_at": "2025-01-01T00:00:00Z", +} + + +def test_row_to_index_provider(): + d = row_to_index_provider(_INDEX_PROVIDER_ROW) + assert d["uri"] == "at://did:plc:provider1/science.alt.dataset.index/3idx" + assert d["did"] == "did:plc:provider1" + assert d["name"] == "Genomics Index" + assert d["endpointUrl"] == "https://example.com/skeleton" + assert d["description"] == "Curated genomics datasets" + + +def test_row_to_index_provider_omits_null_description(): + row = {**_INDEX_PROVIDER_ROW, "description": None} + d = row_to_index_provider(row) + assert "description" not in d From 74eaf458fa8fd7d2cce63ff4c1d8c0e012a53ed1 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 01:44:07 -0800 Subject: [PATCH 06/12] feat: add subscribeChanges WebSocket endpoint for real-time change streaming Introduces an in-memory broadcast event bus (changestream.py) that the ingestion processor publishes to after successful upserts/deletes. A new WebSocket endpoint at /xrpc/science.alt.dataset.subscribeChanges streams these events to subscribers with cursor-based replay from a bounded buffer. Co-Authored-By: Claude Opus 4.6 --- src/atdata_app/changestream.py | 152 +++++++++ src/atdata_app/ingestion/jetstream.py | 4 +- src/atdata_app/ingestion/processor.py | 27 +- src/atdata_app/main.py | 4 + src/atdata_app/xrpc/router.py | 4 +- src/atdata_app/xrpc/subscriptions.py | 53 ++++ tests/test_changestream.py | 440 ++++++++++++++++++++++++++ 7 files changed, 681 insertions(+), 3 deletions(-) create mode 100644 src/atdata_app/changestream.py create mode 100644 src/atdata_app/xrpc/subscriptions.py create mode 100644 tests/test_changestream.py diff --git a/src/atdata_app/changestream.py b/src/atdata_app/changestream.py new file mode 100644 index 0000000..3d6510c --- /dev/null +++ b/src/atdata_app/changestream.py @@ -0,0 +1,152 @@ +"""In-memory broadcast channel for real-time change events. + +Provides a pub/sub mechanism that the ingestion processor publishes to +and WebSocket subscribers consume from. Maintains a bounded buffer of +recent events for cursor-based replay. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections import deque +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + +DEFAULT_BUFFER_SIZE = 1000 +DEFAULT_SUBSCRIBER_QUEUE_SIZE = 256 + + +@dataclass +class ChangeEvent: + """A single change event in the stream.""" + + seq: int + type: str # "create", "update", or "delete" + collection: str + did: str + rkey: str + timestamp: str + record: dict[str, Any] | None = None + cid: str | None = None + + def to_dict(self) -> dict[str, Any]: + d: dict[str, Any] = { + "seq": self.seq, + "type": self.type, + "collection": self.collection, + "did": self.did, + "rkey": self.rkey, + "timestamp": self.timestamp, + } + if self.record is not None: + d["record"] = self.record + if self.cid is not None: + d["cid"] = self.cid + return d + + +@dataclass +class ChangeStream: + """Broadcast channel with bounded replay buffer. + + Thread-safe for asyncio: all mutations happen in the event loop. + """ + + buffer_size: int = DEFAULT_BUFFER_SIZE + subscriber_queue_size: int = DEFAULT_SUBSCRIBER_QUEUE_SIZE + _seq: int = field(default=0, init=False) + _buffer: deque[ChangeEvent] = field(init=False) + _subscribers: dict[int, asyncio.Queue[ChangeEvent]] = field( + default_factory=dict, init=False + ) + _next_sub_id: int = field(default=0, init=False) + + def __post_init__(self) -> None: + self._buffer = deque(maxlen=self.buffer_size) + + def publish(self, event: ChangeEvent) -> None: + """Publish an event to all subscribers and the replay buffer. + + Non-blocking. If a subscriber's queue is full, the event is dropped + for that subscriber (backpressure via disconnect is handled by the + WebSocket handler). + """ + self._seq += 1 + event.seq = self._seq + self._buffer.append(event) + + for sub_id, queue in list(self._subscribers.items()): + try: + queue.put_nowait(event) + except asyncio.QueueFull: + logger.warning( + "Subscriber %d queue full, dropping event seq=%d", + sub_id, + event.seq, + ) + + def subscribe(self) -> tuple[int, asyncio.Queue[ChangeEvent]]: + """Create a new subscriber. Returns (subscriber_id, queue).""" + sub_id = self._next_sub_id + self._next_sub_id += 1 + queue: asyncio.Queue[ChangeEvent] = asyncio.Queue( + maxsize=self.subscriber_queue_size + ) + self._subscribers[sub_id] = queue + logger.debug("Subscriber %d connected (total: %d)", sub_id, len(self._subscribers)) + return sub_id, queue + + def unsubscribe(self, sub_id: int) -> None: + """Remove a subscriber.""" + self._subscribers.pop(sub_id, None) + logger.debug("Subscriber %d disconnected (total: %d)", sub_id, len(self._subscribers)) + + def replay_from(self, cursor: int) -> list[ChangeEvent]: + """Return buffered events with seq > cursor. + + Returns an empty list if the cursor is outside the buffer window. + """ + if not self._buffer: + return [] + + oldest_seq = self._buffer[0].seq + if cursor < oldest_seq - 1: + # Cursor is too old — events between cursor and buffer start were lost + return [] + + return [ev for ev in self._buffer if ev.seq > cursor] + + @property + def current_seq(self) -> int: + return self._seq + + @property + def subscriber_count(self) -> int: + return len(self._subscribers) + + +def make_change_event( + *, + event_type: str, + collection: str, + did: str, + rkey: str, + record: dict[str, Any] | None = None, + cid: str | None = None, +) -> ChangeEvent: + """Factory for creating change events with current timestamp.""" + from datetime import datetime, timezone + + return ChangeEvent( + seq=0, # Assigned by ChangeStream.publish() + type=event_type, + collection=collection, + did=did, + rkey=rkey, + timestamp=datetime.now(timezone.utc).isoformat(), + record=record, + cid=cid, + ) diff --git a/src/atdata_app/ingestion/jetstream.py b/src/atdata_app/ingestion/jetstream.py index 9b765d2..f09711b 100644 --- a/src/atdata_app/ingestion/jetstream.py +++ b/src/atdata_app/ingestion/jetstream.py @@ -52,7 +52,9 @@ async def jetstream_consumer(app: FastAPI) -> None: if event.get("kind") != "commit": continue - await process_commit(pool, event) + await process_commit( + pool, event, getattr(app.state, "change_stream", None) + ) last_time_us = event.get("time_us") msg_count += 1 diff --git a/src/atdata_app/ingestion/processor.py b/src/atdata_app/ingestion/processor.py index fa34c96..40e4a4b 100644 --- a/src/atdata_app/ingestion/processor.py +++ b/src/atdata_app/ingestion/processor.py @@ -8,11 +8,16 @@ import asyncpg from atdata_app import database as db +from atdata_app.changestream import ChangeStream, make_change_event logger = logging.getLogger(__name__) -async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None: +async def process_commit( + pool: asyncpg.Pool, + event: dict[str, Any], + change_stream: ChangeStream | None = None, +) -> None: """Process a Jetstream commit event. Expected event format:: @@ -44,6 +49,15 @@ async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None: if operation == "delete": await db.delete_record(pool, table, did, rkey) logger.debug("Deleted %s %s/%s", collection, did, rkey) + if change_stream is not None: + change_stream.publish( + make_change_event( + event_type="delete", + collection=collection, + did=did, + rkey=rkey, + ) + ) elif operation in ("create", "update"): record = commit["record"] cid = commit.get("cid") @@ -57,5 +71,16 @@ async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None: elif table == "lenses": await db.upsert_lens(pool, did, rkey, cid, record) logger.debug("Upserted %s %s/%s", collection, did, rkey) + if change_stream is not None: + change_stream.publish( + make_change_event( + event_type=operation, + collection=collection, + did=did, + rkey=rkey, + record=record, + cid=cid, + ) + ) except Exception: logger.exception("Failed to upsert %s %s/%s", collection, did, rkey) diff --git a/src/atdata_app/main.py b/src/atdata_app/main.py index b50bba3..087fde6 100644 --- a/src/atdata_app/main.py +++ b/src/atdata_app/main.py @@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles +from atdata_app.changestream import ChangeStream from atdata_app.config import AppConfig from atdata_app.database import create_pool, run_migrations from atdata_app.frontend import router as frontend_router @@ -30,6 +31,9 @@ async def lifespan(app: FastAPI): config: AppConfig = app.state.config logger.info("Starting atdata-app (DID: %s)", config.service_did) + # Change stream (must be created before background tasks) + app.state.change_stream = ChangeStream() + # Database pool = await create_pool(config.database_url) app.state.db_pool = pool diff --git a/src/atdata_app/xrpc/router.py b/src/atdata_app/xrpc/router.py index 627b382..d929478 100644 --- a/src/atdata_app/xrpc/router.py +++ b/src/atdata_app/xrpc/router.py @@ -1,10 +1,12 @@ -"""Combined XRPC router mounting all query and procedure endpoints.""" +"""Combined XRPC router mounting all query, procedure, and subscription endpoints.""" from fastapi import APIRouter from atdata_app.xrpc.procedures import router as procedures_router from atdata_app.xrpc.queries import router as queries_router +from atdata_app.xrpc.subscriptions import router as subscriptions_router router = APIRouter(prefix="/xrpc") router.include_router(queries_router) router.include_router(procedures_router) +router.include_router(subscriptions_router) diff --git a/src/atdata_app/xrpc/subscriptions.py b/src/atdata_app/xrpc/subscriptions.py new file mode 100644 index 0000000..cdcbf60 --- /dev/null +++ b/src/atdata_app/xrpc/subscriptions.py @@ -0,0 +1,53 @@ +"""WebSocket subscription endpoints for real-time change streaming.""" + +from __future__ import annotations + +import json +import logging + +from fastapi import APIRouter, WebSocket, WebSocketDisconnect + +from atdata_app.changestream import ChangeStream + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.websocket("/science.alt.dataset.subscribeChanges") +async def subscribe_changes(websocket: WebSocket) -> None: + """Stream real-time change events over WebSocket. + + Query parameters: + cursor: Optional sequence number to replay from. + """ + change_stream: ChangeStream = websocket.app.state.change_stream + + await websocket.accept() + + cursor_param = websocket.query_params.get("cursor") + sub_id, queue = change_stream.subscribe() + + try: + # Replay buffered events if cursor provided + if cursor_param is not None: + try: + cursor = int(cursor_param) + except (ValueError, TypeError): + await websocket.close(code=1008, reason="Invalid cursor value") + return + missed = change_stream.replay_from(cursor) + for event in missed: + await websocket.send_text(json.dumps(event.to_dict())) + + # Stream live events + while True: + event = await queue.get() + await websocket.send_text(json.dumps(event.to_dict())) + + except WebSocketDisconnect: + logger.debug("Subscriber %d disconnected", sub_id) + except Exception: + logger.exception("Error in subscriber %d", sub_id) + finally: + change_stream.unsubscribe(sub_id) diff --git a/tests/test_changestream.py b/tests/test_changestream.py new file mode 100644 index 0000000..6413e47 --- /dev/null +++ b/tests/test_changestream.py @@ -0,0 +1,440 @@ +"""Tests for the change stream event bus and WebSocket subscription endpoint.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from starlette.testclient import TestClient + +from atdata_app.changestream import ChangeEvent, ChangeStream, make_change_event +from atdata_app.ingestion.processor import process_commit +from atdata_app.xrpc.subscriptions import router as subscriptions_router + + +# --------------------------------------------------------------------------- +# ChangeStream unit tests +# --------------------------------------------------------------------------- + + +class TestChangeStream: + def test_publish_assigns_monotonic_seq(self): + cs = ChangeStream() + ev1 = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + ev2 = make_change_event( + event_type="update", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="def", + ) + cs.publish(ev1) + cs.publish(ev2) + assert ev1.seq == 1 + assert ev2.seq == 2 + assert cs.current_seq == 2 + + def test_publish_delivers_to_subscribers(self): + cs = ChangeStream() + sub_id, queue = cs.subscribe() + + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + cs.publish(ev) + + assert not queue.empty() + received = queue.get_nowait() + assert received.seq == 1 + assert received.did == "did:plc:test" + + def test_multiple_subscribers_receive_events(self): + cs = ChangeStream() + _, q1 = cs.subscribe() + _, q2 = cs.subscribe() + + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + cs.publish(ev) + + assert not q1.empty() + assert not q2.empty() + assert q1.get_nowait().seq == 1 + assert q2.get_nowait().seq == 1 + + def test_unsubscribe_removes_subscriber(self): + cs = ChangeStream() + sub_id, queue = cs.subscribe() + assert cs.subscriber_count == 1 + + cs.unsubscribe(sub_id) + assert cs.subscriber_count == 0 + + # Publishing after unsubscribe should not deliver + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + cs.publish(ev) + assert queue.empty() + + def test_full_queue_drops_event(self): + cs = ChangeStream(subscriber_queue_size=1) + _, queue = cs.subscribe() + + ev1 = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="a", + ) + ev2 = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="b", + ) + cs.publish(ev1) + cs.publish(ev2) # Should be dropped (queue full) + + assert queue.qsize() == 1 + assert queue.get_nowait().rkey == "a" + + def test_replay_from_cursor(self): + cs = ChangeStream(buffer_size=10) + for i in range(5): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + cs.publish(ev) + + # Replay from seq 3 — should get events 4 and 5 + replayed = cs.replay_from(3) + assert len(replayed) == 2 + assert replayed[0].seq == 4 + assert replayed[1].seq == 5 + + def test_replay_from_zero_returns_all(self): + cs = ChangeStream(buffer_size=10) + for i in range(3): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + cs.publish(ev) + + replayed = cs.replay_from(0) + assert len(replayed) == 3 + + def test_replay_cursor_too_old(self): + cs = ChangeStream(buffer_size=3) + for i in range(5): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + cs.publish(ev) + + # Buffer only holds seq 3, 4, 5 — cursor 1 is too old + replayed = cs.replay_from(1) + assert len(replayed) == 0 + + def test_replay_empty_buffer(self): + cs = ChangeStream() + assert cs.replay_from(0) == [] + + def test_bounded_buffer(self): + cs = ChangeStream(buffer_size=3) + for i in range(10): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + cs.publish(ev) + + assert len(cs._buffer) == 3 + assert cs._buffer[0].seq == 8 + assert cs._buffer[-1].seq == 10 + + +class TestChangeEvent: + def test_to_dict_create(self): + ev = ChangeEvent( + seq=1, + type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + timestamp="2026-01-01T00:00:00Z", + record={"name": "test"}, + cid="bafytest", + ) + d = ev.to_dict() + assert d["seq"] == 1 + assert d["type"] == "create" + assert d["record"] == {"name": "test"} + assert d["cid"] == "bafytest" + + def test_to_dict_delete_omits_record_and_cid(self): + ev = ChangeEvent( + seq=2, + type="delete", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + timestamp="2026-01-01T00:00:00Z", + ) + d = ev.to_dict() + assert "record" not in d + assert "cid" not in d + + +class TestMakeChangeEvent: + def test_creates_event_with_timestamp(self): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + record={"name": "test"}, + cid="bafytest", + ) + assert ev.seq == 0 # Not yet assigned + assert ev.type == "create" + assert ev.timestamp # Should have a timestamp + + +# --------------------------------------------------------------------------- +# Processor integration tests +# --------------------------------------------------------------------------- + +_DB = "atdata_app.database" + + +def _make_event( + did: str = "did:plc:test123", + collection: str = "science.alt.dataset.entry", + operation: str = "create", + rkey: str = "3xyz", + record: dict | None = None, + cid: str = "bafytest", +) -> dict: + commit: dict = { + "rev": "rev1", + "operation": operation, + "collection": collection, + "rkey": rkey, + } + if operation != "delete": + commit["record"] = record or { + "$type": collection, + "name": "test-dataset", + "schemaRef": "at://did:plc:test/science.alt.dataset.schema/test@1.0.0", + "storage": {"$type": "science.alt.dataset.storageHttp", "shards": []}, + "createdAt": "2025-01-01T00:00:00Z", + } + commit["cid"] = cid + + return { + "did": did, + "time_us": 1725911162329308, + "kind": "commit", + "commit": commit, + } + + +@pytest.mark.asyncio +@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock) +async def test_processor_publishes_create_event(mock_upsert): + pool = AsyncMock() + cs = ChangeStream() + _, queue = cs.subscribe() + + event = _make_event(operation="create") + await process_commit(pool, event, change_stream=cs) + + assert not queue.empty() + change_event = queue.get_nowait() + assert change_event.type == "create" + assert change_event.collection == "science.alt.dataset.entry" + assert change_event.did == "did:plc:test123" + assert change_event.rkey == "3xyz" + assert change_event.record is not None + assert change_event.cid == "bafytest" + + +@pytest.mark.asyncio +@patch(f"{_DB}.delete_record", new_callable=AsyncMock) +async def test_processor_publishes_delete_event(mock_delete): + pool = AsyncMock() + cs = ChangeStream() + _, queue = cs.subscribe() + + event = _make_event(operation="delete") + await process_commit(pool, event, change_stream=cs) + + assert not queue.empty() + change_event = queue.get_nowait() + assert change_event.type == "delete" + assert change_event.collection == "science.alt.dataset.entry" + assert change_event.record is None + assert change_event.cid is None + + +@pytest.mark.asyncio +@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock) +async def test_processor_no_event_on_upsert_failure(mock_upsert): + mock_upsert.side_effect = Exception("db error") + pool = AsyncMock() + cs = ChangeStream() + _, queue = cs.subscribe() + + event = _make_event(operation="create") + await process_commit(pool, event, change_stream=cs) + + assert queue.empty() + + +@pytest.mark.asyncio +@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock) +async def test_processor_works_without_change_stream(mock_upsert): + """Backward compat: process_commit works when change_stream is None.""" + pool = AsyncMock() + event = _make_event(operation="create") + await process_commit(pool, event) + mock_upsert.assert_called_once() + + +# --------------------------------------------------------------------------- +# WebSocket endpoint tests +# --------------------------------------------------------------------------- + + +@pytest.fixture +def app_with_changestream(): + """Minimal app with just the subscriptions router — no DB lifespan.""" + app = FastAPI() + app.state.change_stream = ChangeStream() + app.include_router(subscriptions_router, prefix="/xrpc") + return app + + +def test_websocket_subscribe_and_receive(app_with_changestream): + app = app_with_changestream + cs: ChangeStream = app.state.change_stream + + with TestClient(app) as client: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges" + ) as ws: + # Publish an event from another "thread" + cs.publish( + make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + record={"name": "test"}, + cid="bafytest", + ) + ) + + data = ws.receive_json() + assert data["seq"] == 1 + assert data["type"] == "create" + assert data["collection"] == "science.alt.dataset.entry" + assert data["did"] == "did:plc:test" + assert data["record"] == {"name": "test"} + + +def test_websocket_cursor_replay(app_with_changestream): + app = app_with_changestream + cs: ChangeStream = app.state.change_stream + + # Pre-populate buffer + for i in range(5): + cs.publish( + make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + ) + + with TestClient(app) as client: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges?cursor=3" + ) as ws: + # Should replay events 4 and 5 + msg1 = ws.receive_json() + assert msg1["seq"] == 4 + + msg2 = ws.receive_json() + assert msg2["seq"] == 5 + + +def test_websocket_disconnect_cleanup(app_with_changestream): + app = app_with_changestream + cs: ChangeStream = app.state.change_stream + + with TestClient(app) as client: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges" + ): + assert cs.subscriber_count == 1 + + # After disconnect, subscriber should be cleaned up + assert cs.subscriber_count == 0 + + +def test_websocket_multiple_subscribers(app_with_changestream): + app = app_with_changestream + cs: ChangeStream = app.state.change_stream + + with TestClient(app) as client: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges" + ) as ws1: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges" + ) as ws2: + assert cs.subscriber_count == 2 + + cs.publish( + make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + ) + + d1 = ws1.receive_json() + d2 = ws2.receive_json() + assert d1["seq"] == d2["seq"] == 1 + + assert cs.subscriber_count == 0 From beb9e633e6e6c1ba3c7ebf05608b2445fbbd3946 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 01:50:01 -0800 Subject: [PATCH 07/12] docs: add subscribeChanges to CHANGELOG Co-Authored-By: Claude Opus 4.6 --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index edb123d..1d37d0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/). - DID document service entry updated from `#atproto_appview` / `AtprotoAppView` to `#atdata_appview` / `AtdataAppView` ### Added +- Add real-time change stream subscribeChanges endpoint (#50) - Dual-hostname DID document support — serve different `did:web` documents for `api.atdata.app` (appview identity) and `atdata.app` (atproto account identity) based on the `Host` header ([#19](https://github.com/forecast-bio/atdata-app/issues/19)) - Host-based route gating middleware — frontend HTML routes are only served on the frontend hostname; the API subdomain serves only XRPC, health, and DID endpoints From 37f224032871368e780a12ad8e68fe4271864fbd Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 01:52:21 -0800 Subject: [PATCH 08/12] feat: add array format type recognition and ndarray v1.1.0 annotation display Add KNOWN_ARRAY_FORMATS constant and ARRAY_FORMAT_LABELS for the six recognized array format tokens (numpyBytes, parquetBytes, sparseBytes, structuredBytes, arrowTensor, safetensors). Update row_to_schema() to surface arrayFormat, dtype, shape, and dimensionNames from the schema body as top-level fields. Update frontend templates (schema detail, schemas list, profile, dataset detail) to display format and annotation info when present. Update MCP server descriptions to mention new formats. Co-Authored-By: Claude Opus 4.6 --- src/atdata_app/frontend/routes.py | 8 ++ .../frontend/templates/dataset.html | 14 +++ .../frontend/templates/profile.html | 3 +- src/atdata_app/frontend/templates/schema.html | 12 +++ .../frontend/templates/schemas.html | 3 +- src/atdata_app/mcp_server.py | 8 +- src/atdata_app/models.py | 39 ++++++++ tests/test_frontend.py | 97 ++++++++++++++++++- tests/test_models.py | 91 +++++++++++++++++ 9 files changed, 269 insertions(+), 6 deletions(-) diff --git a/src/atdata_app/frontend/routes.py b/src/atdata_app/frontend/routes.py index 9a815e4..d528841 100644 --- a/src/atdata_app/frontend/routes.py +++ b/src/atdata_app/frontend/routes.py @@ -105,6 +105,13 @@ async def dataset_detail(request: Request, did: str, rkey: str): if len(parts) == 3: schema_did, _, schema_rkey = parts + # Fetch the referenced schema for inline display of format/annotation info + schema_info = None + if schema_did and schema_rkey: + schema_row = await query_get_schema(pool, schema_did, schema_rkey) + if schema_row: + schema_info = row_to_schema(schema_row) + # Fetch labels pointing to this dataset dataset_uri = entry["uri"] label_rows = await query_labels_for_dataset(pool, dataset_uri) @@ -117,6 +124,7 @@ async def dataset_detail(request: Request, did: str, rkey: str): "entry": entry, "schema_did": schema_did, "schema_rkey": schema_rkey, + "schema_info": schema_info, "labels": labels, }, ) diff --git a/src/atdata_app/frontend/templates/dataset.html b/src/atdata_app/frontend/templates/dataset.html index 4b32fed..78f20b8 100644 --- a/src/atdata_app/frontend/templates/dataset.html +++ b/src/atdata_app/frontend/templates/dataset.html @@ -27,6 +27,20 @@

Details

License{{ entry.license }} {% endif %} Schema{{ entry.schemaRef }} + {% if schema_info %} + {% if schema_info.arrayFormat is defined %} + Array Format{{ schema_info.get("arrayFormatLabel", schema_info.arrayFormat) }} + {% endif %} + {% if schema_info.dtype is defined %} + Data Type{{ schema_info.dtype }} + {% endif %} + {% if schema_info.shape is defined %} + Shape{{ schema_info.shape | join(" × ") }} + {% endif %} + {% if schema_info.dimensionNames is defined %} + Dimensions{{ schema_info.dimensionNames | join(", ") }} + {% endif %} + {% endif %} {% if entry.size %} Size diff --git a/src/atdata_app/frontend/templates/profile.html b/src/atdata_app/frontend/templates/profile.html index 46026b4..04c4e29 100644 --- a/src/atdata_app/frontend/templates/profile.html +++ b/src/atdata_app/frontend/templates/profile.html @@ -34,13 +34,14 @@

{{ entry.name }}

Schemas {% if schemas %} - + {% for s in schemas %} + {% endfor %} diff --git a/src/atdata_app/frontend/templates/schema.html b/src/atdata_app/frontend/templates/schema.html index 281d9f6..16ede21 100644 --- a/src/atdata_app/frontend/templates/schema.html +++ b/src/atdata_app/frontend/templates/schema.html @@ -16,6 +16,18 @@

Details

+ {% if schema.arrayFormat is defined %} + + {% endif %} + {% if schema.dtype is defined %} + + {% endif %} + {% if schema.shape is defined %} + + {% endif %} + {% if schema.dimensionNames is defined %} + + {% endif %} diff --git a/src/atdata_app/frontend/templates/schemas.html b/src/atdata_app/frontend/templates/schemas.html index 268af39..fdf3565 100644 --- a/src/atdata_app/frontend/templates/schemas.html +++ b/src/atdata_app/frontend/templates/schemas.html @@ -7,7 +7,7 @@

Schemas

{% if schemas %}
NameVersionType
NameVersionTypeFormat
{{ s.name }} {{ s.version }} {{ s.schemaType }}{{ s.get("arrayFormatLabel", "") }}
AT-URI{{ schema.uri }}
Type{{ schema.schemaType }}
Array Format{{ schema.get("arrayFormatLabel", schema.arrayFormat) }}
Data Type{{ schema.dtype }}
Shape{{ schema.shape | join(" × ") }}
Dimensions{{ schema.dimensionNames | join(", ") }}
Version{{ schema.version }}
Created{{ schema.createdAt }}
- + {% for s in schemas %} @@ -15,6 +15,7 @@

Schemas

+ diff --git a/src/atdata_app/mcp_server.py b/src/atdata_app/mcp_server.py index 53991eb..a5d2e69 100644 --- a/src/atdata_app/mcp_server.py +++ b/src/atdata_app/mcp_server.py @@ -57,7 +57,10 @@ async def server_lifespan(server: FastMCP) -> AsyncIterator[ServerContext]: "ATProto AppView for the science.alt.dataset namespace. " "Use these tools to discover and query scientific datasets, " "schemas, and lenses (bidirectional schema transforms) published " - "on the AT Protocol network." + "on the AT Protocol network. " + "Schemas may specify an arrayFormat (numpyBytes, parquetBytes, " + "sparseBytes, structuredBytes, arrowTensor, safetensors) and " + "ndarray annotations (dtype, shape, dimensionNames)." ), lifespan=server_lifespan, ) @@ -127,7 +130,8 @@ async def get_schema(ctx: Ctx, uri: str) -> dict[str, Any]: uri: AT-URI of the schema (e.g. at://did:plc:abc/science.alt.dataset.schema/my.schema@1.0.0). Returns: - Full schema record including name, version, type, schema body, and description. + Full schema record including name, version, type, schema body, description, + and (when present) arrayFormat, dtype, shape, and dimensionNames. """ sc = _get_ctx(ctx) did, _collection, rkey = parse_at_uri(uri) diff --git a/src/atdata_app/models.py b/src/atdata_app/models.py index 1cde2f2..72e31da 100644 --- a/src/atdata_app/models.py +++ b/src/atdata_app/models.py @@ -9,6 +9,32 @@ from pydantic import BaseModel +# --------------------------------------------------------------------------- +# Known array format tokens (atdata-lexicon#21) +# --------------------------------------------------------------------------- + +KNOWN_ARRAY_FORMATS: set[str] = { + # Original formats + "numpyBytes", + "parquetBytes", + # New formats + "sparseBytes", + "structuredBytes", + "arrowTensor", + "safetensors", +} + +#: Human-friendly display names for array format tokens. +ARRAY_FORMAT_LABELS: dict[str, str] = { + "numpyBytes": "NumPy ndarray", + "parquetBytes": "Parquet", + "sparseBytes": "Sparse matrix (CSR/CSC/COO)", + "structuredBytes": "NumPy structured array", + "arrowTensor": "Arrow tensor IPC", + "safetensors": "Safetensors", +} + + # --------------------------------------------------------------------------- # AT-URI parsing # --------------------------------------------------------------------------- @@ -119,6 +145,19 @@ def row_to_schema(row) -> dict[str, Any]: } if row["description"]: d["description"] = row["description"] + + # Surface array format and ndarray v1.1.0 annotation fields for display + array_format = schema_body.get("arrayFormat") + if array_format: + d["arrayFormat"] = array_format + d["arrayFormatLabel"] = ARRAY_FORMAT_LABELS.get(array_format, array_format) + if schema_body.get("dtype"): + d["dtype"] = schema_body["dtype"] + if schema_body.get("shape"): + d["shape"] = schema_body["shape"] + if schema_body.get("dimensionNames"): + d["dimensionNames"] = schema_body["dimensionNames"] + return d diff --git a/tests/test_frontend.py b/tests/test_frontend.py index 43f714f..7e3d71e 100644 --- a/tests/test_frontend.py +++ b/tests/test_frontend.py @@ -41,6 +41,7 @@ def _make_schema_row( did: str = "did:plc:test123", rkey: str = "test@1.0.0", name: str = "TestSchema", + schema_body: str | dict = '{"type": "object"}', ) -> dict: return { "did": did, @@ -49,7 +50,7 @@ def _make_schema_row( "name": name, "version": "1.0.0", "schema_type": "jsonSchema", - "schema_body": '{"type": "object"}', + "schema_body": schema_body, "description": "A test schema", "metadata": None, "created_at": "2025-01-01T00:00:00Z", @@ -140,10 +141,12 @@ async def test_home_search(mock_search): @pytest.mark.asyncio @patch("atdata_app.frontend.routes.query_labels_for_dataset", new_callable=AsyncMock) +@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock) @patch("atdata_app.frontend.routes.query_get_entry", new_callable=AsyncMock) -async def test_dataset_detail(mock_get, mock_labels): +async def test_dataset_detail(mock_get, mock_schema, mock_labels): pool, _conn = _mock_pool() mock_get.return_value = _make_entry_row() + mock_schema.return_value = _make_schema_row() mock_labels.return_value = [_make_label_row()] app = _make_app(pool) transport = ASGITransport(app=app) @@ -260,6 +263,96 @@ async def test_about(mock_counts): assert "did:web:localhost%3A8000" in resp.text +# --------------------------------------------------------------------------- +# Schema detail — array format & ndarray annotations +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock) +async def test_schema_detail_array_format(mock_get): + pool, _conn = _mock_pool() + mock_get.return_value = _make_schema_row( + schema_body={ + "arrayFormat": "sparseBytes", + "dtype": "float32", + "shape": [100, 200], + "dimensionNames": ["samples", "features"], + }, + ) + app = _make_app(pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/schema/did:plc:test123/test@1.0.0") + assert resp.status_code == 200 + assert "Sparse matrix" in resp.text + assert "float32" in resp.text + assert "100" in resp.text + assert "samples" in resp.text + + +@pytest.mark.asyncio +@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock) +async def test_schema_detail_no_array_format(mock_get): + """Plain schemas should not show array format rows.""" + pool, _conn = _mock_pool() + mock_get.return_value = _make_schema_row() + app = _make_app(pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/schema/did:plc:test123/test@1.0.0") + assert resp.status_code == 200 + assert "Array Format" not in resp.text + assert "Data Type" not in resp.text + + +# --------------------------------------------------------------------------- +# Dataset detail — schema format info +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("atdata_app.frontend.routes.query_labels_for_dataset", new_callable=AsyncMock) +@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock) +@patch("atdata_app.frontend.routes.query_get_entry", new_callable=AsyncMock) +async def test_dataset_detail_with_schema_format(mock_entry, mock_schema, mock_labels): + pool, _conn = _mock_pool() + mock_entry.return_value = _make_entry_row() + mock_schema.return_value = _make_schema_row( + did="did:plc:test", + rkey="test@1.0.0", + schema_body={"arrayFormat": "numpyBytes", "dtype": "float64"}, + ) + mock_labels.return_value = [] + app = _make_app(pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/dataset/did:plc:test123/3xyz") + assert resp.status_code == 200 + assert "NumPy ndarray" in resp.text + assert "float64" in resp.text + + +# --------------------------------------------------------------------------- +# Schemas list — format column +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("atdata_app.frontend.routes.query_list_schemas", new_callable=AsyncMock) +async def test_schemas_list_shows_format(mock_list): + pool, _conn = _mock_pool() + mock_list.return_value = [ + _make_schema_row(schema_body={"arrayFormat": "safetensors"}), + ] + app = _make_app(pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/schemas") + assert resp.status_code == 200 + assert "Safetensors" in resp.text + + # --------------------------------------------------------------------------- # Static files # --------------------------------------------------------------------------- diff --git a/tests/test_models.py b/tests/test_models.py index 23f6046..91b3089 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,6 +5,8 @@ import pytest from atdata_app.models import ( + ARRAY_FORMAT_LABELS, + KNOWN_ARRAY_FORMATS, decode_cursor, encode_cursor, make_at_uri, @@ -174,6 +176,95 @@ def test_row_to_schema_json_string_body(): assert d["schema"] == {"type": "object"} +def test_row_to_schema_no_array_format_fields_when_absent(): + """Plain schemas should not gain arrayFormat/ndarray annotation keys.""" + d = row_to_schema(_SCHEMA_ROW) + assert "arrayFormat" not in d + assert "arrayFormatLabel" not in d + assert "dtype" not in d + assert "shape" not in d + assert "dimensionNames" not in d + + +# --------------------------------------------------------------------------- +# row_to_schema — array format types +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("fmt", sorted(KNOWN_ARRAY_FORMATS)) +def test_row_to_schema_known_array_format(fmt): + """Each known format token should surface arrayFormat and a human label.""" + row = {**_SCHEMA_ROW, "schema_body": {"arrayFormat": fmt}} + d = row_to_schema(row) + assert d["arrayFormat"] == fmt + assert d["arrayFormatLabel"] == ARRAY_FORMAT_LABELS[fmt] + + +def test_row_to_schema_unknown_array_format_passes_through(): + """Unknown format tokens are stored and surfaced as-is.""" + row = {**_SCHEMA_ROW, "schema_body": {"arrayFormat": "futureFormat"}} + d = row_to_schema(row) + assert d["arrayFormat"] == "futureFormat" + assert d["arrayFormatLabel"] == "futureFormat" + + +# --------------------------------------------------------------------------- +# row_to_schema — ndarray v1.1.0 annotations +# --------------------------------------------------------------------------- + + +def test_row_to_schema_ndarray_annotations(): + """ndarray v1.1.0 annotation fields are surfaced at top level.""" + row = { + **_SCHEMA_ROW, + "schema_body": { + "arrayFormat": "numpyBytes", + "dtype": "float32", + "shape": [100, 200], + "dimensionNames": ["samples", "features"], + }, + } + d = row_to_schema(row) + assert d["arrayFormat"] == "numpyBytes" + assert d["dtype"] == "float32" + assert d["shape"] == [100, 200] + assert d["dimensionNames"] == ["samples", "features"] + + +def test_row_to_schema_ndarray_partial_annotations(): + """Only present annotation fields should appear in output.""" + row = { + **_SCHEMA_ROW, + "schema_body": {"arrayFormat": "sparseBytes", "dtype": "int64"}, + } + d = row_to_schema(row) + assert d["dtype"] == "int64" + assert "shape" not in d + assert "dimensionNames" not in d + + +# --------------------------------------------------------------------------- +# KNOWN_ARRAY_FORMATS constant +# --------------------------------------------------------------------------- + + +def test_known_array_formats_contains_all_expected(): + expected = { + "numpyBytes", + "parquetBytes", + "sparseBytes", + "structuredBytes", + "arrowTensor", + "safetensors", + } + assert KNOWN_ARRAY_FORMATS == expected + + +def test_array_format_labels_covers_all_known(): + """Every known format should have a human-readable label.""" + assert set(ARRAY_FORMAT_LABELS.keys()) == KNOWN_ARRAY_FORMATS + + # --------------------------------------------------------------------------- # row_to_label # --------------------------------------------------------------------------- From 60e9b4bc1fb1e89801a5e51c77382a4f44601fea Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 08:14:07 -0800 Subject: [PATCH 09/12] release: prepare v0.4.0b1 Co-Authored-By: Claude Opus 4.6 --- CHANGELOG.md | 18 ++++++++++++++++-- pyproject.toml | 2 +- uv.lock | 2 +- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 019429c..ed398b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/), and this project adheres to [Semantic Versioning](https://semver.org/). +## [0.4.0b1] - 2026-02-26 + +### Added + +- `sendInteractions` XRPC procedure for anonymous usage telemetry — fire-and-forget reporting of download, citation, and derivative events on datasets ([#21](https://github.com/forecast-bio/atdata-app/issues/21)) +- Skeleton/hydration pattern for third-party dataset indexes — `getIndexSkeleton`, `getIndex`, `listIndexes`, and `publishIndex` endpoints following Bluesky's feed generator model ([#20](https://github.com/forecast-bio/atdata-app/issues/20)) +- `subscribeChanges` WebSocket endpoint for real-time change streaming — in-memory event bus broadcasts create/update/delete events to subscribers with cursor-based replay ([#22](https://github.com/forecast-bio/atdata-app/issues/22)) +- Array format type recognition (`sparseBytes`, `structuredBytes`, `arrowTensor`, `safetensors`) and ndarray v1.1.0 annotation display (`dtype`, `shape`, `dimensionNames`) in frontend templates ([#30](https://github.com/forecast-bio/atdata-app/issues/30)) +- `atdata-lexicon` git submodule at `lexicons/` pinned to v0.2.1b1 for reference and CI validation ([#27](https://github.com/forecast-bio/atdata-app/issues/27)) +- CI checkout steps now initialize submodules + +### Changed + +- Ingestion processor refactored to use `UPSERT_FNS` dispatch dict instead of if/elif chain +- Index provider records (`science.alt.dataset.index`) added to `COLLECTION_TABLE_MAP` for firehose ingestion + ## [0.3.0b1] - 2026-02-22 ### Changed @@ -19,8 +35,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/). - DID document service entry updated from `#atproto_appview` / `AtprotoAppView` to `#atdata_appview` / `AtdataAppView` ### Added -- Add real-time change stream subscribeChanges endpoint (#50) -- Add sendInteractions XRPC procedure for usage telemetry (#35) - Dual-hostname DID document support — serve different `did:web` documents for `api.atdata.app` (appview identity) and `atdata.app` (atproto account identity) based on the `Host` header ([#19](https://github.com/forecast-bio/atdata-app/issues/19)) - Host-based route gating middleware — frontend HTML routes are only served on the frontend hostname; the API subdomain serves only XRPC, health, and DID endpoints diff --git a/pyproject.toml b/pyproject.toml index 708c0ac..32b3071 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "atdata-app" -version = "0.3.0b1" +version = "0.4.0b1" description = "ATProto AppView for science.alt.dataset" readme = "README.md" authors = [ diff --git a/uv.lock b/uv.lock index 72b071a..2dfa852 100644 --- a/uv.lock +++ b/uv.lock @@ -75,7 +75,7 @@ wheels = [ [[package]] name = "atdata-app" -version = "0.3.0b1" +version = "0.4.0b1" source = { editable = "." } dependencies = [ { name = "asyncpg" }, From 991e6df711106304ec31e39045524535c22a3cd0 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 09:47:12 -0800 Subject: [PATCH 10/12] fix: address critical findings from adversarial review - SSRF: validate endpoint URLs with DNS resolution and private IP blocking at fetch time (queries.py) and ingestion time (database.py) - Auth: add service auth to sendInteractions endpoint - Backpressure: track dropped subscribers in ChangeStream instead of silently losing events; close WebSocket with code 4000 on drop - Subscriber limits: cap ChangeStream to 1000 subscribers, reject with WebSocket close code 1013 when full - Replay dedup: track last replayed seq to avoid sending duplicate events from both replay buffer and live queue - Keepalive: fix broken loop structure so timeout re-enters event loop - Task GC: retain asyncio.Task references to prevent garbage collection of fire-and-forget analytics tasks - Skeleton cap: enforce requested limit on items returned by external index providers - Remove dead timestamp validation code from sendInteractions - Sanitize error messages to avoid leaking internal URLs Co-Authored-By: Claude Opus 4.6 --- src/atdata_app/changestream.py | 25 ++++++++++--- src/atdata_app/database.py | 28 ++++++++++++++- src/atdata_app/xrpc/procedures.py | 25 ++----------- src/atdata_app/xrpc/queries.py | 46 ++++++++++++++++++++++-- src/atdata_app/xrpc/subscriptions.py | 32 ++++++++++++++--- tests/test_analytics.py | 54 ++++++++++++++++++++-------- 6 files changed, 160 insertions(+), 50 deletions(-) diff --git a/src/atdata_app/changestream.py b/src/atdata_app/changestream.py index 3d6510c..17c3593 100644 --- a/src/atdata_app/changestream.py +++ b/src/atdata_app/changestream.py @@ -17,6 +17,7 @@ DEFAULT_BUFFER_SIZE = 1000 DEFAULT_SUBSCRIBER_QUEUE_SIZE = 256 +DEFAULT_MAX_SUBSCRIBERS = 1000 @dataclass @@ -57,11 +58,13 @@ class ChangeStream: buffer_size: int = DEFAULT_BUFFER_SIZE subscriber_queue_size: int = DEFAULT_SUBSCRIBER_QUEUE_SIZE + max_subscribers: int = DEFAULT_MAX_SUBSCRIBERS _seq: int = field(default=0, init=False) _buffer: deque[ChangeEvent] = field(init=False) _subscribers: dict[int, asyncio.Queue[ChangeEvent]] = field( default_factory=dict, init=False ) + _dropped_subs: set[int] = field(default_factory=set, init=False) _next_sub_id: int = field(default=0, init=False) def __post_init__(self) -> None: @@ -70,9 +73,8 @@ def __post_init__(self) -> None: def publish(self, event: ChangeEvent) -> None: """Publish an event to all subscribers and the replay buffer. - Non-blocking. If a subscriber's queue is full, the event is dropped - for that subscriber (backpressure via disconnect is handled by the - WebSocket handler). + Non-blocking. If a subscriber's queue is full, the subscriber is + marked as dropped so the WebSocket handler can close the connection. """ self._seq += 1 event.seq = self._seq @@ -83,13 +85,21 @@ def publish(self, event: ChangeEvent) -> None: queue.put_nowait(event) except asyncio.QueueFull: logger.warning( - "Subscriber %d queue full, dropping event seq=%d", + "Subscriber %d queue full at seq=%d — marking for disconnect", sub_id, event.seq, ) + self._dropped_subs.add(sub_id) def subscribe(self) -> tuple[int, asyncio.Queue[ChangeEvent]]: - """Create a new subscriber. Returns (subscriber_id, queue).""" + """Create a new subscriber. Returns (subscriber_id, queue). + + Raises ``RuntimeError`` if the maximum subscriber count is reached. + """ + if len(self._subscribers) >= self.max_subscribers: + raise RuntimeError( + f"Maximum subscriber count ({self.max_subscribers}) reached" + ) sub_id = self._next_sub_id self._next_sub_id += 1 queue: asyncio.Queue[ChangeEvent] = asyncio.Queue( @@ -102,8 +112,13 @@ def subscribe(self) -> tuple[int, asyncio.Queue[ChangeEvent]]: def unsubscribe(self, sub_id: int) -> None: """Remove a subscriber.""" self._subscribers.pop(sub_id, None) + self._dropped_subs.discard(sub_id) logger.debug("Subscriber %d disconnected (total: %d)", sub_id, len(self._subscribers)) + def is_dropped(self, sub_id: int) -> bool: + """Return True if the subscriber was dropped due to backpressure.""" + return sub_id in self._dropped_subs + def replay_from(self, cursor: int) -> list[ChangeEvent]: """Return buffered events with seq > cursor. diff --git a/src/atdata_app/database.py b/src/atdata_app/database.py index f6adb85..02e2657 100644 --- a/src/atdata_app/database.py +++ b/src/atdata_app/database.py @@ -204,6 +204,21 @@ async def upsert_lens( ) +def _is_safe_endpoint_url(url: str) -> bool: + """Return True if *url* is HTTPS with no credentials or private IPs.""" + from urllib.parse import urlparse + + try: + parsed = urlparse(url) + except ValueError: + return False + if parsed.scheme != "https" or not parsed.hostname: + return False + if parsed.username or parsed.password: + return False + return True + + async def upsert_index_provider( pool: asyncpg.Pool, did: str, @@ -211,6 +226,12 @@ async def upsert_index_provider( cid: str | None, record: dict[str, Any], ) -> None: + endpoint_url = record.get("endpointUrl", "") + if not _is_safe_endpoint_url(endpoint_url): + logger.warning( + "Rejected index provider %s/%s: unsafe endpoint URL %s", did, rkey, endpoint_url + ) + return async with pool.acquire() as conn: await conn.execute( """ @@ -686,6 +707,9 @@ async def record_analytics_event( logger.warning("Failed to record analytics event %s", event_type, exc_info=True) +_background_tasks: set[asyncio.Task] = set() + + def fire_analytics_event( pool: asyncpg.Pool, event_type: str, @@ -694,9 +718,11 @@ def fire_analytics_event( query_params: dict[str, Any] | None = None, ) -> None: """Fire-and-forget analytics recording. Does not block the caller.""" - asyncio.create_task( + task = asyncio.create_task( record_analytics_event(pool, event_type, target_did, target_rkey, query_params) ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) PERIOD_INTERVALS: dict[str, timedelta] = { diff --git a/src/atdata_app/xrpc/procedures.py b/src/atdata_app/xrpc/procedures.py index 7aa79f9..3c46f92 100644 --- a/src/atdata_app/xrpc/procedures.py +++ b/src/atdata_app/xrpc/procedures.py @@ -7,7 +7,6 @@ from __future__ import annotations import logging -from datetime import datetime from typing import Any from urllib.parse import urlparse @@ -319,14 +318,11 @@ async def publish_index(request: Request) -> dict[str, Any]: _MAX_INTERACTIONS_BATCH = 100 -def _validate_iso8601(value: str) -> None: - """Raise ValueError if *value* is not a valid ISO 8601 datetime string.""" - # datetime.fromisoformat handles the common subset we accept - datetime.fromisoformat(value) - - @router.post("/science.alt.dataset.sendInteractions") async def send_interactions(request: Request) -> dict[str, Any]: + """Record dataset interaction events (anonymous, fire-and-forget).""" + await verify_service_auth(request, "science.alt.dataset.sendInteractions") + pool = request.app.state.db_pool body = await request.json() @@ -365,21 +361,6 @@ async def send_interactions(request: Request) -> dict[str, Any]: detail=f"interactions[{i}]: invalid AT-URI: {dataset_uri}", ) - timestamp = item.get("timestamp") - if timestamp is not None: - if not isinstance(timestamp, str): - raise HTTPException( - status_code=400, - detail=f"interactions[{i}]: timestamp must be a string", - ) - try: - _validate_iso8601(timestamp) - except ValueError: - raise HTTPException( - status_code=400, - detail=f"interactions[{i}]: invalid ISO 8601 timestamp: {timestamp}", - ) - # All valid — fire analytics events for item in interactions: did, _collection, rkey = parse_at_uri(item["datasetUri"]) diff --git a/src/atdata_app/xrpc/queries.py b/src/atdata_app/xrpc/queries.py index 218aa44..5adf8ec 100644 --- a/src/atdata_app/xrpc/queries.py +++ b/src/atdata_app/xrpc/queries.py @@ -7,6 +7,10 @@ from fastapi import APIRouter, HTTPException, Query, Request +import ipaddress +import socket +from urllib.parse import urlparse + import httpx from atdata_app import get_resolver @@ -374,21 +378,55 @@ async def list_indexes( # --------------------------------------------------------------------------- +def _validate_endpoint_url(url: str) -> None: + """Reject URLs that could cause SSRF (private IPs, non-HTTPS, credentials). + + Raises ``HTTPException`` (400) if the URL is unsafe. + """ + try: + parsed = urlparse(url) + except ValueError: + raise HTTPException(status_code=400, detail="Malformed endpoint URL") + + if parsed.scheme != "https": + raise HTTPException(status_code=400, detail="Endpoint URL must use HTTPS") + if not parsed.hostname: + raise HTTPException(status_code=400, detail="Endpoint URL missing hostname") + if parsed.username or parsed.password: + raise HTTPException(status_code=400, detail="Endpoint URL must not contain credentials") + + # Resolve hostname and block private/reserved IP ranges + try: + infos = socket.getaddrinfo(parsed.hostname, None, proto=socket.IPPROTO_TCP) + except socket.gaierror: + raise HTTPException(status_code=502, detail="Index provider hostname unresolvable") + + for _family, _type, _proto, _canonname, sockaddr in infos: + ip = ipaddress.ip_address(sockaddr[0]) + if ip.is_private or ip.is_reserved or ip.is_loopback or ip.is_link_local: + raise HTTPException( + status_code=400, + detail="Endpoint URL must not resolve to a private/reserved IP address", + ) + + async def _fetch_skeleton( endpoint_url: str, cursor: str | None, limit: int, ) -> dict[str, Any]: """Fetch skeleton from an upstream index provider.""" + _validate_endpoint_url(endpoint_url) + params: dict[str, Any] = {"limit": limit} if cursor: params["cursor"] = cursor async with httpx.AsyncClient(timeout=10.0) as http: try: resp = await http.get(endpoint_url, params=params) - except httpx.HTTPError as e: + except (httpx.HTTPError, ValueError) as e: raise HTTPException( - status_code=502, detail=f"Index provider unreachable: {e}" + status_code=502, detail="Index provider unreachable" ) from e if resp.status_code != 200: raise HTTPException( @@ -399,12 +437,14 @@ async def _fetch_skeleton( data = resp.json() except (ValueError, KeyError) as e: raise HTTPException( - status_code=502, detail=f"Invalid response from index provider: {e}" + status_code=502, detail="Invalid response from index provider" ) from e if not isinstance(data.get("items"), list): raise HTTPException( status_code=502, detail="Index provider response missing 'items' array" ) + # Cap items to the requested limit regardless of what upstream returns + data["items"] = data["items"][:limit] return data diff --git a/src/atdata_app/xrpc/subscriptions.py b/src/atdata_app/xrpc/subscriptions.py index cdcbf60..80bfff4 100644 --- a/src/atdata_app/xrpc/subscriptions.py +++ b/src/atdata_app/xrpc/subscriptions.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import json import logging @@ -26,10 +27,17 @@ async def subscribe_changes(websocket: WebSocket) -> None: await websocket.accept() cursor_param = websocket.query_params.get("cursor") - sub_id, queue = change_stream.subscribe() try: - # Replay buffered events if cursor provided + sub_id, queue = change_stream.subscribe() + except RuntimeError: + await websocket.close(code=1013, reason="Too many subscribers") + return + + try: + # Replay buffered events if cursor provided, tracking last seq + # to deduplicate against events that also landed in the live queue. + last_replayed_seq = 0 if cursor_param is not None: try: cursor = int(cursor_param) @@ -39,11 +47,27 @@ async def subscribe_changes(websocket: WebSocket) -> None: missed = change_stream.replay_from(cursor) for event in missed: await websocket.send_text(json.dumps(event.to_dict())) + last_replayed_seq = event.seq - # Stream live events + # Stream live events with periodic keepalive on idle while True: - event = await queue.get() + try: + event = await asyncio.wait_for(queue.get(), timeout=30.0) + except asyncio.TimeoutError: + # No events for 30s — send keepalive to detect dead connections + await websocket.send_text(json.dumps({"type": "keepalive"})) + continue + + # Deduplicate events already sent during replay + if event.seq <= last_replayed_seq: + continue await websocket.send_text(json.dumps(event.to_dict())) + # Check if we were marked as dropped due to backpressure + if change_stream.is_dropped(sub_id): + await websocket.close( + code=4000, reason="Backpressure: events were dropped" + ) + return except WebSocketDisconnect: logger.debug("Subscriber %d disconnected", sub_id) diff --git a/tests/test_analytics.py b/tests/test_analytics.py index 8df0a14..1ed9d04 100644 --- a/tests/test_analytics.py +++ b/tests/test_analytics.py @@ -394,8 +394,10 @@ def test_get_entry_stats_response_defaults_interactions(): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_valid_batch(mock_fire, config, pool): +async def test_send_interactions_valid_batch(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -424,8 +426,10 @@ async def test_send_interactions_valid_batch(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_empty_array(mock_fire, config, pool): +async def test_send_interactions_empty_array(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -440,8 +444,10 @@ async def test_send_interactions_empty_array(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_invalid_uri(mock_fire, config, pool): +async def test_send_interactions_invalid_uri(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -460,8 +466,10 @@ async def test_send_interactions_invalid_uri(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_invalid_type(mock_fire, config, pool): +async def test_send_interactions_invalid_type(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -483,8 +491,10 @@ async def test_send_interactions_invalid_type(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_batch_size_exceeded(mock_fire, config, pool): +async def test_send_interactions_batch_size_exceeded(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) interactions = [ @@ -503,8 +513,11 @@ async def test_send_interactions_batch_size_exceeded(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_invalid_timestamp(mock_fire, config, pool): +async def test_send_interactions_ignores_timestamp(mock_fire, mock_auth, config, pool): + """Timestamp field is accepted but not validated (informational only).""" + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -521,14 +534,15 @@ async def test_send_interactions_invalid_timestamp(mock_fire, config, pool): }, ) - assert resp.status_code == 400 - assert "invalid ISO 8601 timestamp" in resp.json()["detail"] - mock_fire.assert_not_called() + assert resp.status_code == 200 + assert mock_fire.call_count == 1 @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_missing_dataset_uri(mock_fire, config, pool): +async def test_send_interactions_missing_dataset_uri(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -547,8 +561,10 @@ async def test_send_interactions_missing_dataset_uri(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_not_an_array(mock_fire, config, pool): +async def test_send_interactions_not_an_array(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -563,8 +579,10 @@ async def test_send_interactions_not_an_array(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_all_three_types(mock_fire, config, pool): +async def test_send_interactions_all_three_types(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -593,9 +611,11 @@ async def test_send_interactions_all_three_types(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_missing_key(mock_fire, config, pool): +async def test_send_interactions_missing_key(mock_fire, mock_auth, config, pool): """Body without 'interactions' key should return 400.""" + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -610,9 +630,11 @@ async def test_send_interactions_missing_key(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_non_dict_item(mock_fire, config, pool): +async def test_send_interactions_non_dict_item(mock_fire, mock_auth, config, pool): """Non-object items in the interactions array should return 400.""" + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: @@ -627,9 +649,11 @@ async def test_send_interactions_non_dict_item(mock_fire, config, pool): @pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) @patch(f"{_PROC}.fire_analytics_event") -async def test_send_interactions_boundary_at_max(mock_fire, config, pool): +async def test_send_interactions_boundary_at_max(mock_fire, mock_auth, config, pool): """Exactly 100 interactions (the maximum) should succeed.""" + mock_auth.return_value = MagicMock(iss="did:plc:caller") app = _mock_app(config, pool) transport = ASGITransport(app=app) interactions = [ From 1260f6e7f64d437c0da78df9b81bf73c32b9e8b6 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:53:55 -0800 Subject: [PATCH 11/12] fix: address warning-level findings from adversarial review - W2: Cap upstream skeleton response to 1 MiB to prevent memory exhaustion from malicious index providers - W3: Guard query_get_entries with a 100-key limit to prevent unbounded OR-clause queries - W4: Whitelist skeleton item fields to only 'uri', preventing injection of unexpected fields by upstream providers - W6: Validate skeleton cursor passthrough (length cap, no null bytes) - W8: Validate that sendInteractions datasetUri references science.alt.dataset.entry collection specifically - W14: Prevent javascript:/data: URI XSS in storage URL href by only rendering http(s) URLs as clickable links - W15: Guard template join filter with iterable checks to prevent crashes on malformed shape/dimensionNames data - W16: Add missing ingestion test for index_providers collection - Harden publishIndex URL validation: reject credentials and fragments Co-Authored-By: Claude Opus 4.6 --- src/atdata_app/database.py | 7 +++++ .../frontend/templates/dataset.html | 8 +++-- src/atdata_app/frontend/templates/schema.html | 4 +-- src/atdata_app/xrpc/procedures.py | 19 ++++++++++-- src/atdata_app/xrpc/queries.py | 29 +++++++++++++++++++ tests/test_ingestion.py | 20 +++++++++++++ 6 files changed, 80 insertions(+), 7 deletions(-) diff --git a/src/atdata_app/database.py b/src/atdata_app/database.py index 02e2657..6543530 100644 --- a/src/atdata_app/database.py +++ b/src/atdata_app/database.py @@ -366,11 +366,18 @@ async def query_get_entry( ) +_MAX_GET_ENTRIES_KEYS = 100 + + async def query_get_entries( pool: asyncpg.Pool, keys: list[tuple[str, str]] ) -> list[asyncpg.Record]: if not keys: return [] + if len(keys) > _MAX_GET_ENTRIES_KEYS: + raise ValueError( + f"query_get_entries: too many keys ({len(keys)}), max {_MAX_GET_ENTRIES_KEYS}" + ) conditions = " OR ".join( f"(did = ${i * 2 + 1} AND rkey = ${i * 2 + 2})" for i in range(len(keys)) ) diff --git a/src/atdata_app/frontend/templates/dataset.html b/src/atdata_app/frontend/templates/dataset.html index 78f20b8..ddd4760 100644 --- a/src/atdata_app/frontend/templates/dataset.html +++ b/src/atdata_app/frontend/templates/dataset.html @@ -34,10 +34,10 @@

Details

{% if schema_info.dtype is defined %}
{% endif %} - {% if schema_info.shape is defined %} + {% if schema_info.shape is defined and schema_info.shape is iterable and schema_info.shape is not string %} {% endif %} - {% if schema_info.dimensionNames is defined %} + {% if schema_info.dimensionNames is defined and schema_info.dimensionNames is iterable and schema_info.dimensionNames is not string %} {% endif %} {% endif %} @@ -63,7 +63,11 @@

Storage

{% if entry.storage.url is defined %} + {% if entry.storage.url.startswith("https://") or entry.storage.url.startswith("http://") %} + {% else %} + + {% endif %} {% endif %}
NameVersionTypeDescriptionPublisher
NameVersionTypeFormatDescriptionPublisher
{{ s.name }} {{ s.version }} {{ s.schemaType }}{{ s.get("arrayFormatLabel", "") }} {{ s.get("description", "") }} {{ s.did[:20] }}…
Data Type{{ schema_info.dtype }}
Shape{{ schema_info.shape | join(" × ") }}
Dimensions{{ schema_info.dimensionNames | join(", ") }}
Type{{ entry.storage.get("$type", "unknown") }}
URL{{ entry.storage.url }}
URL{{ entry.storage.url }}
diff --git a/src/atdata_app/frontend/templates/schema.html b/src/atdata_app/frontend/templates/schema.html index 16ede21..3df1859 100644 --- a/src/atdata_app/frontend/templates/schema.html +++ b/src/atdata_app/frontend/templates/schema.html @@ -22,10 +22,10 @@

Details

{% if schema.dtype is defined %} Data Type{{ schema.dtype }} {% endif %} - {% if schema.shape is defined %} + {% if schema.shape is defined and schema.shape is iterable and schema.shape is not string %} Shape{{ schema.shape | join(" × ") }} {% endif %} - {% if schema.dimensionNames is defined %} + {% if schema.dimensionNames is defined and schema.dimensionNames is iterable and schema.dimensionNames is not string %} Dimensions{{ schema.dimensionNames | join(", ") }} {% endif %} Version{{ schema.version }} diff --git a/src/atdata_app/xrpc/procedures.py b/src/atdata_app/xrpc/procedures.py index 3c46f92..7a418d0 100644 --- a/src/atdata_app/xrpc/procedures.py +++ b/src/atdata_app/xrpc/procedures.py @@ -294,12 +294,20 @@ async def publish_index(request: Request) -> dict[str, Any]: if field not in record: raise HTTPException(status_code=400, detail=f"Missing required field: {field}") - # Validate endpoint URL is HTTPS + # Validate endpoint URL is HTTPS with no credentials or fragments parsed = urlparse(record["endpointUrl"]) - if parsed.scheme != "https" or not parsed.netloc: + if parsed.scheme != "https" or not parsed.hostname: raise HTTPException( status_code=400, detail="endpointUrl must be a valid HTTPS URL" ) + if parsed.username or parsed.password: + raise HTTPException( + status_code=400, detail="endpointUrl must not contain credentials" + ) + if parsed.fragment: + raise HTTPException( + status_code=400, detail="endpointUrl must not contain a fragment" + ) record["$type"] = "science.alt.dataset.index" @@ -354,12 +362,17 @@ async def send_interactions(request: Request) -> dict[str, Any]: status_code=400, detail=f"interactions[{i}]: datasetUri is required" ) try: - parse_at_uri(dataset_uri) + _did, _col, _rkey = parse_at_uri(dataset_uri) except ValueError: raise HTTPException( status_code=400, detail=f"interactions[{i}]: invalid AT-URI: {dataset_uri}", ) + if _col != "science.alt.dataset.entry": + raise HTTPException( + status_code=400, + detail=f"interactions[{i}]: datasetUri must reference a dataset entry", + ) # All valid — fire analytics events for item in interactions: diff --git a/src/atdata_app/xrpc/queries.py b/src/atdata_app/xrpc/queries.py index 5adf8ec..ff22e1d 100644 --- a/src/atdata_app/xrpc/queries.py +++ b/src/atdata_app/xrpc/queries.py @@ -410,6 +410,10 @@ def _validate_endpoint_url(url: str) -> None: ) +_MAX_SKELETON_RESPONSE_BYTES = 1_048_576 # 1 MiB +_MAX_SKELETON_CURSOR_LEN = 512 + + async def _fetch_skeleton( endpoint_url: str, cursor: str | None, @@ -418,6 +422,10 @@ async def _fetch_skeleton( """Fetch skeleton from an upstream index provider.""" _validate_endpoint_url(endpoint_url) + if cursor is not None: + if len(cursor) > _MAX_SKELETON_CURSOR_LEN or "\x00" in cursor: + raise HTTPException(status_code=400, detail="Invalid cursor value") + params: dict[str, Any] = {"limit": limit} if cursor: params["cursor"] = cursor @@ -433,6 +441,16 @@ async def _fetch_skeleton( status_code=502, detail=f"Index provider returned {resp.status_code}", ) + # Reject oversized responses to prevent memory exhaustion + content_length = resp.headers.get("content-length") + if content_length and int(content_length) > _MAX_SKELETON_RESPONSE_BYTES: + raise HTTPException( + status_code=502, detail="Index provider response too large" + ) + if len(resp.content) > _MAX_SKELETON_RESPONSE_BYTES: + raise HTTPException( + status_code=502, detail="Index provider response too large" + ) try: data = resp.json() except (ValueError, KeyError) as e: @@ -445,6 +463,17 @@ async def _fetch_skeleton( ) # Cap items to the requested limit regardless of what upstream returns data["items"] = data["items"][:limit] + # Whitelist item fields — only pass through 'uri' to prevent injection + data["items"] = [{"uri": item.get("uri", "")} for item in data["items"]] + # Validate upstream cursor + upstream_cursor = data.get("cursor") + if upstream_cursor is not None: + if ( + not isinstance(upstream_cursor, str) + or len(upstream_cursor) > _MAX_SKELETON_CURSOR_LEN + or "\x00" in upstream_cursor + ): + data["cursor"] = None return data diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py index 490934f..fd31238 100644 --- a/tests/test_ingestion.py +++ b/tests/test_ingestion.py @@ -142,6 +142,26 @@ async def test_process_commit_lens(): ) +@pytest.mark.asyncio +async def test_process_commit_index_provider(): + patcher, mock_upsert = _patch_upsert("index_providers") + with patcher: + pool = AsyncMock() + event = _make_event( + collection="science.alt.dataset.index", + record={ + "$type": "science.alt.dataset.index", + "name": "Genomics Index", + "endpointUrl": "https://example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + ) + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) + + @pytest.mark.asyncio async def test_process_commit_update(): """Update operations should route to the same upsert function as create.""" From 0de1047aa075aa04d0555309902686031b83a531 Mon Sep 17 00:00:00 2001 From: Maxine Levesque <170461181+maxinelevesque@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:28:51 -0800 Subject: [PATCH 12/12] docs: update CHANGELOG with security and fix sections from adversarial review Co-Authored-By: Claude Opus 4.6 --- CHANGELOG.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed398b6..e35941f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/). - Ingestion processor refactored to use `UPSERT_FNS` dispatch dict instead of if/elif chain - Index provider records (`science.alt.dataset.index`) added to `COLLECTION_TABLE_MAP` for firehose ingestion +### Security + +- **SSRF protection**: Skeleton fetch now validates endpoint URLs with DNS resolution and blocks private/reserved IP ranges at both fetch time and firehose ingestion time +- **Auth**: `sendInteractions` endpoint now requires ATProto service auth (was previously unauthenticated) +- **XSS**: Storage URLs in dataset detail pages are only rendered as clickable links when using `http(s)://` schemes, preventing `javascript:` URI injection +- **Input validation**: `publishIndex` rejects endpoint URLs containing embedded credentials or fragments; `sendInteractions` validates that URIs reference the `science.alt.dataset.entry` collection + +### Fixed + +- **ChangeStream backpressure**: Subscribers that fall behind are now tracked and explicitly disconnected with WebSocket close code 4000, instead of silently dropping events +- **ChangeStream subscriber limit**: Capped at 1000 concurrent subscribers; new connections receive close code 1013 when full +- **WebSocket keepalive**: Restructured the `subscribeChanges` event loop so the 30-second idle keepalive correctly re-enters the processing loop (was previously broken) +- **Replay deduplication**: Track last replayed sequence number to prevent duplicate events when replay buffer overlaps with the live queue +- **Task GC**: Fire-and-forget analytics tasks now retain references to prevent garbage collection before completion +- **Skeleton response cap**: Enforce the requested `limit` on items returned by external index providers, and cap response body size to 1 MiB +- **Skeleton item sanitization**: Whitelist upstream skeleton items to only the `uri` field; validate cursor strings for length and null bytes +- **Query guard**: `query_get_entries` now rejects requests with more than 100 keys to prevent unbounded OR-clause queries +- **Template robustness**: `shape` and `dimensionNames` join filters now guard against non-iterable data from malformed firehose records +- Removed dead `_validate_iso8601` timestamp validation code from `sendInteractions` + ## [0.3.0b1] - 2026-02-22 ### Changed