diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 95cd225..4526e04 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,6 +15,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + submodules: true - uses: astral-sh/setup-uv@v5 with: enable-cache: true @@ -28,6 +30,8 @@ jobs: python-version: ["3.12", "3.13"] steps: - uses: actions/checkout@v4 + with: + submodules: true - uses: astral-sh/setup-uv@v5 with: enable-cache: true @@ -56,6 +60,8 @@ jobs: --health-retries 5 steps: - uses: actions/checkout@v4 + with: + submodules: true - name: Apply schema env: PGHOST: localhost @@ -105,6 +111,8 @@ jobs: --health-retries 5 steps: - uses: actions/checkout@v4 + with: + submodules: true - uses: astral-sh/setup-uv@v5 with: enable-cache: true diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4c9d7e5..48eb356 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -12,6 +12,8 @@ jobs: id-token: write steps: - uses: actions/checkout@v4 + with: + submodules: true - uses: astral-sh/setup-uv@v5 with: enable-cache: true diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..d4faf98 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "lexicons"] + path = lexicons + url = https://github.com/forecast-bio/atdata-lexicon.git diff --git a/CHANGELOG.md b/CHANGELOG.md index edb123d..e35941f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,42 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/), and this project adheres to [Semantic Versioning](https://semver.org/). +## [0.4.0b1] - 2026-02-26 + +### Added + +- `sendInteractions` XRPC procedure for anonymous usage telemetry — fire-and-forget reporting of download, citation, and derivative events on datasets ([#21](https://github.com/forecast-bio/atdata-app/issues/21)) +- Skeleton/hydration pattern for third-party dataset indexes — `getIndexSkeleton`, `getIndex`, `listIndexes`, and `publishIndex` endpoints following Bluesky's feed generator model ([#20](https://github.com/forecast-bio/atdata-app/issues/20)) +- `subscribeChanges` WebSocket endpoint for real-time change streaming — in-memory event bus broadcasts create/update/delete events to subscribers with cursor-based replay ([#22](https://github.com/forecast-bio/atdata-app/issues/22)) +- Array format type recognition (`sparseBytes`, `structuredBytes`, `arrowTensor`, `safetensors`) and ndarray v1.1.0 annotation display (`dtype`, `shape`, `dimensionNames`) in frontend templates ([#30](https://github.com/forecast-bio/atdata-app/issues/30)) +- `atdata-lexicon` git submodule at `lexicons/` pinned to v0.2.1b1 for reference and CI validation ([#27](https://github.com/forecast-bio/atdata-app/issues/27)) +- CI checkout steps now initialize submodules + +### Changed + +- Ingestion processor refactored to use `UPSERT_FNS` dispatch dict instead of if/elif chain +- Index provider records (`science.alt.dataset.index`) added to `COLLECTION_TABLE_MAP` for firehose ingestion + +### Security + +- **SSRF protection**: Skeleton fetch now validates endpoint URLs with DNS resolution and blocks private/reserved IP ranges at both fetch time and firehose ingestion time +- **Auth**: `sendInteractions` endpoint now requires ATProto service auth (was previously unauthenticated) +- **XSS**: Storage URLs in dataset detail pages are only rendered as clickable links when using `http(s)://` schemes, preventing `javascript:` URI injection +- **Input validation**: `publishIndex` rejects endpoint URLs containing embedded credentials or fragments; `sendInteractions` validates that URIs reference the `science.alt.dataset.entry` collection + +### Fixed + +- **ChangeStream backpressure**: Subscribers that fall behind are now tracked and explicitly disconnected with WebSocket close code 4000, instead of silently dropping events +- **ChangeStream subscriber limit**: Capped at 1000 concurrent subscribers; new connections receive close code 1013 when full +- **WebSocket keepalive**: Restructured the `subscribeChanges` event loop so the 30-second idle keepalive correctly re-enters the processing loop (was previously broken) +- **Replay deduplication**: Track last replayed sequence number to prevent duplicate events when replay buffer overlaps with the live queue +- **Task GC**: Fire-and-forget analytics tasks now retain references to prevent garbage collection before completion +- **Skeleton response cap**: Enforce the requested `limit` on items returned by external index providers, and cap response body size to 1 MiB +- **Skeleton item sanitization**: Whitelist upstream skeleton items to only the `uri` field; validate cursor strings for length and null bytes +- **Query guard**: `query_get_entries` now rejects requests with more than 100 keys to prevent unbounded OR-clause queries +- **Template robustness**: `shape` and `dimensionNames` join filters now guard against non-iterable data from malformed firehose records +- Removed dead `_validate_iso8601` timestamp validation code from `sendInteractions` + ## [0.3.0b1] - 2026-02-22 ### Changed diff --git a/CLAUDE.md b/CLAUDE.md index 6020760..c5b8285 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,6 +28,16 @@ uv run ruff check src/ tests/ uv run uvicorn atdata_app.main:app --reload ``` +## Lexicon Submodule + +The `lexicons/` directory is a git submodule pointing to [forecast-bio/atdata-lexicon](https://github.com/forecast-bio/atdata-lexicon), which contains the authoritative `science.alt.dataset.*` lexicon definitions. The submodule is for reference and CI validation — the Python source code does not read lexicon files at runtime. + +After cloning, initialize the submodule: + +```bash +git submodule update --init +``` + ## Architecture ### Data Flow diff --git a/README.md b/README.md index 00d8bdf..33ac782 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ ATProto Network # Install dependencies uv sync --dev +# Initialize the lexicon submodule +git submodule update --init + # Set up PostgreSQL (schema auto-applies on startup) createdb atdata_app @@ -135,6 +138,16 @@ uv run ruff check src/ tests/ Tests mock all external dependencies (database, HTTP, identity resolution) using `unittest.mock.AsyncMock`. HTTP endpoint tests use httpx `ASGITransport` for in-process testing without a running server. +### Lexicon Definitions + +The `lexicons/` directory is a [git submodule](https://github.com/forecast-bio/atdata-lexicon) containing the authoritative `science.alt.dataset.*` lexicon schemas. Initialize it with: + +```bash +git submodule update --init +``` + +The lexicons are for reference and CI validation. The Python source code uses hardcoded NSID constants and does not read the lexicon JSON files at runtime. + ## License MIT diff --git a/lexicons b/lexicons new file mode 160000 index 0000000..ece47ae --- /dev/null +++ b/lexicons @@ -0,0 +1 @@ +Subproject commit ece47ae2158c3226bdf2ac280695db32ea5ad336 diff --git a/pyproject.toml b/pyproject.toml index 708c0ac..32b3071 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "atdata-app" -version = "0.3.0b1" +version = "0.4.0b1" description = "ATProto AppView for science.alt.dataset" readme = "README.md" authors = [ diff --git a/src/atdata_app/changestream.py b/src/atdata_app/changestream.py new file mode 100644 index 0000000..17c3593 --- /dev/null +++ b/src/atdata_app/changestream.py @@ -0,0 +1,167 @@ +"""In-memory broadcast channel for real-time change events. + +Provides a pub/sub mechanism that the ingestion processor publishes to +and WebSocket subscribers consume from. Maintains a bounded buffer of +recent events for cursor-based replay. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections import deque +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + +DEFAULT_BUFFER_SIZE = 1000 +DEFAULT_SUBSCRIBER_QUEUE_SIZE = 256 +DEFAULT_MAX_SUBSCRIBERS = 1000 + + +@dataclass +class ChangeEvent: + """A single change event in the stream.""" + + seq: int + type: str # "create", "update", or "delete" + collection: str + did: str + rkey: str + timestamp: str + record: dict[str, Any] | None = None + cid: str | None = None + + def to_dict(self) -> dict[str, Any]: + d: dict[str, Any] = { + "seq": self.seq, + "type": self.type, + "collection": self.collection, + "did": self.did, + "rkey": self.rkey, + "timestamp": self.timestamp, + } + if self.record is not None: + d["record"] = self.record + if self.cid is not None: + d["cid"] = self.cid + return d + + +@dataclass +class ChangeStream: + """Broadcast channel with bounded replay buffer. + + Thread-safe for asyncio: all mutations happen in the event loop. + """ + + buffer_size: int = DEFAULT_BUFFER_SIZE + subscriber_queue_size: int = DEFAULT_SUBSCRIBER_QUEUE_SIZE + max_subscribers: int = DEFAULT_MAX_SUBSCRIBERS + _seq: int = field(default=0, init=False) + _buffer: deque[ChangeEvent] = field(init=False) + _subscribers: dict[int, asyncio.Queue[ChangeEvent]] = field( + default_factory=dict, init=False + ) + _dropped_subs: set[int] = field(default_factory=set, init=False) + _next_sub_id: int = field(default=0, init=False) + + def __post_init__(self) -> None: + self._buffer = deque(maxlen=self.buffer_size) + + def publish(self, event: ChangeEvent) -> None: + """Publish an event to all subscribers and the replay buffer. + + Non-blocking. If a subscriber's queue is full, the subscriber is + marked as dropped so the WebSocket handler can close the connection. + """ + self._seq += 1 + event.seq = self._seq + self._buffer.append(event) + + for sub_id, queue in list(self._subscribers.items()): + try: + queue.put_nowait(event) + except asyncio.QueueFull: + logger.warning( + "Subscriber %d queue full at seq=%d — marking for disconnect", + sub_id, + event.seq, + ) + self._dropped_subs.add(sub_id) + + def subscribe(self) -> tuple[int, asyncio.Queue[ChangeEvent]]: + """Create a new subscriber. Returns (subscriber_id, queue). + + Raises ``RuntimeError`` if the maximum subscriber count is reached. + """ + if len(self._subscribers) >= self.max_subscribers: + raise RuntimeError( + f"Maximum subscriber count ({self.max_subscribers}) reached" + ) + sub_id = self._next_sub_id + self._next_sub_id += 1 + queue: asyncio.Queue[ChangeEvent] = asyncio.Queue( + maxsize=self.subscriber_queue_size + ) + self._subscribers[sub_id] = queue + logger.debug("Subscriber %d connected (total: %d)", sub_id, len(self._subscribers)) + return sub_id, queue + + def unsubscribe(self, sub_id: int) -> None: + """Remove a subscriber.""" + self._subscribers.pop(sub_id, None) + self._dropped_subs.discard(sub_id) + logger.debug("Subscriber %d disconnected (total: %d)", sub_id, len(self._subscribers)) + + def is_dropped(self, sub_id: int) -> bool: + """Return True if the subscriber was dropped due to backpressure.""" + return sub_id in self._dropped_subs + + def replay_from(self, cursor: int) -> list[ChangeEvent]: + """Return buffered events with seq > cursor. + + Returns an empty list if the cursor is outside the buffer window. + """ + if not self._buffer: + return [] + + oldest_seq = self._buffer[0].seq + if cursor < oldest_seq - 1: + # Cursor is too old — events between cursor and buffer start were lost + return [] + + return [ev for ev in self._buffer if ev.seq > cursor] + + @property + def current_seq(self) -> int: + return self._seq + + @property + def subscriber_count(self) -> int: + return len(self._subscribers) + + +def make_change_event( + *, + event_type: str, + collection: str, + did: str, + rkey: str, + record: dict[str, Any] | None = None, + cid: str | None = None, +) -> ChangeEvent: + """Factory for creating change events with current timestamp.""" + from datetime import datetime, timezone + + return ChangeEvent( + seq=0, # Assigned by ChangeStream.publish() + type=event_type, + collection=collection, + did=did, + rkey=rkey, + timestamp=datetime.now(timezone.utc).isoformat(), + record=record, + cid=cid, + ) diff --git a/src/atdata_app/database.py b/src/atdata_app/database.py index b0ea925..6543530 100644 --- a/src/atdata_app/database.py +++ b/src/atdata_app/database.py @@ -20,6 +20,7 @@ f"{LEXICON_NAMESPACE}.entry": "entries", f"{LEXICON_NAMESPACE}.label": "labels", f"{LEXICON_NAMESPACE}.lens": "lenses", + f"{LEXICON_NAMESPACE}.index": "index_providers", } @@ -203,6 +204,57 @@ async def upsert_lens( ) +def _is_safe_endpoint_url(url: str) -> bool: + """Return True if *url* is HTTPS with no credentials or private IPs.""" + from urllib.parse import urlparse + + try: + parsed = urlparse(url) + except ValueError: + return False + if parsed.scheme != "https" or not parsed.hostname: + return False + if parsed.username or parsed.password: + return False + return True + + +async def upsert_index_provider( + pool: asyncpg.Pool, + did: str, + rkey: str, + cid: str | None, + record: dict[str, Any], +) -> None: + endpoint_url = record.get("endpointUrl", "") + if not _is_safe_endpoint_url(endpoint_url): + logger.warning( + "Rejected index provider %s/%s: unsafe endpoint URL %s", did, rkey, endpoint_url + ) + return + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO index_providers (did, rkey, cid, name, description, endpoint_url, + created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (did, rkey) DO UPDATE SET + cid = EXCLUDED.cid, + name = EXCLUDED.name, + description = EXCLUDED.description, + endpoint_url = EXCLUDED.endpoint_url, + indexed_at = NOW() + """, + did, + rkey, + cid, + record["name"], + record.get("description"), + record["endpointUrl"], + record.get("createdAt", ""), + ) + + async def delete_record(pool: asyncpg.Pool, table: str, did: str, rkey: str) -> None: if table not in COLLECTION_TABLE_MAP.values(): return @@ -219,6 +271,7 @@ async def delete_record(pool: asyncpg.Pool, table: str, did: str, rkey: str) -> "entries": upsert_entry, "labels": upsert_label, "lenses": upsert_lens, + "index_providers": upsert_index_provider, } @@ -313,11 +366,18 @@ async def query_get_entry( ) +_MAX_GET_ENTRIES_KEYS = 100 + + async def query_get_entries( pool: asyncpg.Pool, keys: list[tuple[str, str]] ) -> list[asyncpg.Record]: if not keys: return [] + if len(keys) > _MAX_GET_ENTRIES_KEYS: + raise ValueError( + f"query_get_entries: too many keys ({len(keys)}), max {_MAX_GET_ENTRIES_KEYS}" + ) conditions = " OR ".join( f"(did = ${i * 2 + 1} AND rkey = ${i * 2 + 2})" for i in range(len(keys)) ) @@ -449,6 +509,49 @@ async def query_list_lenses( ) +async def query_get_index_provider( + pool: asyncpg.Pool, did: str, rkey: str +) -> asyncpg.Record | None: + async with pool.acquire() as conn: + return await conn.fetchrow( + "SELECT * FROM index_providers WHERE did = $1 AND rkey = $2", did, rkey + ) + + +async def query_list_index_providers( + pool: asyncpg.Pool, + repo: str | None = None, + limit: int = 50, + cursor_did: str | None = None, + cursor_rkey: str | None = None, + cursor_indexed_at: str | None = None, +) -> list[asyncpg.Record]: + conditions: list[str] = [] + params: list[Any] = [] + idx = 1 + + if repo: + conditions.append(f"did = ${idx}") + params.append(repo) + idx += 1 + + if cursor_indexed_at and cursor_did and cursor_rkey: + conditions.append( + f"(indexed_at, did, rkey) < (${idx}, ${idx + 1}, ${idx + 2})" + ) + params.extend([datetime.fromisoformat(cursor_indexed_at), cursor_did, cursor_rkey]) + idx += 3 + + where = f"WHERE {' AND '.join(conditions)}" if conditions else "" + params.append(limit) + + async with pool.acquire() as conn: + return await conn.fetch( + f"SELECT * FROM index_providers {where} ORDER BY indexed_at DESC, did DESC, rkey DESC LIMIT ${idx}", + *params, + ) + + async def query_search_datasets( pool: asyncpg.Pool, q: str, @@ -544,13 +647,17 @@ async def query_search_lenses( ) +async def _fetch_record_counts(conn: asyncpg.Connection) -> dict[str, int]: + counts = {} + for collection, table in COLLECTION_TABLE_MAP.items(): + row = await conn.fetchrow(f"SELECT COUNT(*) as cnt FROM {table}") # noqa: S608 + counts[collection] = row["cnt"] + return counts + + async def query_record_counts(pool: asyncpg.Pool) -> dict[str, int]: async with pool.acquire() as conn: - counts = {} - for collection, table in COLLECTION_TABLE_MAP.items(): - row = await conn.fetchrow(f"SELECT COUNT(*) as cnt FROM {table}") # noqa: S608 - counts[collection] = row["cnt"] - return counts + return await _fetch_record_counts(conn) async def query_labels_for_dataset( @@ -607,6 +714,9 @@ async def record_analytics_event( logger.warning("Failed to record analytics event %s", event_type, exc_info=True) +_background_tasks: set[asyncio.Task] = set() + + def fire_analytics_event( pool: asyncpg.Pool, event_type: str, @@ -615,9 +725,11 @@ def fire_analytics_event( query_params: dict[str, Any] | None = None, ) -> None: """Fire-and-forget analytics recording. Does not block the caller.""" - asyncio.create_task( + task = asyncio.create_task( record_analytics_event(pool, event_type, target_did, target_rkey, query_params) ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) PERIOD_INTERVALS: dict[str, timedelta] = { @@ -692,11 +804,7 @@ async def query_analytics_summary( {"term": r["term"], "count": r["count"]} for r in term_rows ] - # Record counts - counts = {} - for collection, table in COLLECTION_TABLE_MAP.items(): - c = await conn.fetchrow(f"SELECT COUNT(*) AS cnt FROM {table}") # noqa: S608 - counts[collection] = c["cnt"] + counts = await _fetch_record_counts(conn) return { "totalViews": total_views, @@ -720,7 +828,10 @@ async def query_entry_stats( """ SELECT COUNT(*) FILTER (WHERE event_type = 'view_entry') AS views, - COUNT(*) FILTER (WHERE event_type = 'search') AS search_appearances + COUNT(*) FILTER (WHERE event_type = 'search') AS search_appearances, + COUNT(*) FILTER (WHERE event_type = 'download') AS downloads, + COUNT(*) FILTER (WHERE event_type = 'citation') AS citations, + COUNT(*) FILTER (WHERE event_type = 'derivative') AS derivatives FROM analytics_events WHERE target_did = $1 AND target_rkey = $2 AND created_at >= NOW() - $3::interval @@ -732,6 +843,9 @@ async def query_entry_stats( return { "views": row["views"], "searchAppearances": row["search_appearances"], + "downloads": row["downloads"], + "citations": row["citations"], + "derivatives": row["derivatives"], "period": period, } @@ -750,6 +864,8 @@ async def query_active_publishers(pool: asyncpg.Pool, days: int = 30) -> int: SELECT did FROM labels WHERE indexed_at >= NOW() - $1::interval UNION SELECT did FROM lenses WHERE indexed_at >= NOW() - $1::interval + UNION + SELECT did FROM index_providers WHERE indexed_at >= NOW() - $1::interval ) sub """, interval, diff --git a/src/atdata_app/frontend/routes.py b/src/atdata_app/frontend/routes.py index 9a815e4..dac49f2 100644 --- a/src/atdata_app/frontend/routes.py +++ b/src/atdata_app/frontend/routes.py @@ -22,6 +22,7 @@ ) from atdata_app.models import ( maybe_cursor, + parse_at_uri, parse_cursor, row_to_entry, row_to_label, @@ -100,10 +101,18 @@ async def dataset_detail(request: Request, did: str, rkey: str): schema_did = "" schema_rkey = "" schema_ref = entry.get("schemaRef", "") - if schema_ref.startswith("at://"): - parts = schema_ref[5:].split("/", 2) - if len(parts) == 3: - schema_did, _, schema_rkey = parts + if schema_ref: + try: + schema_did, _, schema_rkey = parse_at_uri(schema_ref) + except ValueError: + pass + + # Fetch the referenced schema for inline display of format/annotation info + schema_info = None + if schema_did and schema_rkey: + schema_row = await query_get_schema(pool, schema_did, schema_rkey) + if schema_row: + schema_info = row_to_schema(schema_row) # Fetch labels pointing to this dataset dataset_uri = entry["uri"] @@ -117,6 +126,7 @@ async def dataset_detail(request: Request, did: str, rkey: str): "entry": entry, "schema_did": schema_did, "schema_rkey": schema_rkey, + "schema_info": schema_info, "labels": labels, }, ) diff --git a/src/atdata_app/frontend/templates/dataset.html b/src/atdata_app/frontend/templates/dataset.html index 4b32fed..ddd4760 100644 --- a/src/atdata_app/frontend/templates/dataset.html +++ b/src/atdata_app/frontend/templates/dataset.html @@ -27,6 +27,20 @@

Details

License{{ entry.license }} {% endif %} Schema{{ entry.schemaRef }} + {% if schema_info %} + {% if schema_info.arrayFormat is defined %} + Array Format{{ schema_info.get("arrayFormatLabel", schema_info.arrayFormat) }} + {% endif %} + {% if schema_info.dtype is defined %} + Data Type{{ schema_info.dtype }} + {% endif %} + {% if schema_info.shape is defined and schema_info.shape is iterable and schema_info.shape is not string %} + Shape{{ schema_info.shape | join(" × ") }} + {% endif %} + {% if schema_info.dimensionNames is defined and schema_info.dimensionNames is iterable and schema_info.dimensionNames is not string %} + Dimensions{{ schema_info.dimensionNames | join(", ") }} + {% endif %} + {% endif %} {% if entry.size %} Size @@ -49,7 +63,11 @@

Storage

Type{{ entry.storage.get("$type", "unknown") }} {% if entry.storage.url is defined %} + {% if entry.storage.url.startswith("https://") or entry.storage.url.startswith("http://") %} URL{{ entry.storage.url }} + {% else %} + URL{{ entry.storage.url }} + {% endif %} {% endif %} diff --git a/src/atdata_app/frontend/templates/profile.html b/src/atdata_app/frontend/templates/profile.html index 46026b4..04c4e29 100644 --- a/src/atdata_app/frontend/templates/profile.html +++ b/src/atdata_app/frontend/templates/profile.html @@ -34,13 +34,14 @@

{{ entry.name }}

Schemas {% if schemas %} - + {% for s in schemas %} + {% endfor %} diff --git a/src/atdata_app/frontend/templates/schema.html b/src/atdata_app/frontend/templates/schema.html index 281d9f6..3df1859 100644 --- a/src/atdata_app/frontend/templates/schema.html +++ b/src/atdata_app/frontend/templates/schema.html @@ -16,6 +16,18 @@

Details

+ {% if schema.arrayFormat is defined %} + + {% endif %} + {% if schema.dtype is defined %} + + {% endif %} + {% if schema.shape is defined and schema.shape is iterable and schema.shape is not string %} + + {% endif %} + {% if schema.dimensionNames is defined and schema.dimensionNames is iterable and schema.dimensionNames is not string %} + + {% endif %} diff --git a/src/atdata_app/frontend/templates/schemas.html b/src/atdata_app/frontend/templates/schemas.html index 268af39..fdf3565 100644 --- a/src/atdata_app/frontend/templates/schemas.html +++ b/src/atdata_app/frontend/templates/schemas.html @@ -7,7 +7,7 @@

Schemas

{% if schemas %}
NameVersionType
NameVersionTypeFormat
{{ s.name }} {{ s.version }} {{ s.schemaType }}{{ s.get("arrayFormatLabel", "") }}
AT-URI{{ schema.uri }}
Type{{ schema.schemaType }}
Array Format{{ schema.get("arrayFormatLabel", schema.arrayFormat) }}
Data Type{{ schema.dtype }}
Shape{{ schema.shape | join(" × ") }}
Dimensions{{ schema.dimensionNames | join(", ") }}
Version{{ schema.version }}
Created{{ schema.createdAt }}
- + {% for s in schemas %} @@ -15,6 +15,7 @@

Schemas

+ diff --git a/src/atdata_app/ingestion/jetstream.py b/src/atdata_app/ingestion/jetstream.py index 9b765d2..f09711b 100644 --- a/src/atdata_app/ingestion/jetstream.py +++ b/src/atdata_app/ingestion/jetstream.py @@ -52,7 +52,9 @@ async def jetstream_consumer(app: FastAPI) -> None: if event.get("kind") != "commit": continue - await process_commit(pool, event) + await process_commit( + pool, event, getattr(app.state, "change_stream", None) + ) last_time_us = event.get("time_us") msg_count += 1 diff --git a/src/atdata_app/ingestion/processor.py b/src/atdata_app/ingestion/processor.py index fa34c96..aab3f2b 100644 --- a/src/atdata_app/ingestion/processor.py +++ b/src/atdata_app/ingestion/processor.py @@ -8,11 +8,16 @@ import asyncpg from atdata_app import database as db +from atdata_app.changestream import ChangeStream, make_change_event logger = logging.getLogger(__name__) -async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None: +async def process_commit( + pool: asyncpg.Pool, + event: dict[str, Any], + change_stream: ChangeStream | None = None, +) -> None: """Process a Jetstream commit event. Expected event format:: @@ -44,18 +49,31 @@ async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None: if operation == "delete": await db.delete_record(pool, table, did, rkey) logger.debug("Deleted %s %s/%s", collection, did, rkey) + if change_stream is not None: + change_stream.publish( + make_change_event( + event_type="delete", + collection=collection, + did=did, + rkey=rkey, + ) + ) elif operation in ("create", "update"): record = commit["record"] cid = commit.get("cid") try: - if table == "schemas": - await db.upsert_schema(pool, did, rkey, cid, record) - elif table == "entries": - await db.upsert_entry(pool, did, rkey, cid, record) - elif table == "labels": - await db.upsert_label(pool, did, rkey, cid, record) - elif table == "lenses": - await db.upsert_lens(pool, did, rkey, cid, record) + await db.UPSERT_FNS[table](pool, did, rkey, cid, record) logger.debug("Upserted %s %s/%s", collection, did, rkey) + if change_stream is not None: + change_stream.publish( + make_change_event( + event_type=operation, + collection=collection, + did=did, + rkey=rkey, + record=record, + cid=cid, + ) + ) except Exception: logger.exception("Failed to upsert %s %s/%s", collection, did, rkey) diff --git a/src/atdata_app/main.py b/src/atdata_app/main.py index b50bba3..087fde6 100644 --- a/src/atdata_app/main.py +++ b/src/atdata_app/main.py @@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles +from atdata_app.changestream import ChangeStream from atdata_app.config import AppConfig from atdata_app.database import create_pool, run_migrations from atdata_app.frontend import router as frontend_router @@ -30,6 +31,9 @@ async def lifespan(app: FastAPI): config: AppConfig = app.state.config logger.info("Starting atdata-app (DID: %s)", config.service_did) + # Change stream (must be created before background tasks) + app.state.change_stream = ChangeStream() + # Database pool = await create_pool(config.database_url) app.state.db_pool = pool diff --git a/src/atdata_app/mcp_server.py b/src/atdata_app/mcp_server.py index 53991eb..a5d2e69 100644 --- a/src/atdata_app/mcp_server.py +++ b/src/atdata_app/mcp_server.py @@ -57,7 +57,10 @@ async def server_lifespan(server: FastMCP) -> AsyncIterator[ServerContext]: "ATProto AppView for the science.alt.dataset namespace. " "Use these tools to discover and query scientific datasets, " "schemas, and lenses (bidirectional schema transforms) published " - "on the AT Protocol network." + "on the AT Protocol network. " + "Schemas may specify an arrayFormat (numpyBytes, parquetBytes, " + "sparseBytes, structuredBytes, arrowTensor, safetensors) and " + "ndarray annotations (dtype, shape, dimensionNames)." ), lifespan=server_lifespan, ) @@ -127,7 +130,8 @@ async def get_schema(ctx: Ctx, uri: str) -> dict[str, Any]: uri: AT-URI of the schema (e.g. at://did:plc:abc/science.alt.dataset.schema/my.schema@1.0.0). Returns: - Full schema record including name, version, type, schema body, and description. + Full schema record including name, version, type, schema body, description, + and (when present) arrayFormat, dtype, shape, and dimensionNames. """ sc = _get_ctx(ctx) did, _collection, rkey = parse_at_uri(uri) diff --git a/src/atdata_app/models.py b/src/atdata_app/models.py index 1cde2f2..8793590 100644 --- a/src/atdata_app/models.py +++ b/src/atdata_app/models.py @@ -9,6 +9,32 @@ from pydantic import BaseModel +# --------------------------------------------------------------------------- +# Known array format tokens (atdata-lexicon#21) +# --------------------------------------------------------------------------- + +KNOWN_ARRAY_FORMATS: set[str] = { + # Original formats + "numpyBytes", + "parquetBytes", + # New formats + "sparseBytes", + "structuredBytes", + "arrowTensor", + "safetensors", +} + +#: Human-friendly display names for array format tokens. +ARRAY_FORMAT_LABELS: dict[str, str] = { + "numpyBytes": "NumPy ndarray", + "parquetBytes": "Parquet", + "sparseBytes": "Sparse matrix (CSR/CSC/COO)", + "structuredBytes": "NumPy structured array", + "arrowTensor": "Arrow tensor IPC", + "safetensors": "Safetensors", +} + + # --------------------------------------------------------------------------- # AT-URI parsing # --------------------------------------------------------------------------- @@ -119,6 +145,19 @@ def row_to_schema(row) -> dict[str, Any]: } if row["description"]: d["description"] = row["description"] + + # Surface array format and ndarray v1.1.0 annotation fields for display + array_format = schema_body.get("arrayFormat") + if array_format: + d["arrayFormat"] = array_format + d["arrayFormatLabel"] = ARRAY_FORMAT_LABELS.get(array_format, array_format) + if schema_body.get("dtype"): + d["dtype"] = schema_body["dtype"] + if schema_body.get("shape"): + d["shape"] = schema_body["shape"] + if schema_body.get("dimensionNames"): + d["dimensionNames"] = schema_body["dimensionNames"] + return d @@ -127,6 +166,7 @@ def row_to_label(row) -> dict[str, Any]: d: dict[str, Any] = { "uri": uri, "cid": row["cid"], + "did": row["did"], "name": row["name"], "datasetUri": row["dataset_uri"], "createdAt": row["created_at"], @@ -138,6 +178,21 @@ def row_to_label(row) -> dict[str, Any]: return d +def row_to_index_provider(row) -> dict[str, Any]: + uri = make_at_uri(row["did"], "science.alt.dataset.index", row["rkey"]) + d: dict[str, Any] = { + "uri": uri, + "cid": row["cid"], + "did": row["did"], + "name": row["name"], + "endpointUrl": row["endpoint_url"], + "createdAt": row["created_at"], + } + if row["description"]: + d["description"] = row["description"] + return d + + def row_to_lens(row) -> dict[str, Any]: uri = make_at_uri(row["did"], "science.alt.dataset.lens", row["rkey"]) getter_code = row["getter_code"] @@ -150,6 +205,7 @@ def row_to_lens(row) -> dict[str, Any]: d: dict[str, Any] = { "uri": uri, "cid": row["cid"], + "did": row["did"], "name": row["name"], "sourceSchema": row["source_schema"], "targetSchema": row["target_schema"], @@ -236,4 +292,22 @@ class GetAnalyticsResponse(BaseModel): class GetEntryStatsResponse(BaseModel): views: int searchAppearances: int + downloads: int = 0 + citations: int = 0 + derivatives: int = 0 period: str + + +class ListIndexesResponse(BaseModel): + indexes: list[dict[str, Any]] + cursor: str | None = None + + +class IndexSkeletonResponse(BaseModel): + items: list[dict[str, Any]] + cursor: str | None = None + + +class IndexResponse(BaseModel): + items: list[dict[str, Any]] + cursor: str | None = None diff --git a/src/atdata_app/sql/schema.sql b/src/atdata_app/sql/schema.sql index 17f1729..422b379 100644 --- a/src/atdata_app/sql/schema.sql +++ b/src/atdata_app/sql/schema.sql @@ -144,6 +144,22 @@ CREATE INDEX IF NOT EXISTS idx_analytics_events_type_created CREATE INDEX IF NOT EXISTS idx_analytics_events_target ON analytics_events (target_did, target_rkey, event_type); +-- Index providers (science.alt.dataset.index) +CREATE TABLE IF NOT EXISTS index_providers ( + did TEXT NOT NULL, + rkey TEXT NOT NULL, + cid TEXT, + name TEXT NOT NULL, + description TEXT, + endpoint_url TEXT NOT NULL, + created_at TEXT NOT NULL, + indexed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (did, rkey) +); + +CREATE INDEX IF NOT EXISTS idx_index_providers_did ON index_providers (did); +CREATE INDEX IF NOT EXISTS idx_index_providers_indexed_at ON index_providers (indexed_at DESC); + -- Pre-aggregated analytics counters (avoids expensive COUNT on events table) CREATE TABLE IF NOT EXISTS analytics_counters ( target_did TEXT NOT NULL, diff --git a/src/atdata_app/xrpc/procedures.py b/src/atdata_app/xrpc/procedures.py index c066987..7a418d0 100644 --- a/src/atdata_app/xrpc/procedures.py +++ b/src/atdata_app/xrpc/procedures.py @@ -8,6 +8,7 @@ import logging from typing import Any +from urllib.parse import urlparse import httpx from fastapi import APIRouter, HTTPException, Request @@ -15,6 +16,7 @@ from atdata_app import get_resolver from atdata_app.auth import verify_service_auth from atdata_app.database import ( + fire_analytics_event, query_get_entry, query_get_schema, query_record_exists, @@ -268,3 +270,113 @@ async def publish_lens(request: Request) -> dict[str, Any]: pds, pds_token, auth.iss, "science.alt.dataset.lens", record, rkey ) return {"uri": result.get("uri"), "cid": result.get("cid")} + + +# --------------------------------------------------------------------------- +# publishIndex +# --------------------------------------------------------------------------- + + +@router.post("/science.alt.dataset.publishIndex") +async def publish_index(request: Request) -> dict[str, Any]: + auth = await verify_service_auth(request, "science.alt.dataset.publishIndex") + pds_token = _require_pds_token(request) + + body = await request.json() + record = body.get("record", {}) + rkey = body.get("rkey") + + record_type = record.get("$type", "") + if record_type and record_type != "science.alt.dataset.index": + raise HTTPException(status_code=400, detail="Invalid $type for index") + + for field in ("name", "endpointUrl", "createdAt"): + if field not in record: + raise HTTPException(status_code=400, detail=f"Missing required field: {field}") + + # Validate endpoint URL is HTTPS with no credentials or fragments + parsed = urlparse(record["endpointUrl"]) + if parsed.scheme != "https" or not parsed.hostname: + raise HTTPException( + status_code=400, detail="endpointUrl must be a valid HTTPS URL" + ) + if parsed.username or parsed.password: + raise HTTPException( + status_code=400, detail="endpointUrl must not contain credentials" + ) + if parsed.fragment: + raise HTTPException( + status_code=400, detail="endpointUrl must not contain a fragment" + ) + + record["$type"] = "science.alt.dataset.index" + + pds = await _resolve_pds(auth.iss) + result = await _proxy_create_record( + pds, pds_token, auth.iss, "science.alt.dataset.index", record, rkey + ) + return {"uri": result.get("uri"), "cid": result.get("cid")} + + +# --------------------------------------------------------------------------- +# sendInteractions +# --------------------------------------------------------------------------- + +_VALID_INTERACTION_TYPES = frozenset({"download", "citation", "derivative"}) +_MAX_INTERACTIONS_BATCH = 100 + + +@router.post("/science.alt.dataset.sendInteractions") +async def send_interactions(request: Request) -> dict[str, Any]: + """Record dataset interaction events (anonymous, fire-and-forget).""" + await verify_service_auth(request, "science.alt.dataset.sendInteractions") + + pool = request.app.state.db_pool + + body = await request.json() + interactions = body.get("interactions") + if not isinstance(interactions, list): + raise HTTPException(status_code=400, detail="interactions must be an array") + + if len(interactions) > _MAX_INTERACTIONS_BATCH: + raise HTTPException( + status_code=400, + detail=f"Batch size exceeds maximum of {_MAX_INTERACTIONS_BATCH}", + ) + + for i, item in enumerate(interactions): + if not isinstance(item, dict): + raise HTTPException(status_code=400, detail=f"interactions[{i}]: must be an object") + + itype = item.get("type") + if itype not in _VALID_INTERACTION_TYPES: + raise HTTPException( + status_code=400, + detail=f"interactions[{i}]: invalid type '{itype}', " + f"must be one of: {', '.join(sorted(_VALID_INTERACTION_TYPES))}", + ) + + dataset_uri = item.get("datasetUri") + if not isinstance(dataset_uri, str): + raise HTTPException( + status_code=400, detail=f"interactions[{i}]: datasetUri is required" + ) + try: + _did, _col, _rkey = parse_at_uri(dataset_uri) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"interactions[{i}]: invalid AT-URI: {dataset_uri}", + ) + if _col != "science.alt.dataset.entry": + raise HTTPException( + status_code=400, + detail=f"interactions[{i}]: datasetUri must reference a dataset entry", + ) + + # All valid — fire analytics events + for item in interactions: + did, _collection, rkey = parse_at_uri(item["datasetUri"]) + fire_analytics_event(pool, item["type"], target_did=did, target_rkey=rkey) + + return {} diff --git a/src/atdata_app/xrpc/queries.py b/src/atdata_app/xrpc/queries.py index df5d15c..ff22e1d 100644 --- a/src/atdata_app/xrpc/queries.py +++ b/src/atdata_app/xrpc/queries.py @@ -7,6 +7,12 @@ from fastapi import APIRouter, HTTPException, Query, Request +import ipaddress +import socket +from urllib.parse import urlparse + +import httpx + from atdata_app import get_resolver from atdata_app.database import ( COLLECTION_TABLE_MAP, @@ -16,8 +22,10 @@ query_entry_stats, query_get_entries, query_get_entry, + query_get_index_provider, query_get_schema, query_list_entries, + query_list_index_providers, query_list_lenses, query_list_schemas, query_record_counts, @@ -32,7 +40,10 @@ GetEntriesResponse, GetEntryResponse, GetEntryStatsResponse, + IndexResponse, + IndexSkeletonResponse, ListEntriesResponse, + ListIndexesResponse, ListLensesResponse, ListSchemasResponse, ResolveBlobsResponse, @@ -44,6 +55,7 @@ parse_at_uri, parse_cursor, row_to_entry, + row_to_index_provider, row_to_label, row_to_lens, row_to_schema, @@ -340,6 +352,203 @@ async def search_lenses( ) +# --------------------------------------------------------------------------- +# listIndexes +# --------------------------------------------------------------------------- + + +@router.get("/science.alt.dataset.listIndexes") +async def list_indexes( + request: Request, + repo: str | None = Query(None), + limit: int = Query(50, ge=1, le=100), + cursor: str | None = Query(None), +) -> ListIndexesResponse: + pool = request.app.state.db_pool + c_at, c_did, c_rkey = parse_cursor(cursor) + rows = await query_list_index_providers(pool, repo, limit, c_did, c_rkey, c_at) + return ListIndexesResponse( + indexes=[row_to_index_provider(r) for r in rows], + cursor=maybe_cursor(rows, limit), + ) + + +# --------------------------------------------------------------------------- +# getIndexSkeleton +# --------------------------------------------------------------------------- + + +def _validate_endpoint_url(url: str) -> None: + """Reject URLs that could cause SSRF (private IPs, non-HTTPS, credentials). + + Raises ``HTTPException`` (400) if the URL is unsafe. + """ + try: + parsed = urlparse(url) + except ValueError: + raise HTTPException(status_code=400, detail="Malformed endpoint URL") + + if parsed.scheme != "https": + raise HTTPException(status_code=400, detail="Endpoint URL must use HTTPS") + if not parsed.hostname: + raise HTTPException(status_code=400, detail="Endpoint URL missing hostname") + if parsed.username or parsed.password: + raise HTTPException(status_code=400, detail="Endpoint URL must not contain credentials") + + # Resolve hostname and block private/reserved IP ranges + try: + infos = socket.getaddrinfo(parsed.hostname, None, proto=socket.IPPROTO_TCP) + except socket.gaierror: + raise HTTPException(status_code=502, detail="Index provider hostname unresolvable") + + for _family, _type, _proto, _canonname, sockaddr in infos: + ip = ipaddress.ip_address(sockaddr[0]) + if ip.is_private or ip.is_reserved or ip.is_loopback or ip.is_link_local: + raise HTTPException( + status_code=400, + detail="Endpoint URL must not resolve to a private/reserved IP address", + ) + + +_MAX_SKELETON_RESPONSE_BYTES = 1_048_576 # 1 MiB +_MAX_SKELETON_CURSOR_LEN = 512 + + +async def _fetch_skeleton( + endpoint_url: str, + cursor: str | None, + limit: int, +) -> dict[str, Any]: + """Fetch skeleton from an upstream index provider.""" + _validate_endpoint_url(endpoint_url) + + if cursor is not None: + if len(cursor) > _MAX_SKELETON_CURSOR_LEN or "\x00" in cursor: + raise HTTPException(status_code=400, detail="Invalid cursor value") + + params: dict[str, Any] = {"limit": limit} + if cursor: + params["cursor"] = cursor + async with httpx.AsyncClient(timeout=10.0) as http: + try: + resp = await http.get(endpoint_url, params=params) + except (httpx.HTTPError, ValueError) as e: + raise HTTPException( + status_code=502, detail="Index provider unreachable" + ) from e + if resp.status_code != 200: + raise HTTPException( + status_code=502, + detail=f"Index provider returned {resp.status_code}", + ) + # Reject oversized responses to prevent memory exhaustion + content_length = resp.headers.get("content-length") + if content_length and int(content_length) > _MAX_SKELETON_RESPONSE_BYTES: + raise HTTPException( + status_code=502, detail="Index provider response too large" + ) + if len(resp.content) > _MAX_SKELETON_RESPONSE_BYTES: + raise HTTPException( + status_code=502, detail="Index provider response too large" + ) + try: + data = resp.json() + except (ValueError, KeyError) as e: + raise HTTPException( + status_code=502, detail="Invalid response from index provider" + ) from e + if not isinstance(data.get("items"), list): + raise HTTPException( + status_code=502, detail="Index provider response missing 'items' array" + ) + # Cap items to the requested limit regardless of what upstream returns + data["items"] = data["items"][:limit] + # Whitelist item fields — only pass through 'uri' to prevent injection + data["items"] = [{"uri": item.get("uri", "")} for item in data["items"]] + # Validate upstream cursor + upstream_cursor = data.get("cursor") + if upstream_cursor is not None: + if ( + not isinstance(upstream_cursor, str) + or len(upstream_cursor) > _MAX_SKELETON_CURSOR_LEN + or "\x00" in upstream_cursor + ): + data["cursor"] = None + return data + + +@router.get("/science.alt.dataset.getIndexSkeleton") +async def get_index_skeleton( + request: Request, + index: str = Query(...), + cursor: str | None = Query(None), + limit: int = Query(50, ge=1, le=100), +) -> IndexSkeletonResponse: + pool = request.app.state.db_pool + try: + did, _, rkey = parse_at_uri(index) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid AT-URI for index") + provider = await query_get_index_provider(pool, did, rkey) + if not provider: + raise HTTPException(status_code=404, detail="Index provider not found") + + data = await _fetch_skeleton(provider["endpoint_url"], cursor, limit) + return IndexSkeletonResponse( + items=data["items"], + cursor=data.get("cursor"), + ) + + +# --------------------------------------------------------------------------- +# getIndex (hydrated) +# --------------------------------------------------------------------------- + + +@router.get("/science.alt.dataset.getIndex") +async def get_index( + request: Request, + index: str = Query(...), + cursor: str | None = Query(None), + limit: int = Query(50, ge=1, le=100), +) -> IndexResponse: + pool = request.app.state.db_pool + try: + did, _, rkey = parse_at_uri(index) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid AT-URI for index") + provider = await query_get_index_provider(pool, did, rkey) + if not provider: + raise HTTPException(status_code=404, detail="Index provider not found") + + data = await _fetch_skeleton(provider["endpoint_url"], cursor, limit) + + # Parse URIs from skeleton items and hydrate + keys: list[tuple[str, str]] = [] + for item in data["items"]: + uri = item.get("uri", "") + try: + entry_did, _, entry_rkey = parse_at_uri(uri) + keys.append((entry_did, entry_rkey)) + except ValueError: + continue # skip malformed URIs + + rows = await query_get_entries(pool, keys) + + # Build a lookup map to preserve skeleton ordering + row_map = {(r["did"], r["rkey"]): r for r in rows} + hydrated = [] + for entry_did, entry_rkey in keys: + row = row_map.get((entry_did, entry_rkey)) + if row: + hydrated.append(row_to_entry(row)) + + return IndexResponse( + items=hydrated, + cursor=data.get("cursor"), + ) + + # --------------------------------------------------------------------------- # describeService # --------------------------------------------------------------------------- diff --git a/src/atdata_app/xrpc/router.py b/src/atdata_app/xrpc/router.py index 627b382..d929478 100644 --- a/src/atdata_app/xrpc/router.py +++ b/src/atdata_app/xrpc/router.py @@ -1,10 +1,12 @@ -"""Combined XRPC router mounting all query and procedure endpoints.""" +"""Combined XRPC router mounting all query, procedure, and subscription endpoints.""" from fastapi import APIRouter from atdata_app.xrpc.procedures import router as procedures_router from atdata_app.xrpc.queries import router as queries_router +from atdata_app.xrpc.subscriptions import router as subscriptions_router router = APIRouter(prefix="/xrpc") router.include_router(queries_router) router.include_router(procedures_router) +router.include_router(subscriptions_router) diff --git a/src/atdata_app/xrpc/subscriptions.py b/src/atdata_app/xrpc/subscriptions.py new file mode 100644 index 0000000..80bfff4 --- /dev/null +++ b/src/atdata_app/xrpc/subscriptions.py @@ -0,0 +1,77 @@ +"""WebSocket subscription endpoints for real-time change streaming.""" + +from __future__ import annotations + +import asyncio +import json +import logging + +from fastapi import APIRouter, WebSocket, WebSocketDisconnect + +from atdata_app.changestream import ChangeStream + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.websocket("/science.alt.dataset.subscribeChanges") +async def subscribe_changes(websocket: WebSocket) -> None: + """Stream real-time change events over WebSocket. + + Query parameters: + cursor: Optional sequence number to replay from. + """ + change_stream: ChangeStream = websocket.app.state.change_stream + + await websocket.accept() + + cursor_param = websocket.query_params.get("cursor") + + try: + sub_id, queue = change_stream.subscribe() + except RuntimeError: + await websocket.close(code=1013, reason="Too many subscribers") + return + + try: + # Replay buffered events if cursor provided, tracking last seq + # to deduplicate against events that also landed in the live queue. + last_replayed_seq = 0 + if cursor_param is not None: + try: + cursor = int(cursor_param) + except (ValueError, TypeError): + await websocket.close(code=1008, reason="Invalid cursor value") + return + missed = change_stream.replay_from(cursor) + for event in missed: + await websocket.send_text(json.dumps(event.to_dict())) + last_replayed_seq = event.seq + + # Stream live events with periodic keepalive on idle + while True: + try: + event = await asyncio.wait_for(queue.get(), timeout=30.0) + except asyncio.TimeoutError: + # No events for 30s — send keepalive to detect dead connections + await websocket.send_text(json.dumps({"type": "keepalive"})) + continue + + # Deduplicate events already sent during replay + if event.seq <= last_replayed_seq: + continue + await websocket.send_text(json.dumps(event.to_dict())) + # Check if we were marked as dropped due to backpressure + if change_stream.is_dropped(sub_id): + await websocket.close( + code=4000, reason="Backpressure: events were dropped" + ) + return + + except WebSocketDisconnect: + logger.debug("Subscriber %d disconnected", sub_id) + except Exception: + logger.exception("Error in subscriber %d", sub_id) + finally: + change_stream.unsubscribe(sub_id) diff --git a/tests/test_analytics.py b/tests/test_analytics.py index 6be9707..1ed9d04 100644 --- a/tests/test_analytics.py +++ b/tests/test_analytics.py @@ -8,7 +8,6 @@ import pytest from httpx import ASGITransport, AsyncClient -from atdata_app.config import AppConfig from atdata_app.database import ( fire_analytics_event, record_analytics_event, @@ -28,17 +27,12 @@ # --------------------------------------------------------------------------- -@pytest.fixture -def config() -> AppConfig: - return AppConfig(dev_mode=True, hostname="localhost", port=8000) - - @pytest.fixture def pool() -> AsyncMock: return AsyncMock() -def _mock_app(config: AppConfig, pool: AsyncMock): +def _mock_app(config, pool): """Create a FastAPI app with mocked lifespan (no real DB).""" app = create_app(config) app.state.db_pool = pool @@ -374,3 +368,303 @@ def test_describe_service_response_without_analytics(): recordCount={"science.alt.dataset.entry": 10}, ) assert resp.analytics is None + + +def test_get_entry_stats_response_with_interactions(): + resp = GetEntryStatsResponse( + views=10, searchAppearances=3, downloads=5, citations=2, derivatives=1, period="week" + ) + assert resp.downloads == 5 + assert resp.citations == 2 + assert resp.derivatives == 1 + + +def test_get_entry_stats_response_defaults_interactions(): + resp = GetEntryStatsResponse(views=10, searchAppearances=3, period="week") + assert resp.downloads == 0 + assert resp.citations == 0 + assert resp.derivatives == 0 + + +# --------------------------------------------------------------------------- +# sendInteractions endpoint +# --------------------------------------------------------------------------- + +_PROC = "atdata_app.xrpc.procedures" + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_valid_batch(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + { + "type": "download", + "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz", + }, + { + "type": "citation", + "datasetUri": "at://did:plc:def/science.alt.dataset.entry/4abc", + "timestamp": "2025-06-01T12:00:00Z", + }, + ] + }, + ) + + assert resp.status_code == 200 + assert resp.json() == {} + assert mock_fire.call_count == 2 + mock_fire.assert_any_call(pool, "download", target_did="did:plc:abc", target_rkey="3xyz") + mock_fire.assert_any_call(pool, "citation", target_did="did:plc:def", target_rkey="4abc") + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_empty_array(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": []}, + ) + + assert resp.status_code == 200 + assert resp.json() == {} + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_invalid_uri(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + {"type": "download", "datasetUri": "https://not-an-at-uri"}, + ] + }, + ) + + assert resp.status_code == 400 + assert "invalid AT-URI" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_invalid_type(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + { + "type": "bookmark", + "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz", + }, + ] + }, + ) + + assert resp.status_code == 400 + assert "invalid type" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_batch_size_exceeded(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + interactions = [ + {"type": "download", "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz"} + for _ in range(101) + ] + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": interactions}, + ) + + assert resp.status_code == 400 + assert "Batch size exceeds maximum" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_ignores_timestamp(mock_fire, mock_auth, config, pool): + """Timestamp field is accepted but not validated (informational only).""" + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + { + "type": "download", + "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz", + "timestamp": "not-a-date", + }, + ] + }, + ) + + assert resp.status_code == 200 + assert mock_fire.call_count == 1 + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_missing_dataset_uri(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + {"type": "download"}, + ] + }, + ) + + assert resp.status_code == 400 + assert "datasetUri is required" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_not_an_array(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": "not-an-array"}, + ) + + assert resp.status_code == 400 + assert "interactions must be an array" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_all_three_types(mock_fire, mock_auth, config, pool): + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={ + "interactions": [ + { + "type": "download", + "datasetUri": "at://did:plc:a/science.alt.dataset.entry/1", + }, + { + "type": "citation", + "datasetUri": "at://did:plc:b/science.alt.dataset.entry/2", + }, + { + "type": "derivative", + "datasetUri": "at://did:plc:c/science.alt.dataset.entry/3", + }, + ] + }, + ) + + assert resp.status_code == 200 + assert mock_fire.call_count == 3 + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_missing_key(mock_fire, mock_auth, config, pool): + """Body without 'interactions' key should return 400.""" + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"data": []}, + ) + + assert resp.status_code == 400 + assert "interactions must be an array" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_non_dict_item(mock_fire, mock_auth, config, pool): + """Non-object items in the interactions array should return 400.""" + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": ["not-a-dict"]}, + ) + + assert resp.status_code == 400 + assert "must be an object" in resp.json()["detail"] + mock_fire.assert_not_called() + + +@pytest.mark.asyncio +@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock) +@patch(f"{_PROC}.fire_analytics_event") +async def test_send_interactions_boundary_at_max(mock_fire, mock_auth, config, pool): + """Exactly 100 interactions (the maximum) should succeed.""" + mock_auth.return_value = MagicMock(iss="did:plc:caller") + app = _mock_app(config, pool) + transport = ASGITransport(app=app) + interactions = [ + {"type": "download", "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz"} + for _ in range(100) + ] + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.sendInteractions", + json={"interactions": interactions}, + ) + + assert resp.status_code == 200 + assert mock_fire.call_count == 100 diff --git a/tests/test_changestream.py b/tests/test_changestream.py new file mode 100644 index 0000000..48e4318 --- /dev/null +++ b/tests/test_changestream.py @@ -0,0 +1,442 @@ +"""Tests for the change stream event bus and WebSocket subscription endpoint.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from starlette.testclient import TestClient + +from atdata_app.changestream import ChangeEvent, ChangeStream, make_change_event +from atdata_app.ingestion.processor import process_commit +from atdata_app.xrpc.subscriptions import router as subscriptions_router + + +# --------------------------------------------------------------------------- +# ChangeStream unit tests +# --------------------------------------------------------------------------- + + +class TestChangeStream: + def test_publish_assigns_monotonic_seq(self): + cs = ChangeStream() + ev1 = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + ev2 = make_change_event( + event_type="update", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="def", + ) + cs.publish(ev1) + cs.publish(ev2) + assert ev1.seq == 1 + assert ev2.seq == 2 + assert cs.current_seq == 2 + + def test_publish_delivers_to_subscribers(self): + cs = ChangeStream() + sub_id, queue = cs.subscribe() + + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + cs.publish(ev) + + assert not queue.empty() + received = queue.get_nowait() + assert received.seq == 1 + assert received.did == "did:plc:test" + + def test_multiple_subscribers_receive_events(self): + cs = ChangeStream() + _, q1 = cs.subscribe() + _, q2 = cs.subscribe() + + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + cs.publish(ev) + + assert not q1.empty() + assert not q2.empty() + assert q1.get_nowait().seq == 1 + assert q2.get_nowait().seq == 1 + + def test_unsubscribe_removes_subscriber(self): + cs = ChangeStream() + sub_id, queue = cs.subscribe() + assert cs.subscriber_count == 1 + + cs.unsubscribe(sub_id) + assert cs.subscriber_count == 0 + + # Publishing after unsubscribe should not deliver + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + cs.publish(ev) + assert queue.empty() + + def test_full_queue_drops_event(self): + cs = ChangeStream(subscriber_queue_size=1) + _, queue = cs.subscribe() + + ev1 = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="a", + ) + ev2 = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="b", + ) + cs.publish(ev1) + cs.publish(ev2) # Should be dropped (queue full) + + assert queue.qsize() == 1 + assert queue.get_nowait().rkey == "a" + + def test_replay_from_cursor(self): + cs = ChangeStream(buffer_size=10) + for i in range(5): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + cs.publish(ev) + + # Replay from seq 3 — should get events 4 and 5 + replayed = cs.replay_from(3) + assert len(replayed) == 2 + assert replayed[0].seq == 4 + assert replayed[1].seq == 5 + + def test_replay_from_zero_returns_all(self): + cs = ChangeStream(buffer_size=10) + for i in range(3): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + cs.publish(ev) + + replayed = cs.replay_from(0) + assert len(replayed) == 3 + + def test_replay_cursor_too_old(self): + cs = ChangeStream(buffer_size=3) + for i in range(5): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + cs.publish(ev) + + # Buffer only holds seq 3, 4, 5 — cursor 1 is too old + replayed = cs.replay_from(1) + assert len(replayed) == 0 + + def test_replay_empty_buffer(self): + cs = ChangeStream() + assert cs.replay_from(0) == [] + + def test_bounded_buffer(self): + cs = ChangeStream(buffer_size=3) + for i in range(10): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + cs.publish(ev) + + assert len(cs._buffer) == 3 + assert cs._buffer[0].seq == 8 + assert cs._buffer[-1].seq == 10 + + +class TestChangeEvent: + def test_to_dict_create(self): + ev = ChangeEvent( + seq=1, + type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + timestamp="2026-01-01T00:00:00Z", + record={"name": "test"}, + cid="bafytest", + ) + d = ev.to_dict() + assert d["seq"] == 1 + assert d["type"] == "create" + assert d["record"] == {"name": "test"} + assert d["cid"] == "bafytest" + + def test_to_dict_delete_omits_record_and_cid(self): + ev = ChangeEvent( + seq=2, + type="delete", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + timestamp="2026-01-01T00:00:00Z", + ) + d = ev.to_dict() + assert "record" not in d + assert "cid" not in d + + +class TestMakeChangeEvent: + def test_creates_event_with_timestamp(self): + ev = make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + record={"name": "test"}, + cid="bafytest", + ) + assert ev.seq == 0 # Not yet assigned + assert ev.type == "create" + assert ev.timestamp # Should have a timestamp + + +# --------------------------------------------------------------------------- +# Processor integration tests +# --------------------------------------------------------------------------- + +_DB = "atdata_app.database" + + +def _make_event( + did: str = "did:plc:test123", + collection: str = "science.alt.dataset.entry", + operation: str = "create", + rkey: str = "3xyz", + record: dict | None = None, + cid: str = "bafytest", +) -> dict: + commit: dict = { + "rev": "rev1", + "operation": operation, + "collection": collection, + "rkey": rkey, + } + if operation != "delete": + commit["record"] = record or { + "$type": collection, + "name": "test-dataset", + "schemaRef": "at://did:plc:test/science.alt.dataset.schema/test@1.0.0", + "storage": {"$type": "science.alt.dataset.storageHttp", "shards": []}, + "createdAt": "2025-01-01T00:00:00Z", + } + commit["cid"] = cid + + return { + "did": did, + "time_us": 1725911162329308, + "kind": "commit", + "commit": commit, + } + + +@pytest.mark.asyncio +async def test_processor_publishes_create_event(): + mock_upsert = AsyncMock() + pool = AsyncMock() + cs = ChangeStream() + _, queue = cs.subscribe() + + event = _make_event(operation="create") + with patch.dict(f"{_DB}.UPSERT_FNS", {"entries": mock_upsert}): + await process_commit(pool, event, change_stream=cs) + + assert not queue.empty() + change_event = queue.get_nowait() + assert change_event.type == "create" + assert change_event.collection == "science.alt.dataset.entry" + assert change_event.did == "did:plc:test123" + assert change_event.rkey == "3xyz" + assert change_event.record is not None + assert change_event.cid == "bafytest" + + +@pytest.mark.asyncio +@patch(f"{_DB}.delete_record", new_callable=AsyncMock) +async def test_processor_publishes_delete_event(mock_delete): + pool = AsyncMock() + cs = ChangeStream() + _, queue = cs.subscribe() + + event = _make_event(operation="delete") + await process_commit(pool, event, change_stream=cs) + + assert not queue.empty() + change_event = queue.get_nowait() + assert change_event.type == "delete" + assert change_event.collection == "science.alt.dataset.entry" + assert change_event.record is None + assert change_event.cid is None + + +@pytest.mark.asyncio +async def test_processor_no_event_on_upsert_failure(): + mock_upsert = AsyncMock(side_effect=Exception("db error")) + pool = AsyncMock() + cs = ChangeStream() + _, queue = cs.subscribe() + + event = _make_event(operation="create") + with patch.dict(f"{_DB}.UPSERT_FNS", {"entries": mock_upsert}): + await process_commit(pool, event, change_stream=cs) + + assert queue.empty() + + +@pytest.mark.asyncio +async def test_processor_works_without_change_stream(): + """Backward compat: process_commit works when change_stream is None.""" + mock_upsert = AsyncMock() + pool = AsyncMock() + event = _make_event(operation="create") + with patch.dict(f"{_DB}.UPSERT_FNS", {"entries": mock_upsert}): + await process_commit(pool, event) + mock_upsert.assert_called_once() + + +# --------------------------------------------------------------------------- +# WebSocket endpoint tests +# --------------------------------------------------------------------------- + + +@pytest.fixture +def app_with_changestream(): + """Minimal app with just the subscriptions router — no DB lifespan.""" + app = FastAPI() + app.state.change_stream = ChangeStream() + app.include_router(subscriptions_router, prefix="/xrpc") + return app + + +def test_websocket_subscribe_and_receive(app_with_changestream): + app = app_with_changestream + cs: ChangeStream = app.state.change_stream + + with TestClient(app) as client: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges" + ) as ws: + # Publish an event from another "thread" + cs.publish( + make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + record={"name": "test"}, + cid="bafytest", + ) + ) + + data = ws.receive_json() + assert data["seq"] == 1 + assert data["type"] == "create" + assert data["collection"] == "science.alt.dataset.entry" + assert data["did"] == "did:plc:test" + assert data["record"] == {"name": "test"} + + +def test_websocket_cursor_replay(app_with_changestream): + app = app_with_changestream + cs: ChangeStream = app.state.change_stream + + # Pre-populate buffer + for i in range(5): + cs.publish( + make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey=str(i), + ) + ) + + with TestClient(app) as client: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges?cursor=3" + ) as ws: + # Should replay events 4 and 5 + msg1 = ws.receive_json() + assert msg1["seq"] == 4 + + msg2 = ws.receive_json() + assert msg2["seq"] == 5 + + +def test_websocket_disconnect_cleanup(app_with_changestream): + app = app_with_changestream + cs: ChangeStream = app.state.change_stream + + with TestClient(app) as client: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges" + ): + assert cs.subscriber_count == 1 + + # After disconnect, subscriber should be cleaned up + assert cs.subscriber_count == 0 + + +def test_websocket_multiple_subscribers(app_with_changestream): + app = app_with_changestream + cs: ChangeStream = app.state.change_stream + + with TestClient(app) as client: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges" + ) as ws1: + with client.websocket_connect( + "/xrpc/science.alt.dataset.subscribeChanges" + ) as ws2: + assert cs.subscriber_count == 2 + + cs.publish( + make_change_event( + event_type="create", + collection="science.alt.dataset.entry", + did="did:plc:test", + rkey="abc", + ) + ) + + d1 = ws1.receive_json() + d2 = ws2.receive_json() + assert d1["seq"] == d2["seq"] == 1 + + assert cs.subscriber_count == 0 diff --git a/tests/test_frontend.py b/tests/test_frontend.py index 43f714f..7e3d71e 100644 --- a/tests/test_frontend.py +++ b/tests/test_frontend.py @@ -41,6 +41,7 @@ def _make_schema_row( did: str = "did:plc:test123", rkey: str = "test@1.0.0", name: str = "TestSchema", + schema_body: str | dict = '{"type": "object"}', ) -> dict: return { "did": did, @@ -49,7 +50,7 @@ def _make_schema_row( "name": name, "version": "1.0.0", "schema_type": "jsonSchema", - "schema_body": '{"type": "object"}', + "schema_body": schema_body, "description": "A test schema", "metadata": None, "created_at": "2025-01-01T00:00:00Z", @@ -140,10 +141,12 @@ async def test_home_search(mock_search): @pytest.mark.asyncio @patch("atdata_app.frontend.routes.query_labels_for_dataset", new_callable=AsyncMock) +@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock) @patch("atdata_app.frontend.routes.query_get_entry", new_callable=AsyncMock) -async def test_dataset_detail(mock_get, mock_labels): +async def test_dataset_detail(mock_get, mock_schema, mock_labels): pool, _conn = _mock_pool() mock_get.return_value = _make_entry_row() + mock_schema.return_value = _make_schema_row() mock_labels.return_value = [_make_label_row()] app = _make_app(pool) transport = ASGITransport(app=app) @@ -260,6 +263,96 @@ async def test_about(mock_counts): assert "did:web:localhost%3A8000" in resp.text +# --------------------------------------------------------------------------- +# Schema detail — array format & ndarray annotations +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock) +async def test_schema_detail_array_format(mock_get): + pool, _conn = _mock_pool() + mock_get.return_value = _make_schema_row( + schema_body={ + "arrayFormat": "sparseBytes", + "dtype": "float32", + "shape": [100, 200], + "dimensionNames": ["samples", "features"], + }, + ) + app = _make_app(pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/schema/did:plc:test123/test@1.0.0") + assert resp.status_code == 200 + assert "Sparse matrix" in resp.text + assert "float32" in resp.text + assert "100" in resp.text + assert "samples" in resp.text + + +@pytest.mark.asyncio +@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock) +async def test_schema_detail_no_array_format(mock_get): + """Plain schemas should not show array format rows.""" + pool, _conn = _mock_pool() + mock_get.return_value = _make_schema_row() + app = _make_app(pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/schema/did:plc:test123/test@1.0.0") + assert resp.status_code == 200 + assert "Array Format" not in resp.text + assert "Data Type" not in resp.text + + +# --------------------------------------------------------------------------- +# Dataset detail — schema format info +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("atdata_app.frontend.routes.query_labels_for_dataset", new_callable=AsyncMock) +@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock) +@patch("atdata_app.frontend.routes.query_get_entry", new_callable=AsyncMock) +async def test_dataset_detail_with_schema_format(mock_entry, mock_schema, mock_labels): + pool, _conn = _mock_pool() + mock_entry.return_value = _make_entry_row() + mock_schema.return_value = _make_schema_row( + did="did:plc:test", + rkey="test@1.0.0", + schema_body={"arrayFormat": "numpyBytes", "dtype": "float64"}, + ) + mock_labels.return_value = [] + app = _make_app(pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/dataset/did:plc:test123/3xyz") + assert resp.status_code == 200 + assert "NumPy ndarray" in resp.text + assert "float64" in resp.text + + +# --------------------------------------------------------------------------- +# Schemas list — format column +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("atdata_app.frontend.routes.query_list_schemas", new_callable=AsyncMock) +async def test_schemas_list_shows_format(mock_list): + pool, _conn = _mock_pool() + mock_list.return_value = [ + _make_schema_row(schema_body={"arrayFormat": "safetensors"}), + ] + app = _make_app(pool) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/schemas") + assert resp.status_code == 200 + assert "Safetensors" in resp.text + + # --------------------------------------------------------------------------- # Static files # --------------------------------------------------------------------------- diff --git a/tests/test_index.py b/tests/test_index.py new file mode 100644 index 0000000..15c84a8 --- /dev/null +++ b/tests/test_index.py @@ -0,0 +1,541 @@ +"""Tests for index provider endpoints (skeleton/hydration pattern).""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from httpx import ASGITransport, AsyncClient + +from atdata_app.config import AppConfig +from atdata_app.ingestion.processor import process_commit +from atdata_app.main import create_app + +_DB = "atdata_app.database" +_QUERIES = "atdata_app.xrpc.queries" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_app() -> tuple: + config = AppConfig(dev_mode=True, hostname="localhost", port=8000) + app = create_app(config) + pool = AsyncMock() + app.state.db_pool = pool + return app, pool + + +def _index_provider_row( + did: str = "did:plc:provider1", + rkey: str = "3abc", + endpoint_url: str = "https://example.com/skeleton", + name: str = "Genomics Index", + description: str = "Curated genomics datasets", +) -> dict: + """Simulate an asyncpg Record as a dict-like object.""" + return MagicMock( + **{ + "__getitem__": lambda self, key: { + "did": did, + "rkey": rkey, + "cid": "bafyindex", + "name": name, + "description": description, + "endpoint_url": endpoint_url, + "created_at": "2025-01-01T00:00:00Z", + "indexed_at": "2025-01-01T00:00:00+00:00", + }[key], + } + ) + + +def _entry_row(did: str = "did:plc:author1", rkey: str = "3xyz") -> MagicMock: + return MagicMock( + **{ + "__getitem__": lambda self, key: { + "did": did, + "rkey": rkey, + "cid": "bafyentry", + "name": "test-dataset", + "schema_ref": "at://did:plc:test/science.alt.dataset.schema/test@1.0.0", + "storage": '{"$type": "science.alt.dataset.storageHttp", "shards": []}', + "description": None, + "tags": None, + "license": None, + "size_samples": None, + "size_bytes": None, + "size_shards": None, + "created_at": "2025-01-01T00:00:00Z", + "indexed_at": "2025-01-01T00:00:00+00:00", + }[key], + } + ) + + +# --------------------------------------------------------------------------- +# getIndexSkeleton +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_success(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + + skeleton_response = { + "items": [{"uri": "at://did:plc:a/science.alt.dataset.entry/3xyz"}], + "cursor": "next123", + } + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = skeleton_response + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["items"]) == 1 + assert data["items"][0]["uri"] == "at://did:plc:a/science.alt.dataset.entry/3xyz" + assert data["cursor"] == "next123" + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_not_found(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = None + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:missing/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_index_skeleton_invalid_uri(): + app, pool = _make_app() + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "not-a-uri"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_upstream_error(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + + mock_resp = MagicMock() + mock_resp.status_code = 500 + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 502 + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_upstream_unreachable(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectError("Connection refused") + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 502 + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_skeleton_invalid_response(mock_get_provider): + """Upstream returns JSON without 'items' array.""" + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"bad": "data"} + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndexSkeleton", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 502 + + +# --------------------------------------------------------------------------- +# getIndex (hydrated) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_entries", new_callable=AsyncMock) +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_hydrated(mock_get_provider, mock_get_entries): + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + mock_get_entries.return_value = [_entry_row("did:plc:a", "3xyz")] + + skeleton_response = { + "items": [ + {"uri": "at://did:plc:a/science.alt.dataset.entry/3xyz"}, + ], + "cursor": "next456", + } + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = skeleton_response + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndex", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["items"]) == 1 + assert data["items"][0]["name"] == "test-dataset" + assert data["cursor"] == "next456" + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_entries", new_callable=AsyncMock) +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_omits_missing_entries(mock_get_provider, mock_get_entries): + """Entries not in the DB should be silently omitted.""" + app, pool = _make_app() + mock_get_provider.return_value = _index_provider_row() + # Return only one of two requested entries + mock_get_entries.return_value = [_entry_row("did:plc:a", "3xyz")] + + skeleton_response = { + "items": [ + {"uri": "at://did:plc:a/science.alt.dataset.entry/3xyz"}, + {"uri": "at://did:plc:b/science.alt.dataset.entry/3deleted"}, + ], + } + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = skeleton_response + + with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndex", + params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data["items"]) == 1 + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock) +async def test_get_index_not_found(mock_get_provider): + app, pool = _make_app() + mock_get_provider.return_value = None + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.getIndex", + params={"index": "at://did:plc:missing/science.alt.dataset.index/3abc"}, + ) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# listIndexes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_list_index_providers", new_callable=AsyncMock) +async def test_list_indexes(mock_list): + app, pool = _make_app() + mock_list.return_value = [_index_provider_row()] + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/xrpc/science.alt.dataset.listIndexes") + assert resp.status_code == 200 + data = resp.json() + assert len(data["indexes"]) == 1 + assert data["indexes"][0]["name"] == "Genomics Index" + assert data["indexes"][0]["endpointUrl"] == "https://example.com/skeleton" + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_list_index_providers", new_callable=AsyncMock) +async def test_list_indexes_empty(mock_list): + app, pool = _make_app() + mock_list.return_value = [] + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/xrpc/science.alt.dataset.listIndexes") + assert resp.status_code == 200 + data = resp.json() + assert data["indexes"] == [] + assert data["cursor"] is None + + +@pytest.mark.asyncio +@patch(f"{_QUERIES}.query_list_index_providers", new_callable=AsyncMock) +async def test_list_indexes_with_repo_filter(mock_list): + app, pool = _make_app() + mock_list.return_value = [] + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/xrpc/science.alt.dataset.listIndexes", + params={"repo": "did:plc:provider1"}, + ) + assert resp.status_code == 200 + mock_list.assert_called_once_with(pool, "did:plc:provider1", 50, None, None, None) + + +# --------------------------------------------------------------------------- +# publishIndex +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock) +@patch("atdata_app.xrpc.procedures._resolve_pds", new_callable=AsyncMock) +@patch("atdata_app.xrpc.procedures._proxy_create_record", new_callable=AsyncMock) +async def test_publish_index(mock_proxy, mock_pds, mock_auth): + app, pool = _make_app() + + mock_auth.return_value = MagicMock(iss="did:plc:publisher1") + mock_pds.return_value = "https://pds.example.com" + mock_proxy.return_value = { + "uri": "at://did:plc:publisher1/science.alt.dataset.index/3abc", + "cid": "bafynew", + } + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.publishIndex", + json={ + "record": { + "name": "Genomics Index", + "endpointUrl": "https://example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + }, + headers={ + "Authorization": "Bearer test-token", + "X-PDS-Auth": "pds-token", + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["uri"] == "at://did:plc:publisher1/science.alt.dataset.index/3abc" + assert data["cid"] == "bafynew" + + +@pytest.mark.asyncio +@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock) +async def test_publish_index_missing_field(mock_auth): + app, pool = _make_app() + mock_auth.return_value = MagicMock(iss="did:plc:publisher1") + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.publishIndex", + json={ + "record": { + "name": "Test", + "createdAt": "2025-01-01T00:00:00Z", + # missing endpointUrl + }, + }, + headers={ + "Authorization": "Bearer test-token", + "X-PDS-Auth": "pds-token", + }, + ) + assert resp.status_code == 400 + assert "endpointUrl" in resp.json()["detail"] + + +@pytest.mark.asyncio +@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock) +async def test_publish_index_http_url_rejected(mock_auth): + app, pool = _make_app() + mock_auth.return_value = MagicMock(iss="did:plc:publisher1") + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.publishIndex", + json={ + "record": { + "name": "Bad Index", + "endpointUrl": "http://insecure.example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + }, + headers={ + "Authorization": "Bearer test-token", + "X-PDS-Auth": "pds-token", + }, + ) + assert resp.status_code == 400 + assert "HTTPS" in resp.json()["detail"] + + +@pytest.mark.asyncio +@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock) +async def test_publish_index_invalid_type(mock_auth): + app, pool = _make_app() + mock_auth.return_value = MagicMock(iss="did:plc:publisher1") + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/xrpc/science.alt.dataset.publishIndex", + json={ + "record": { + "$type": "science.alt.dataset.entry", + "name": "Wrong Type", + "endpointUrl": "https://example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + }, + headers={ + "Authorization": "Bearer test-token", + "X-PDS-Auth": "pds-token", + }, + ) + assert resp.status_code == 400 + assert "Invalid $type" in resp.json()["detail"] + + +# --------------------------------------------------------------------------- +# Ingestion: index provider records +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_process_commit_index_provider(): + mock_upsert = AsyncMock() + pool = AsyncMock() + event = { + "did": "did:plc:provider1", + "time_us": 1725911162329308, + "kind": "commit", + "commit": { + "rev": "rev1", + "operation": "create", + "collection": "science.alt.dataset.index", + "rkey": "3abc", + "record": { + "$type": "science.alt.dataset.index", + "name": "Genomics Index", + "endpointUrl": "https://example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + "cid": "bafyindex", + }, + } + with patch.dict(f"{_DB}.UPSERT_FNS", {"index_providers": mock_upsert}): + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:provider1", "3abc", "bafyindex", event["commit"]["record"] + ) + + +@pytest.mark.asyncio +@patch(f"{_DB}.delete_record", new_callable=AsyncMock) +async def test_process_commit_delete_index_provider(mock_delete): + pool = AsyncMock() + event = { + "did": "did:plc:provider1", + "time_us": 1725911162329308, + "kind": "commit", + "commit": { + "rev": "rev1", + "operation": "delete", + "collection": "science.alt.dataset.index", + "rkey": "3abc", + }, + } + await process_commit(pool, event) + mock_delete.assert_called_once_with(pool, "index_providers", "did:plc:provider1", "3abc") diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py index 2de2b49..fd31238 100644 --- a/tests/test_ingestion.py +++ b/tests/test_ingestion.py @@ -8,7 +8,6 @@ from atdata_app.ingestion.processor import process_commit -# All patches target the `db` module reference used inside processor.py _DB = "atdata_app.database" @@ -44,15 +43,22 @@ def _make_event( } +def _patch_upsert(table: str): + """Patch a single entry in UPSERT_FNS by table name.""" + mock = AsyncMock() + return patch.dict(f"{_DB}.UPSERT_FNS", {table: mock}), mock + + @pytest.mark.asyncio -@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock) -async def test_process_commit_create(mock_upsert): - pool = AsyncMock() - event = _make_event(operation="create") - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) +async def test_process_commit_create(): + patcher, mock_upsert = _patch_upsert("entries") + with patcher: + pool = AsyncMock() + event = _make_event(operation="create") + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio @@ -72,85 +78,109 @@ async def test_process_commit_delete(mock_delete): @pytest.mark.asyncio -@patch(f"{_DB}.upsert_schema", new_callable=AsyncMock) -async def test_process_commit_schema(mock_upsert): - pool = AsyncMock() - event = _make_event( - collection="science.alt.dataset.schema", - record={ - "$type": "science.alt.dataset.schema", - "name": "TestSchema", - "version": "1.0.0", - "schemaType": "jsonSchema", - "schema": {"$type": "science.alt.dataset.schema#jsonSchemaFormat"}, - "createdAt": "2025-01-01T00:00:00Z", - }, - ) - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) +async def test_process_commit_schema(): + patcher, mock_upsert = _patch_upsert("schemas") + with patcher: + pool = AsyncMock() + event = _make_event( + collection="science.alt.dataset.schema", + record={ + "$type": "science.alt.dataset.schema", + "name": "TestSchema", + "version": "1.0.0", + "schemaType": "jsonSchema", + "schema": {"$type": "science.alt.dataset.schema#jsonSchemaFormat"}, + "createdAt": "2025-01-01T00:00:00Z", + }, + ) + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio -@patch(f"{_DB}.upsert_label", new_callable=AsyncMock) -async def test_process_commit_label(mock_upsert): - pool = AsyncMock() - event = _make_event( - collection="science.alt.dataset.label", - record={ - "$type": "science.alt.dataset.label", - "name": "mnist", - "datasetUri": "at://did:plc:test/science.alt.dataset.entry/3xyz", - "createdAt": "2025-01-01T00:00:00Z", - }, - ) - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) +async def test_process_commit_label(): + patcher, mock_upsert = _patch_upsert("labels") + with patcher: + pool = AsyncMock() + event = _make_event( + collection="science.alt.dataset.label", + record={ + "$type": "science.alt.dataset.label", + "name": "mnist", + "datasetUri": "at://did:plc:test/science.alt.dataset.entry/3xyz", + "createdAt": "2025-01-01T00:00:00Z", + }, + ) + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio -@patch(f"{_DB}.upsert_lens", new_callable=AsyncMock) -async def test_process_commit_lens(mock_upsert): - pool = AsyncMock() - event = _make_event( - collection="science.alt.dataset.lens", - record={ - "$type": "science.alt.dataset.lens", - "name": "test-lens", - "sourceSchema": "at://did:plc:test/science.alt.dataset.schema/a@1.0.0", - "targetSchema": "at://did:plc:test/science.alt.dataset.schema/b@1.0.0", - "getterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "get.py"}, - "putterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "put.py"}, - "createdAt": "2025-01-01T00:00:00Z", - }, - ) - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) +async def test_process_commit_lens(): + patcher, mock_upsert = _patch_upsert("lenses") + with patcher: + pool = AsyncMock() + event = _make_event( + collection="science.alt.dataset.lens", + record={ + "$type": "science.alt.dataset.lens", + "name": "test-lens", + "sourceSchema": "at://did:plc:test/science.alt.dataset.schema/a@1.0.0", + "targetSchema": "at://did:plc:test/science.alt.dataset.schema/b@1.0.0", + "getterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "get.py"}, + "putterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "put.py"}, + "createdAt": "2025-01-01T00:00:00Z", + }, + ) + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) + + +@pytest.mark.asyncio +async def test_process_commit_index_provider(): + patcher, mock_upsert = _patch_upsert("index_providers") + with patcher: + pool = AsyncMock() + event = _make_event( + collection="science.alt.dataset.index", + record={ + "$type": "science.alt.dataset.index", + "name": "Genomics Index", + "endpointUrl": "https://example.com/skeleton", + "createdAt": "2025-01-01T00:00:00Z", + }, + ) + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio -@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock) -async def test_process_commit_update(mock_upsert): +async def test_process_commit_update(): """Update operations should route to the same upsert function as create.""" - pool = AsyncMock() - event = _make_event(operation="update") - await process_commit(pool, event) - mock_upsert.assert_called_once_with( - pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] - ) + patcher, mock_upsert = _patch_upsert("entries") + with patcher: + pool = AsyncMock() + event = _make_event(operation="update") + await process_commit(pool, event) + mock_upsert.assert_called_once_with( + pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"] + ) @pytest.mark.asyncio -@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock) -async def test_process_commit_upsert_error_is_caught(mock_upsert): +async def test_process_commit_upsert_error_is_caught(): """Upsert failures should be logged, not raised.""" - mock_upsert.side_effect = Exception("db error") - pool = AsyncMock() - event = _make_event(operation="create") - # Should not raise - await process_commit(pool, event) + mock = AsyncMock(side_effect=Exception("db error")) + with patch.dict(f"{_DB}.UPSERT_FNS", {"entries": mock}): + pool = AsyncMock() + event = _make_event(operation="create") + # Should not raise + await process_commit(pool, event) diff --git a/tests/test_integration.py b/tests/test_integration.py index 485ee27..2cccd1f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -162,6 +162,7 @@ async def test_all_expected_tables_exist(self, db_pool): "entries", "labels", "lenses", + "index_providers", "cursor_state", "analytics_events", "analytics_counters", @@ -196,6 +197,8 @@ async def test_expected_indexes_exist(self, db_pool): "idx_lenses_did", "idx_analytics_events_type_created", "idx_analytics_events_target", + "idx_index_providers_did", + "idx_index_providers_indexed_at", } async with db_pool.acquire() as conn: rows = await conn.fetch( diff --git a/tests/test_models.py b/tests/test_models.py index 23f6046..a09fc1e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,11 +5,14 @@ import pytest from atdata_app.models import ( + ARRAY_FORMAT_LABELS, + KNOWN_ARRAY_FORMATS, decode_cursor, encode_cursor, make_at_uri, parse_at_uri, row_to_entry, + row_to_index_provider, row_to_label, row_to_lens, row_to_schema, @@ -174,6 +177,95 @@ def test_row_to_schema_json_string_body(): assert d["schema"] == {"type": "object"} +def test_row_to_schema_no_array_format_fields_when_absent(): + """Plain schemas should not gain arrayFormat/ndarray annotation keys.""" + d = row_to_schema(_SCHEMA_ROW) + assert "arrayFormat" not in d + assert "arrayFormatLabel" not in d + assert "dtype" not in d + assert "shape" not in d + assert "dimensionNames" not in d + + +# --------------------------------------------------------------------------- +# row_to_schema — array format types +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("fmt", sorted(KNOWN_ARRAY_FORMATS)) +def test_row_to_schema_known_array_format(fmt): + """Each known format token should surface arrayFormat and a human label.""" + row = {**_SCHEMA_ROW, "schema_body": {"arrayFormat": fmt}} + d = row_to_schema(row) + assert d["arrayFormat"] == fmt + assert d["arrayFormatLabel"] == ARRAY_FORMAT_LABELS[fmt] + + +def test_row_to_schema_unknown_array_format_passes_through(): + """Unknown format tokens are stored and surfaced as-is.""" + row = {**_SCHEMA_ROW, "schema_body": {"arrayFormat": "futureFormat"}} + d = row_to_schema(row) + assert d["arrayFormat"] == "futureFormat" + assert d["arrayFormatLabel"] == "futureFormat" + + +# --------------------------------------------------------------------------- +# row_to_schema — ndarray v1.1.0 annotations +# --------------------------------------------------------------------------- + + +def test_row_to_schema_ndarray_annotations(): + """ndarray v1.1.0 annotation fields are surfaced at top level.""" + row = { + **_SCHEMA_ROW, + "schema_body": { + "arrayFormat": "numpyBytes", + "dtype": "float32", + "shape": [100, 200], + "dimensionNames": ["samples", "features"], + }, + } + d = row_to_schema(row) + assert d["arrayFormat"] == "numpyBytes" + assert d["dtype"] == "float32" + assert d["shape"] == [100, 200] + assert d["dimensionNames"] == ["samples", "features"] + + +def test_row_to_schema_ndarray_partial_annotations(): + """Only present annotation fields should appear in output.""" + row = { + **_SCHEMA_ROW, + "schema_body": {"arrayFormat": "sparseBytes", "dtype": "int64"}, + } + d = row_to_schema(row) + assert d["dtype"] == "int64" + assert "shape" not in d + assert "dimensionNames" not in d + + +# --------------------------------------------------------------------------- +# KNOWN_ARRAY_FORMATS constant +# --------------------------------------------------------------------------- + + +def test_known_array_formats_contains_all_expected(): + expected = { + "numpyBytes", + "parquetBytes", + "sparseBytes", + "structuredBytes", + "arrowTensor", + "safetensors", + } + assert KNOWN_ARRAY_FORMATS == expected + + +def test_array_format_labels_covers_all_known(): + """Every known format should have a human-readable label.""" + assert set(ARRAY_FORMAT_LABELS.keys()) == KNOWN_ARRAY_FORMATS + + # --------------------------------------------------------------------------- # row_to_label # --------------------------------------------------------------------------- @@ -193,6 +285,7 @@ def test_row_to_schema_json_string_body(): def test_row_to_label(): d = row_to_label(_LABEL_ROW) assert d["uri"] == "at://did:plc:abc/science.alt.dataset.label/3lbl" + assert d["did"] == "did:plc:abc" assert d["datasetUri"] == _LABEL_ROW["dataset_uri"] assert d["version"] == "1.0.0" assert d["description"] == "First version" @@ -227,6 +320,7 @@ def test_row_to_label_omits_optional_fields(): def test_row_to_lens(): d = row_to_lens(_LENS_ROW) assert d["uri"] == "at://did:plc:abc/science.alt.dataset.lens/3lens" + assert d["did"] == "did:plc:abc" assert d["sourceSchema"] == _LENS_ROW["source_schema"] assert d["getterCode"] == _LENS_ROW["getter_code"] assert d["description"] == "Transforms A to B" @@ -249,3 +343,33 @@ def test_row_to_lens_json_string_code(): d = row_to_lens(row) assert d["getterCode"] == {"repo": "x"} assert d["putterCode"] == {"repo": "y"} + + +# --------------------------------------------------------------------------- +# row_to_index_provider +# --------------------------------------------------------------------------- + +_INDEX_PROVIDER_ROW = { + "did": "did:plc:provider1", + "rkey": "3idx", + "cid": "bafyindex", + "name": "Genomics Index", + "description": "Curated genomics datasets", + "endpoint_url": "https://example.com/skeleton", + "created_at": "2025-01-01T00:00:00Z", +} + + +def test_row_to_index_provider(): + d = row_to_index_provider(_INDEX_PROVIDER_ROW) + assert d["uri"] == "at://did:plc:provider1/science.alt.dataset.index/3idx" + assert d["did"] == "did:plc:provider1" + assert d["name"] == "Genomics Index" + assert d["endpointUrl"] == "https://example.com/skeleton" + assert d["description"] == "Curated genomics datasets" + + +def test_row_to_index_provider_omits_null_description(): + row = {**_INDEX_PROVIDER_ROW, "description": None} + d = row_to_index_provider(row) + assert "description" not in d diff --git a/uv.lock b/uv.lock index 72b071a..2dfa852 100644 --- a/uv.lock +++ b/uv.lock @@ -75,7 +75,7 @@ wheels = [ [[package]] name = "atdata-app" -version = "0.3.0b1" +version = "0.4.0b1" source = { editable = "." } dependencies = [ { name = "asyncpg" },
NameVersionTypeDescriptionPublisher
NameVersionTypeFormatDescriptionPublisher
{{ s.name }} {{ s.version }} {{ s.schemaType }}{{ s.get("arrayFormatLabel", "") }} {{ s.get("description", "") }} {{ s.did[:20] }}…