diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 95cd225..4526e04 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -15,6 +15,8 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
+ with:
+ submodules: true
- uses: astral-sh/setup-uv@v5
with:
enable-cache: true
@@ -28,6 +30,8 @@ jobs:
python-version: ["3.12", "3.13"]
steps:
- uses: actions/checkout@v4
+ with:
+ submodules: true
- uses: astral-sh/setup-uv@v5
with:
enable-cache: true
@@ -56,6 +60,8 @@ jobs:
--health-retries 5
steps:
- uses: actions/checkout@v4
+ with:
+ submodules: true
- name: Apply schema
env:
PGHOST: localhost
@@ -105,6 +111,8 @@ jobs:
--health-retries 5
steps:
- uses: actions/checkout@v4
+ with:
+ submodules: true
- uses: astral-sh/setup-uv@v5
with:
enable-cache: true
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 4c9d7e5..48eb356 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -12,6 +12,8 @@ jobs:
id-token: write
steps:
- uses: actions/checkout@v4
+ with:
+ submodules: true
- uses: astral-sh/setup-uv@v5
with:
enable-cache: true
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..d4faf98
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "lexicons"]
+ path = lexicons
+ url = https://github.com/forecast-bio/atdata-lexicon.git
diff --git a/CHANGELOG.md b/CHANGELOG.md
index edb123d..e35941f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,42 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/),
and this project adheres to [Semantic Versioning](https://semver.org/).
+## [0.4.0b1] - 2026-02-26
+
+### Added
+
+- `sendInteractions` XRPC procedure for anonymous usage telemetry — fire-and-forget reporting of download, citation, and derivative events on datasets ([#21](https://github.com/forecast-bio/atdata-app/issues/21))
+- Skeleton/hydration pattern for third-party dataset indexes — `getIndexSkeleton`, `getIndex`, `listIndexes`, and `publishIndex` endpoints following Bluesky's feed generator model ([#20](https://github.com/forecast-bio/atdata-app/issues/20))
+- `subscribeChanges` WebSocket endpoint for real-time change streaming — in-memory event bus broadcasts create/update/delete events to subscribers with cursor-based replay ([#22](https://github.com/forecast-bio/atdata-app/issues/22))
+- Array format type recognition (`sparseBytes`, `structuredBytes`, `arrowTensor`, `safetensors`) and ndarray v1.1.0 annotation display (`dtype`, `shape`, `dimensionNames`) in frontend templates ([#30](https://github.com/forecast-bio/atdata-app/issues/30))
+- `atdata-lexicon` git submodule at `lexicons/` pinned to v0.2.1b1 for reference and CI validation ([#27](https://github.com/forecast-bio/atdata-app/issues/27))
+- CI checkout steps now initialize submodules
+
+### Changed
+
+- Ingestion processor refactored to use `UPSERT_FNS` dispatch dict instead of if/elif chain
+- Index provider records (`science.alt.dataset.index`) added to `COLLECTION_TABLE_MAP` for firehose ingestion
+
+### Security
+
+- **SSRF protection**: Skeleton fetch now validates endpoint URLs with DNS resolution and blocks private/reserved IP ranges at both fetch time and firehose ingestion time
+- **Auth**: `sendInteractions` endpoint now requires ATProto service auth (was previously unauthenticated)
+- **XSS**: Storage URLs in dataset detail pages are only rendered as clickable links when using `http(s)://` schemes, preventing `javascript:` URI injection
+- **Input validation**: `publishIndex` rejects endpoint URLs containing embedded credentials or fragments; `sendInteractions` validates that URIs reference the `science.alt.dataset.entry` collection
+
+### Fixed
+
+- **ChangeStream backpressure**: Subscribers that fall behind are now tracked and explicitly disconnected with WebSocket close code 4000, instead of silently dropping events
+- **ChangeStream subscriber limit**: Capped at 1000 concurrent subscribers; new connections receive close code 1013 when full
+- **WebSocket keepalive**: Restructured the `subscribeChanges` event loop so the 30-second idle keepalive correctly re-enters the processing loop (was previously broken)
+- **Replay deduplication**: Track last replayed sequence number to prevent duplicate events when replay buffer overlaps with the live queue
+- **Task GC**: Fire-and-forget analytics tasks now retain references to prevent garbage collection before completion
+- **Skeleton response cap**: Enforce the requested `limit` on items returned by external index providers, and cap response body size to 1 MiB
+- **Skeleton item sanitization**: Whitelist upstream skeleton items to only the `uri` field; validate cursor strings for length and null bytes
+- **Query guard**: `query_get_entries` now rejects requests with more than 100 keys to prevent unbounded OR-clause queries
+- **Template robustness**: `shape` and `dimensionNames` join filters now guard against non-iterable data from malformed firehose records
+- Removed dead `_validate_iso8601` timestamp validation code from `sendInteractions`
+
## [0.3.0b1] - 2026-02-22
### Changed
diff --git a/CLAUDE.md b/CLAUDE.md
index 6020760..c5b8285 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -28,6 +28,16 @@ uv run ruff check src/ tests/
uv run uvicorn atdata_app.main:app --reload
```
+## Lexicon Submodule
+
+The `lexicons/` directory is a git submodule pointing to [forecast-bio/atdata-lexicon](https://github.com/forecast-bio/atdata-lexicon), which contains the authoritative `science.alt.dataset.*` lexicon definitions. The submodule is for reference and CI validation — the Python source code does not read lexicon files at runtime.
+
+After cloning, initialize the submodule:
+
+```bash
+git submodule update --init
+```
+
## Architecture
### Data Flow
diff --git a/README.md b/README.md
index 00d8bdf..33ac782 100644
--- a/README.md
+++ b/README.md
@@ -37,6 +37,9 @@ ATProto Network
# Install dependencies
uv sync --dev
+# Initialize the lexicon submodule
+git submodule update --init
+
# Set up PostgreSQL (schema auto-applies on startup)
createdb atdata_app
@@ -135,6 +138,16 @@ uv run ruff check src/ tests/
Tests mock all external dependencies (database, HTTP, identity resolution) using `unittest.mock.AsyncMock`. HTTP endpoint tests use httpx `ASGITransport` for in-process testing without a running server.
+### Lexicon Definitions
+
+The `lexicons/` directory is a [git submodule](https://github.com/forecast-bio/atdata-lexicon) containing the authoritative `science.alt.dataset.*` lexicon schemas. Initialize it with:
+
+```bash
+git submodule update --init
+```
+
+The lexicons are for reference and CI validation. The Python source code uses hardcoded NSID constants and does not read the lexicon JSON files at runtime.
+
## License
MIT
diff --git a/lexicons b/lexicons
new file mode 160000
index 0000000..ece47ae
--- /dev/null
+++ b/lexicons
@@ -0,0 +1 @@
+Subproject commit ece47ae2158c3226bdf2ac280695db32ea5ad336
diff --git a/pyproject.toml b/pyproject.toml
index 708c0ac..32b3071 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "atdata-app"
-version = "0.3.0b1"
+version = "0.4.0b1"
description = "ATProto AppView for science.alt.dataset"
readme = "README.md"
authors = [
diff --git a/src/atdata_app/changestream.py b/src/atdata_app/changestream.py
new file mode 100644
index 0000000..17c3593
--- /dev/null
+++ b/src/atdata_app/changestream.py
@@ -0,0 +1,167 @@
+"""In-memory broadcast channel for real-time change events.
+
+Provides a pub/sub mechanism that the ingestion processor publishes to
+and WebSocket subscribers consume from. Maintains a bounded buffer of
+recent events for cursor-based replay.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from collections import deque
+from dataclasses import dataclass, field
+from typing import Any
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_BUFFER_SIZE = 1000
+DEFAULT_SUBSCRIBER_QUEUE_SIZE = 256
+DEFAULT_MAX_SUBSCRIBERS = 1000
+
+
+@dataclass
+class ChangeEvent:
+ """A single change event in the stream."""
+
+ seq: int
+ type: str # "create", "update", or "delete"
+ collection: str
+ did: str
+ rkey: str
+ timestamp: str
+ record: dict[str, Any] | None = None
+ cid: str | None = None
+
+ def to_dict(self) -> dict[str, Any]:
+ d: dict[str, Any] = {
+ "seq": self.seq,
+ "type": self.type,
+ "collection": self.collection,
+ "did": self.did,
+ "rkey": self.rkey,
+ "timestamp": self.timestamp,
+ }
+ if self.record is not None:
+ d["record"] = self.record
+ if self.cid is not None:
+ d["cid"] = self.cid
+ return d
+
+
+@dataclass
+class ChangeStream:
+ """Broadcast channel with bounded replay buffer.
+
+ Thread-safe for asyncio: all mutations happen in the event loop.
+ """
+
+ buffer_size: int = DEFAULT_BUFFER_SIZE
+ subscriber_queue_size: int = DEFAULT_SUBSCRIBER_QUEUE_SIZE
+ max_subscribers: int = DEFAULT_MAX_SUBSCRIBERS
+ _seq: int = field(default=0, init=False)
+ _buffer: deque[ChangeEvent] = field(init=False)
+ _subscribers: dict[int, asyncio.Queue[ChangeEvent]] = field(
+ default_factory=dict, init=False
+ )
+ _dropped_subs: set[int] = field(default_factory=set, init=False)
+ _next_sub_id: int = field(default=0, init=False)
+
+ def __post_init__(self) -> None:
+ self._buffer = deque(maxlen=self.buffer_size)
+
+ def publish(self, event: ChangeEvent) -> None:
+ """Publish an event to all subscribers and the replay buffer.
+
+ Non-blocking. If a subscriber's queue is full, the subscriber is
+ marked as dropped so the WebSocket handler can close the connection.
+ """
+ self._seq += 1
+ event.seq = self._seq
+ self._buffer.append(event)
+
+ for sub_id, queue in list(self._subscribers.items()):
+ try:
+ queue.put_nowait(event)
+ except asyncio.QueueFull:
+ logger.warning(
+ "Subscriber %d queue full at seq=%d — marking for disconnect",
+ sub_id,
+ event.seq,
+ )
+ self._dropped_subs.add(sub_id)
+
+ def subscribe(self) -> tuple[int, asyncio.Queue[ChangeEvent]]:
+ """Create a new subscriber. Returns (subscriber_id, queue).
+
+ Raises ``RuntimeError`` if the maximum subscriber count is reached.
+ """
+ if len(self._subscribers) >= self.max_subscribers:
+ raise RuntimeError(
+ f"Maximum subscriber count ({self.max_subscribers}) reached"
+ )
+ sub_id = self._next_sub_id
+ self._next_sub_id += 1
+ queue: asyncio.Queue[ChangeEvent] = asyncio.Queue(
+ maxsize=self.subscriber_queue_size
+ )
+ self._subscribers[sub_id] = queue
+ logger.debug("Subscriber %d connected (total: %d)", sub_id, len(self._subscribers))
+ return sub_id, queue
+
+ def unsubscribe(self, sub_id: int) -> None:
+ """Remove a subscriber."""
+ self._subscribers.pop(sub_id, None)
+ self._dropped_subs.discard(sub_id)
+ logger.debug("Subscriber %d disconnected (total: %d)", sub_id, len(self._subscribers))
+
+ def is_dropped(self, sub_id: int) -> bool:
+ """Return True if the subscriber was dropped due to backpressure."""
+ return sub_id in self._dropped_subs
+
+ def replay_from(self, cursor: int) -> list[ChangeEvent]:
+ """Return buffered events with seq > cursor.
+
+ Returns an empty list if the cursor is outside the buffer window.
+ """
+ if not self._buffer:
+ return []
+
+ oldest_seq = self._buffer[0].seq
+ if cursor < oldest_seq - 1:
+ # Cursor is too old — events between cursor and buffer start were lost
+ return []
+
+ return [ev for ev in self._buffer if ev.seq > cursor]
+
+ @property
+ def current_seq(self) -> int:
+ return self._seq
+
+ @property
+ def subscriber_count(self) -> int:
+ return len(self._subscribers)
+
+
+def make_change_event(
+ *,
+ event_type: str,
+ collection: str,
+ did: str,
+ rkey: str,
+ record: dict[str, Any] | None = None,
+ cid: str | None = None,
+) -> ChangeEvent:
+ """Factory for creating change events with current timestamp."""
+ from datetime import datetime, timezone
+
+ return ChangeEvent(
+ seq=0, # Assigned by ChangeStream.publish()
+ type=event_type,
+ collection=collection,
+ did=did,
+ rkey=rkey,
+ timestamp=datetime.now(timezone.utc).isoformat(),
+ record=record,
+ cid=cid,
+ )
diff --git a/src/atdata_app/database.py b/src/atdata_app/database.py
index b0ea925..6543530 100644
--- a/src/atdata_app/database.py
+++ b/src/atdata_app/database.py
@@ -20,6 +20,7 @@
f"{LEXICON_NAMESPACE}.entry": "entries",
f"{LEXICON_NAMESPACE}.label": "labels",
f"{LEXICON_NAMESPACE}.lens": "lenses",
+ f"{LEXICON_NAMESPACE}.index": "index_providers",
}
@@ -203,6 +204,57 @@ async def upsert_lens(
)
+def _is_safe_endpoint_url(url: str) -> bool:
+ """Return True if *url* is HTTPS with no credentials or private IPs."""
+ from urllib.parse import urlparse
+
+ try:
+ parsed = urlparse(url)
+ except ValueError:
+ return False
+ if parsed.scheme != "https" or not parsed.hostname:
+ return False
+ if parsed.username or parsed.password:
+ return False
+ return True
+
+
+async def upsert_index_provider(
+ pool: asyncpg.Pool,
+ did: str,
+ rkey: str,
+ cid: str | None,
+ record: dict[str, Any],
+) -> None:
+ endpoint_url = record.get("endpointUrl", "")
+ if not _is_safe_endpoint_url(endpoint_url):
+ logger.warning(
+ "Rejected index provider %s/%s: unsafe endpoint URL %s", did, rkey, endpoint_url
+ )
+ return
+ async with pool.acquire() as conn:
+ await conn.execute(
+ """
+ INSERT INTO index_providers (did, rkey, cid, name, description, endpoint_url,
+ created_at)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ ON CONFLICT (did, rkey) DO UPDATE SET
+ cid = EXCLUDED.cid,
+ name = EXCLUDED.name,
+ description = EXCLUDED.description,
+ endpoint_url = EXCLUDED.endpoint_url,
+ indexed_at = NOW()
+ """,
+ did,
+ rkey,
+ cid,
+ record["name"],
+ record.get("description"),
+ record["endpointUrl"],
+ record.get("createdAt", ""),
+ )
+
+
async def delete_record(pool: asyncpg.Pool, table: str, did: str, rkey: str) -> None:
if table not in COLLECTION_TABLE_MAP.values():
return
@@ -219,6 +271,7 @@ async def delete_record(pool: asyncpg.Pool, table: str, did: str, rkey: str) ->
"entries": upsert_entry,
"labels": upsert_label,
"lenses": upsert_lens,
+ "index_providers": upsert_index_provider,
}
@@ -313,11 +366,18 @@ async def query_get_entry(
)
+_MAX_GET_ENTRIES_KEYS = 100
+
+
async def query_get_entries(
pool: asyncpg.Pool, keys: list[tuple[str, str]]
) -> list[asyncpg.Record]:
if not keys:
return []
+ if len(keys) > _MAX_GET_ENTRIES_KEYS:
+ raise ValueError(
+ f"query_get_entries: too many keys ({len(keys)}), max {_MAX_GET_ENTRIES_KEYS}"
+ )
conditions = " OR ".join(
f"(did = ${i * 2 + 1} AND rkey = ${i * 2 + 2})" for i in range(len(keys))
)
@@ -449,6 +509,49 @@ async def query_list_lenses(
)
+async def query_get_index_provider(
+ pool: asyncpg.Pool, did: str, rkey: str
+) -> asyncpg.Record | None:
+ async with pool.acquire() as conn:
+ return await conn.fetchrow(
+ "SELECT * FROM index_providers WHERE did = $1 AND rkey = $2", did, rkey
+ )
+
+
+async def query_list_index_providers(
+ pool: asyncpg.Pool,
+ repo: str | None = None,
+ limit: int = 50,
+ cursor_did: str | None = None,
+ cursor_rkey: str | None = None,
+ cursor_indexed_at: str | None = None,
+) -> list[asyncpg.Record]:
+ conditions: list[str] = []
+ params: list[Any] = []
+ idx = 1
+
+ if repo:
+ conditions.append(f"did = ${idx}")
+ params.append(repo)
+ idx += 1
+
+ if cursor_indexed_at and cursor_did and cursor_rkey:
+ conditions.append(
+ f"(indexed_at, did, rkey) < (${idx}, ${idx + 1}, ${idx + 2})"
+ )
+ params.extend([datetime.fromisoformat(cursor_indexed_at), cursor_did, cursor_rkey])
+ idx += 3
+
+ where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
+ params.append(limit)
+
+ async with pool.acquire() as conn:
+ return await conn.fetch(
+ f"SELECT * FROM index_providers {where} ORDER BY indexed_at DESC, did DESC, rkey DESC LIMIT ${idx}",
+ *params,
+ )
+
+
async def query_search_datasets(
pool: asyncpg.Pool,
q: str,
@@ -544,13 +647,17 @@ async def query_search_lenses(
)
+async def _fetch_record_counts(conn: asyncpg.Connection) -> dict[str, int]:
+ counts = {}
+ for collection, table in COLLECTION_TABLE_MAP.items():
+ row = await conn.fetchrow(f"SELECT COUNT(*) as cnt FROM {table}") # noqa: S608
+ counts[collection] = row["cnt"]
+ return counts
+
+
async def query_record_counts(pool: asyncpg.Pool) -> dict[str, int]:
async with pool.acquire() as conn:
- counts = {}
- for collection, table in COLLECTION_TABLE_MAP.items():
- row = await conn.fetchrow(f"SELECT COUNT(*) as cnt FROM {table}") # noqa: S608
- counts[collection] = row["cnt"]
- return counts
+ return await _fetch_record_counts(conn)
async def query_labels_for_dataset(
@@ -607,6 +714,9 @@ async def record_analytics_event(
logger.warning("Failed to record analytics event %s", event_type, exc_info=True)
+_background_tasks: set[asyncio.Task] = set()
+
+
def fire_analytics_event(
pool: asyncpg.Pool,
event_type: str,
@@ -615,9 +725,11 @@ def fire_analytics_event(
query_params: dict[str, Any] | None = None,
) -> None:
"""Fire-and-forget analytics recording. Does not block the caller."""
- asyncio.create_task(
+ task = asyncio.create_task(
record_analytics_event(pool, event_type, target_did, target_rkey, query_params)
)
+ _background_tasks.add(task)
+ task.add_done_callback(_background_tasks.discard)
PERIOD_INTERVALS: dict[str, timedelta] = {
@@ -692,11 +804,7 @@ async def query_analytics_summary(
{"term": r["term"], "count": r["count"]} for r in term_rows
]
- # Record counts
- counts = {}
- for collection, table in COLLECTION_TABLE_MAP.items():
- c = await conn.fetchrow(f"SELECT COUNT(*) AS cnt FROM {table}") # noqa: S608
- counts[collection] = c["cnt"]
+ counts = await _fetch_record_counts(conn)
return {
"totalViews": total_views,
@@ -720,7 +828,10 @@ async def query_entry_stats(
"""
SELECT
COUNT(*) FILTER (WHERE event_type = 'view_entry') AS views,
- COUNT(*) FILTER (WHERE event_type = 'search') AS search_appearances
+ COUNT(*) FILTER (WHERE event_type = 'search') AS search_appearances,
+ COUNT(*) FILTER (WHERE event_type = 'download') AS downloads,
+ COUNT(*) FILTER (WHERE event_type = 'citation') AS citations,
+ COUNT(*) FILTER (WHERE event_type = 'derivative') AS derivatives
FROM analytics_events
WHERE target_did = $1 AND target_rkey = $2
AND created_at >= NOW() - $3::interval
@@ -732,6 +843,9 @@ async def query_entry_stats(
return {
"views": row["views"],
"searchAppearances": row["search_appearances"],
+ "downloads": row["downloads"],
+ "citations": row["citations"],
+ "derivatives": row["derivatives"],
"period": period,
}
@@ -750,6 +864,8 @@ async def query_active_publishers(pool: asyncpg.Pool, days: int = 30) -> int:
SELECT did FROM labels WHERE indexed_at >= NOW() - $1::interval
UNION
SELECT did FROM lenses WHERE indexed_at >= NOW() - $1::interval
+ UNION
+ SELECT did FROM index_providers WHERE indexed_at >= NOW() - $1::interval
) sub
""",
interval,
diff --git a/src/atdata_app/frontend/routes.py b/src/atdata_app/frontend/routes.py
index 9a815e4..dac49f2 100644
--- a/src/atdata_app/frontend/routes.py
+++ b/src/atdata_app/frontend/routes.py
@@ -22,6 +22,7 @@
)
from atdata_app.models import (
maybe_cursor,
+ parse_at_uri,
parse_cursor,
row_to_entry,
row_to_label,
@@ -100,10 +101,18 @@ async def dataset_detail(request: Request, did: str, rkey: str):
schema_did = ""
schema_rkey = ""
schema_ref = entry.get("schemaRef", "")
- if schema_ref.startswith("at://"):
- parts = schema_ref[5:].split("/", 2)
- if len(parts) == 3:
- schema_did, _, schema_rkey = parts
+ if schema_ref:
+ try:
+ schema_did, _, schema_rkey = parse_at_uri(schema_ref)
+ except ValueError:
+ pass
+
+ # Fetch the referenced schema for inline display of format/annotation info
+ schema_info = None
+ if schema_did and schema_rkey:
+ schema_row = await query_get_schema(pool, schema_did, schema_rkey)
+ if schema_row:
+ schema_info = row_to_schema(schema_row)
# Fetch labels pointing to this dataset
dataset_uri = entry["uri"]
@@ -117,6 +126,7 @@ async def dataset_detail(request: Request, did: str, rkey: str):
"entry": entry,
"schema_did": schema_did,
"schema_rkey": schema_rkey,
+ "schema_info": schema_info,
"labels": labels,
},
)
diff --git a/src/atdata_app/frontend/templates/dataset.html b/src/atdata_app/frontend/templates/dataset.html
index 4b32fed..ddd4760 100644
--- a/src/atdata_app/frontend/templates/dataset.html
+++ b/src/atdata_app/frontend/templates/dataset.html
@@ -27,6 +27,20 @@
Details
| License | {{ entry.license }} |
{% endif %}
| Schema | {{ entry.schemaRef }} |
+ {% if schema_info %}
+ {% if schema_info.arrayFormat is defined %}
+ | Array Format | {{ schema_info.get("arrayFormatLabel", schema_info.arrayFormat) }} |
+ {% endif %}
+ {% if schema_info.dtype is defined %}
+ | Data Type | {{ schema_info.dtype }} |
+ {% endif %}
+ {% if schema_info.shape is defined and schema_info.shape is iterable and schema_info.shape is not string %}
+ | Shape | {{ schema_info.shape | join(" × ") }} |
+ {% endif %}
+ {% if schema_info.dimensionNames is defined and schema_info.dimensionNames is iterable and schema_info.dimensionNames is not string %}
+ | Dimensions | {{ schema_info.dimensionNames | join(", ") }} |
+ {% endif %}
+ {% endif %}
{% if entry.size %}
| Size |
@@ -49,7 +63,11 @@ Storage
| Type | {{ entry.storage.get("$type", "unknown") }} |
{% if entry.storage.url is defined %}
+ {% if entry.storage.url.startswith("https://") or entry.storage.url.startswith("http://") %}
| URL | {{ entry.storage.url }} |
+ {% else %}
+ | URL | {{ entry.storage.url }} |
+ {% endif %}
{% endif %}
diff --git a/src/atdata_app/frontend/templates/profile.html b/src/atdata_app/frontend/templates/profile.html
index 46026b4..04c4e29 100644
--- a/src/atdata_app/frontend/templates/profile.html
+++ b/src/atdata_app/frontend/templates/profile.html
@@ -34,13 +34,14 @@ Schemas
{% if schemas %}
- | Name | Version | Type |
+ | Name | Version | Type | Format |
{% for s in schemas %}
| {{ s.name }} |
{{ s.version }} |
{{ s.schemaType }} |
+ {{ s.get("arrayFormatLabel", "") }} |
{% endfor %}
diff --git a/src/atdata_app/frontend/templates/schema.html b/src/atdata_app/frontend/templates/schema.html
index 281d9f6..3df1859 100644
--- a/src/atdata_app/frontend/templates/schema.html
+++ b/src/atdata_app/frontend/templates/schema.html
@@ -16,6 +16,18 @@ Details
| AT-URI | {{ schema.uri }} |
| Type | {{ schema.schemaType }} |
+ {% if schema.arrayFormat is defined %}
+ | Array Format | {{ schema.get("arrayFormatLabel", schema.arrayFormat) }} |
+ {% endif %}
+ {% if schema.dtype is defined %}
+ | Data Type | {{ schema.dtype }} |
+ {% endif %}
+ {% if schema.shape is defined and schema.shape is iterable and schema.shape is not string %}
+ | Shape | {{ schema.shape | join(" × ") }} |
+ {% endif %}
+ {% if schema.dimensionNames is defined and schema.dimensionNames is iterable and schema.dimensionNames is not string %}
+ | Dimensions | {{ schema.dimensionNames | join(", ") }} |
+ {% endif %}
| Version | {{ schema.version }} |
| Created | {{ schema.createdAt }} |
diff --git a/src/atdata_app/frontend/templates/schemas.html b/src/atdata_app/frontend/templates/schemas.html
index 268af39..fdf3565 100644
--- a/src/atdata_app/frontend/templates/schemas.html
+++ b/src/atdata_app/frontend/templates/schemas.html
@@ -7,7 +7,7 @@ Schemas
{% if schemas %}
- | Name | Version | Type | Description | Publisher |
+ | Name | Version | Type | Format | Description | Publisher |
{% for s in schemas %}
@@ -15,6 +15,7 @@ Schemas
{{ s.name }} |
{{ s.version }} |
{{ s.schemaType }} |
+ {{ s.get("arrayFormatLabel", "") }} |
{{ s.get("description", "") }} |
{{ s.did[:20] }}… |
diff --git a/src/atdata_app/ingestion/jetstream.py b/src/atdata_app/ingestion/jetstream.py
index 9b765d2..f09711b 100644
--- a/src/atdata_app/ingestion/jetstream.py
+++ b/src/atdata_app/ingestion/jetstream.py
@@ -52,7 +52,9 @@ async def jetstream_consumer(app: FastAPI) -> None:
if event.get("kind") != "commit":
continue
- await process_commit(pool, event)
+ await process_commit(
+ pool, event, getattr(app.state, "change_stream", None)
+ )
last_time_us = event.get("time_us")
msg_count += 1
diff --git a/src/atdata_app/ingestion/processor.py b/src/atdata_app/ingestion/processor.py
index fa34c96..aab3f2b 100644
--- a/src/atdata_app/ingestion/processor.py
+++ b/src/atdata_app/ingestion/processor.py
@@ -8,11 +8,16 @@
import asyncpg
from atdata_app import database as db
+from atdata_app.changestream import ChangeStream, make_change_event
logger = logging.getLogger(__name__)
-async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None:
+async def process_commit(
+ pool: asyncpg.Pool,
+ event: dict[str, Any],
+ change_stream: ChangeStream | None = None,
+) -> None:
"""Process a Jetstream commit event.
Expected event format::
@@ -44,18 +49,31 @@ async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None:
if operation == "delete":
await db.delete_record(pool, table, did, rkey)
logger.debug("Deleted %s %s/%s", collection, did, rkey)
+ if change_stream is not None:
+ change_stream.publish(
+ make_change_event(
+ event_type="delete",
+ collection=collection,
+ did=did,
+ rkey=rkey,
+ )
+ )
elif operation in ("create", "update"):
record = commit["record"]
cid = commit.get("cid")
try:
- if table == "schemas":
- await db.upsert_schema(pool, did, rkey, cid, record)
- elif table == "entries":
- await db.upsert_entry(pool, did, rkey, cid, record)
- elif table == "labels":
- await db.upsert_label(pool, did, rkey, cid, record)
- elif table == "lenses":
- await db.upsert_lens(pool, did, rkey, cid, record)
+ await db.UPSERT_FNS[table](pool, did, rkey, cid, record)
logger.debug("Upserted %s %s/%s", collection, did, rkey)
+ if change_stream is not None:
+ change_stream.publish(
+ make_change_event(
+ event_type=operation,
+ collection=collection,
+ did=did,
+ rkey=rkey,
+ record=record,
+ cid=cid,
+ )
+ )
except Exception:
logger.exception("Failed to upsert %s %s/%s", collection, did, rkey)
diff --git a/src/atdata_app/main.py b/src/atdata_app/main.py
index b50bba3..087fde6 100644
--- a/src/atdata_app/main.py
+++ b/src/atdata_app/main.py
@@ -10,6 +10,7 @@
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
+from atdata_app.changestream import ChangeStream
from atdata_app.config import AppConfig
from atdata_app.database import create_pool, run_migrations
from atdata_app.frontend import router as frontend_router
@@ -30,6 +31,9 @@ async def lifespan(app: FastAPI):
config: AppConfig = app.state.config
logger.info("Starting atdata-app (DID: %s)", config.service_did)
+ # Change stream (must be created before background tasks)
+ app.state.change_stream = ChangeStream()
+
# Database
pool = await create_pool(config.database_url)
app.state.db_pool = pool
diff --git a/src/atdata_app/mcp_server.py b/src/atdata_app/mcp_server.py
index 53991eb..a5d2e69 100644
--- a/src/atdata_app/mcp_server.py
+++ b/src/atdata_app/mcp_server.py
@@ -57,7 +57,10 @@ async def server_lifespan(server: FastMCP) -> AsyncIterator[ServerContext]:
"ATProto AppView for the science.alt.dataset namespace. "
"Use these tools to discover and query scientific datasets, "
"schemas, and lenses (bidirectional schema transforms) published "
- "on the AT Protocol network."
+ "on the AT Protocol network. "
+ "Schemas may specify an arrayFormat (numpyBytes, parquetBytes, "
+ "sparseBytes, structuredBytes, arrowTensor, safetensors) and "
+ "ndarray annotations (dtype, shape, dimensionNames)."
),
lifespan=server_lifespan,
)
@@ -127,7 +130,8 @@ async def get_schema(ctx: Ctx, uri: str) -> dict[str, Any]:
uri: AT-URI of the schema (e.g. at://did:plc:abc/science.alt.dataset.schema/my.schema@1.0.0).
Returns:
- Full schema record including name, version, type, schema body, and description.
+ Full schema record including name, version, type, schema body, description,
+ and (when present) arrayFormat, dtype, shape, and dimensionNames.
"""
sc = _get_ctx(ctx)
did, _collection, rkey = parse_at_uri(uri)
diff --git a/src/atdata_app/models.py b/src/atdata_app/models.py
index 1cde2f2..8793590 100644
--- a/src/atdata_app/models.py
+++ b/src/atdata_app/models.py
@@ -9,6 +9,32 @@
from pydantic import BaseModel
+# ---------------------------------------------------------------------------
+# Known array format tokens (atdata-lexicon#21)
+# ---------------------------------------------------------------------------
+
+KNOWN_ARRAY_FORMATS: set[str] = {
+ # Original formats
+ "numpyBytes",
+ "parquetBytes",
+ # New formats
+ "sparseBytes",
+ "structuredBytes",
+ "arrowTensor",
+ "safetensors",
+}
+
+#: Human-friendly display names for array format tokens.
+ARRAY_FORMAT_LABELS: dict[str, str] = {
+ "numpyBytes": "NumPy ndarray",
+ "parquetBytes": "Parquet",
+ "sparseBytes": "Sparse matrix (CSR/CSC/COO)",
+ "structuredBytes": "NumPy structured array",
+ "arrowTensor": "Arrow tensor IPC",
+ "safetensors": "Safetensors",
+}
+
+
# ---------------------------------------------------------------------------
# AT-URI parsing
# ---------------------------------------------------------------------------
@@ -119,6 +145,19 @@ def row_to_schema(row) -> dict[str, Any]:
}
if row["description"]:
d["description"] = row["description"]
+
+ # Surface array format and ndarray v1.1.0 annotation fields for display
+ array_format = schema_body.get("arrayFormat")
+ if array_format:
+ d["arrayFormat"] = array_format
+ d["arrayFormatLabel"] = ARRAY_FORMAT_LABELS.get(array_format, array_format)
+ if schema_body.get("dtype"):
+ d["dtype"] = schema_body["dtype"]
+ if schema_body.get("shape"):
+ d["shape"] = schema_body["shape"]
+ if schema_body.get("dimensionNames"):
+ d["dimensionNames"] = schema_body["dimensionNames"]
+
return d
@@ -127,6 +166,7 @@ def row_to_label(row) -> dict[str, Any]:
d: dict[str, Any] = {
"uri": uri,
"cid": row["cid"],
+ "did": row["did"],
"name": row["name"],
"datasetUri": row["dataset_uri"],
"createdAt": row["created_at"],
@@ -138,6 +178,21 @@ def row_to_label(row) -> dict[str, Any]:
return d
+def row_to_index_provider(row) -> dict[str, Any]:
+ uri = make_at_uri(row["did"], "science.alt.dataset.index", row["rkey"])
+ d: dict[str, Any] = {
+ "uri": uri,
+ "cid": row["cid"],
+ "did": row["did"],
+ "name": row["name"],
+ "endpointUrl": row["endpoint_url"],
+ "createdAt": row["created_at"],
+ }
+ if row["description"]:
+ d["description"] = row["description"]
+ return d
+
+
def row_to_lens(row) -> dict[str, Any]:
uri = make_at_uri(row["did"], "science.alt.dataset.lens", row["rkey"])
getter_code = row["getter_code"]
@@ -150,6 +205,7 @@ def row_to_lens(row) -> dict[str, Any]:
d: dict[str, Any] = {
"uri": uri,
"cid": row["cid"],
+ "did": row["did"],
"name": row["name"],
"sourceSchema": row["source_schema"],
"targetSchema": row["target_schema"],
@@ -236,4 +292,22 @@ class GetAnalyticsResponse(BaseModel):
class GetEntryStatsResponse(BaseModel):
views: int
searchAppearances: int
+ downloads: int = 0
+ citations: int = 0
+ derivatives: int = 0
period: str
+
+
+class ListIndexesResponse(BaseModel):
+ indexes: list[dict[str, Any]]
+ cursor: str | None = None
+
+
+class IndexSkeletonResponse(BaseModel):
+ items: list[dict[str, Any]]
+ cursor: str | None = None
+
+
+class IndexResponse(BaseModel):
+ items: list[dict[str, Any]]
+ cursor: str | None = None
diff --git a/src/atdata_app/sql/schema.sql b/src/atdata_app/sql/schema.sql
index 17f1729..422b379 100644
--- a/src/atdata_app/sql/schema.sql
+++ b/src/atdata_app/sql/schema.sql
@@ -144,6 +144,22 @@ CREATE INDEX IF NOT EXISTS idx_analytics_events_type_created
CREATE INDEX IF NOT EXISTS idx_analytics_events_target
ON analytics_events (target_did, target_rkey, event_type);
+-- Index providers (science.alt.dataset.index)
+CREATE TABLE IF NOT EXISTS index_providers (
+ did TEXT NOT NULL,
+ rkey TEXT NOT NULL,
+ cid TEXT,
+ name TEXT NOT NULL,
+ description TEXT,
+ endpoint_url TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ indexed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ PRIMARY KEY (did, rkey)
+);
+
+CREATE INDEX IF NOT EXISTS idx_index_providers_did ON index_providers (did);
+CREATE INDEX IF NOT EXISTS idx_index_providers_indexed_at ON index_providers (indexed_at DESC);
+
-- Pre-aggregated analytics counters (avoids expensive COUNT on events table)
CREATE TABLE IF NOT EXISTS analytics_counters (
target_did TEXT NOT NULL,
diff --git a/src/atdata_app/xrpc/procedures.py b/src/atdata_app/xrpc/procedures.py
index c066987..7a418d0 100644
--- a/src/atdata_app/xrpc/procedures.py
+++ b/src/atdata_app/xrpc/procedures.py
@@ -8,6 +8,7 @@
import logging
from typing import Any
+from urllib.parse import urlparse
import httpx
from fastapi import APIRouter, HTTPException, Request
@@ -15,6 +16,7 @@
from atdata_app import get_resolver
from atdata_app.auth import verify_service_auth
from atdata_app.database import (
+ fire_analytics_event,
query_get_entry,
query_get_schema,
query_record_exists,
@@ -268,3 +270,113 @@ async def publish_lens(request: Request) -> dict[str, Any]:
pds, pds_token, auth.iss, "science.alt.dataset.lens", record, rkey
)
return {"uri": result.get("uri"), "cid": result.get("cid")}
+
+
+# ---------------------------------------------------------------------------
+# publishIndex
+# ---------------------------------------------------------------------------
+
+
+@router.post("/science.alt.dataset.publishIndex")
+async def publish_index(request: Request) -> dict[str, Any]:
+ auth = await verify_service_auth(request, "science.alt.dataset.publishIndex")
+ pds_token = _require_pds_token(request)
+
+ body = await request.json()
+ record = body.get("record", {})
+ rkey = body.get("rkey")
+
+ record_type = record.get("$type", "")
+ if record_type and record_type != "science.alt.dataset.index":
+ raise HTTPException(status_code=400, detail="Invalid $type for index")
+
+ for field in ("name", "endpointUrl", "createdAt"):
+ if field not in record:
+ raise HTTPException(status_code=400, detail=f"Missing required field: {field}")
+
+ # Validate endpoint URL is HTTPS with no credentials or fragments
+ parsed = urlparse(record["endpointUrl"])
+ if parsed.scheme != "https" or not parsed.hostname:
+ raise HTTPException(
+ status_code=400, detail="endpointUrl must be a valid HTTPS URL"
+ )
+ if parsed.username or parsed.password:
+ raise HTTPException(
+ status_code=400, detail="endpointUrl must not contain credentials"
+ )
+ if parsed.fragment:
+ raise HTTPException(
+ status_code=400, detail="endpointUrl must not contain a fragment"
+ )
+
+ record["$type"] = "science.alt.dataset.index"
+
+ pds = await _resolve_pds(auth.iss)
+ result = await _proxy_create_record(
+ pds, pds_token, auth.iss, "science.alt.dataset.index", record, rkey
+ )
+ return {"uri": result.get("uri"), "cid": result.get("cid")}
+
+
+# ---------------------------------------------------------------------------
+# sendInteractions
+# ---------------------------------------------------------------------------
+
+_VALID_INTERACTION_TYPES = frozenset({"download", "citation", "derivative"})
+_MAX_INTERACTIONS_BATCH = 100
+
+
+@router.post("/science.alt.dataset.sendInteractions")
+async def send_interactions(request: Request) -> dict[str, Any]:
+ """Record dataset interaction events (anonymous, fire-and-forget)."""
+ await verify_service_auth(request, "science.alt.dataset.sendInteractions")
+
+ pool = request.app.state.db_pool
+
+ body = await request.json()
+ interactions = body.get("interactions")
+ if not isinstance(interactions, list):
+ raise HTTPException(status_code=400, detail="interactions must be an array")
+
+ if len(interactions) > _MAX_INTERACTIONS_BATCH:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Batch size exceeds maximum of {_MAX_INTERACTIONS_BATCH}",
+ )
+
+ for i, item in enumerate(interactions):
+ if not isinstance(item, dict):
+ raise HTTPException(status_code=400, detail=f"interactions[{i}]: must be an object")
+
+ itype = item.get("type")
+ if itype not in _VALID_INTERACTION_TYPES:
+ raise HTTPException(
+ status_code=400,
+ detail=f"interactions[{i}]: invalid type '{itype}', "
+ f"must be one of: {', '.join(sorted(_VALID_INTERACTION_TYPES))}",
+ )
+
+ dataset_uri = item.get("datasetUri")
+ if not isinstance(dataset_uri, str):
+ raise HTTPException(
+ status_code=400, detail=f"interactions[{i}]: datasetUri is required"
+ )
+ try:
+ _did, _col, _rkey = parse_at_uri(dataset_uri)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"interactions[{i}]: invalid AT-URI: {dataset_uri}",
+ )
+ if _col != "science.alt.dataset.entry":
+ raise HTTPException(
+ status_code=400,
+ detail=f"interactions[{i}]: datasetUri must reference a dataset entry",
+ )
+
+ # All valid — fire analytics events
+ for item in interactions:
+ did, _collection, rkey = parse_at_uri(item["datasetUri"])
+ fire_analytics_event(pool, item["type"], target_did=did, target_rkey=rkey)
+
+ return {}
diff --git a/src/atdata_app/xrpc/queries.py b/src/atdata_app/xrpc/queries.py
index df5d15c..ff22e1d 100644
--- a/src/atdata_app/xrpc/queries.py
+++ b/src/atdata_app/xrpc/queries.py
@@ -7,6 +7,12 @@
from fastapi import APIRouter, HTTPException, Query, Request
+import ipaddress
+import socket
+from urllib.parse import urlparse
+
+import httpx
+
from atdata_app import get_resolver
from atdata_app.database import (
COLLECTION_TABLE_MAP,
@@ -16,8 +22,10 @@
query_entry_stats,
query_get_entries,
query_get_entry,
+ query_get_index_provider,
query_get_schema,
query_list_entries,
+ query_list_index_providers,
query_list_lenses,
query_list_schemas,
query_record_counts,
@@ -32,7 +40,10 @@
GetEntriesResponse,
GetEntryResponse,
GetEntryStatsResponse,
+ IndexResponse,
+ IndexSkeletonResponse,
ListEntriesResponse,
+ ListIndexesResponse,
ListLensesResponse,
ListSchemasResponse,
ResolveBlobsResponse,
@@ -44,6 +55,7 @@
parse_at_uri,
parse_cursor,
row_to_entry,
+ row_to_index_provider,
row_to_label,
row_to_lens,
row_to_schema,
@@ -340,6 +352,203 @@ async def search_lenses(
)
+# ---------------------------------------------------------------------------
+# listIndexes
+# ---------------------------------------------------------------------------
+
+
+@router.get("/science.alt.dataset.listIndexes")
+async def list_indexes(
+ request: Request,
+ repo: str | None = Query(None),
+ limit: int = Query(50, ge=1, le=100),
+ cursor: str | None = Query(None),
+) -> ListIndexesResponse:
+ pool = request.app.state.db_pool
+ c_at, c_did, c_rkey = parse_cursor(cursor)
+ rows = await query_list_index_providers(pool, repo, limit, c_did, c_rkey, c_at)
+ return ListIndexesResponse(
+ indexes=[row_to_index_provider(r) for r in rows],
+ cursor=maybe_cursor(rows, limit),
+ )
+
+
+# ---------------------------------------------------------------------------
+# getIndexSkeleton
+# ---------------------------------------------------------------------------
+
+
+def _validate_endpoint_url(url: str) -> None:
+ """Reject URLs that could cause SSRF (private IPs, non-HTTPS, credentials).
+
+ Raises ``HTTPException`` (400) if the URL is unsafe.
+ """
+ try:
+ parsed = urlparse(url)
+ except ValueError:
+ raise HTTPException(status_code=400, detail="Malformed endpoint URL")
+
+ if parsed.scheme != "https":
+ raise HTTPException(status_code=400, detail="Endpoint URL must use HTTPS")
+ if not parsed.hostname:
+ raise HTTPException(status_code=400, detail="Endpoint URL missing hostname")
+ if parsed.username or parsed.password:
+ raise HTTPException(status_code=400, detail="Endpoint URL must not contain credentials")
+
+ # Resolve hostname and block private/reserved IP ranges
+ try:
+ infos = socket.getaddrinfo(parsed.hostname, None, proto=socket.IPPROTO_TCP)
+ except socket.gaierror:
+ raise HTTPException(status_code=502, detail="Index provider hostname unresolvable")
+
+ for _family, _type, _proto, _canonname, sockaddr in infos:
+ ip = ipaddress.ip_address(sockaddr[0])
+ if ip.is_private or ip.is_reserved or ip.is_loopback or ip.is_link_local:
+ raise HTTPException(
+ status_code=400,
+ detail="Endpoint URL must not resolve to a private/reserved IP address",
+ )
+
+
+_MAX_SKELETON_RESPONSE_BYTES = 1_048_576 # 1 MiB
+_MAX_SKELETON_CURSOR_LEN = 512
+
+
+async def _fetch_skeleton(
+ endpoint_url: str,
+ cursor: str | None,
+ limit: int,
+) -> dict[str, Any]:
+ """Fetch skeleton from an upstream index provider."""
+ _validate_endpoint_url(endpoint_url)
+
+ if cursor is not None:
+ if len(cursor) > _MAX_SKELETON_CURSOR_LEN or "\x00" in cursor:
+ raise HTTPException(status_code=400, detail="Invalid cursor value")
+
+ params: dict[str, Any] = {"limit": limit}
+ if cursor:
+ params["cursor"] = cursor
+ async with httpx.AsyncClient(timeout=10.0) as http:
+ try:
+ resp = await http.get(endpoint_url, params=params)
+ except (httpx.HTTPError, ValueError) as e:
+ raise HTTPException(
+ status_code=502, detail="Index provider unreachable"
+ ) from e
+ if resp.status_code != 200:
+ raise HTTPException(
+ status_code=502,
+ detail=f"Index provider returned {resp.status_code}",
+ )
+ # Reject oversized responses to prevent memory exhaustion
+ content_length = resp.headers.get("content-length")
+ if content_length and int(content_length) > _MAX_SKELETON_RESPONSE_BYTES:
+ raise HTTPException(
+ status_code=502, detail="Index provider response too large"
+ )
+ if len(resp.content) > _MAX_SKELETON_RESPONSE_BYTES:
+ raise HTTPException(
+ status_code=502, detail="Index provider response too large"
+ )
+ try:
+ data = resp.json()
+ except (ValueError, KeyError) as e:
+ raise HTTPException(
+ status_code=502, detail="Invalid response from index provider"
+ ) from e
+ if not isinstance(data.get("items"), list):
+ raise HTTPException(
+ status_code=502, detail="Index provider response missing 'items' array"
+ )
+ # Cap items to the requested limit regardless of what upstream returns
+ data["items"] = data["items"][:limit]
+ # Whitelist item fields — only pass through 'uri' to prevent injection
+ data["items"] = [{"uri": item.get("uri", "")} for item in data["items"]]
+ # Validate upstream cursor
+ upstream_cursor = data.get("cursor")
+ if upstream_cursor is not None:
+ if (
+ not isinstance(upstream_cursor, str)
+ or len(upstream_cursor) > _MAX_SKELETON_CURSOR_LEN
+ or "\x00" in upstream_cursor
+ ):
+ data["cursor"] = None
+ return data
+
+
+@router.get("/science.alt.dataset.getIndexSkeleton")
+async def get_index_skeleton(
+ request: Request,
+ index: str = Query(...),
+ cursor: str | None = Query(None),
+ limit: int = Query(50, ge=1, le=100),
+) -> IndexSkeletonResponse:
+ pool = request.app.state.db_pool
+ try:
+ did, _, rkey = parse_at_uri(index)
+ except ValueError:
+ raise HTTPException(status_code=400, detail="Invalid AT-URI for index")
+ provider = await query_get_index_provider(pool, did, rkey)
+ if not provider:
+ raise HTTPException(status_code=404, detail="Index provider not found")
+
+ data = await _fetch_skeleton(provider["endpoint_url"], cursor, limit)
+ return IndexSkeletonResponse(
+ items=data["items"],
+ cursor=data.get("cursor"),
+ )
+
+
+# ---------------------------------------------------------------------------
+# getIndex (hydrated)
+# ---------------------------------------------------------------------------
+
+
+@router.get("/science.alt.dataset.getIndex")
+async def get_index(
+ request: Request,
+ index: str = Query(...),
+ cursor: str | None = Query(None),
+ limit: int = Query(50, ge=1, le=100),
+) -> IndexResponse:
+ pool = request.app.state.db_pool
+ try:
+ did, _, rkey = parse_at_uri(index)
+ except ValueError:
+ raise HTTPException(status_code=400, detail="Invalid AT-URI for index")
+ provider = await query_get_index_provider(pool, did, rkey)
+ if not provider:
+ raise HTTPException(status_code=404, detail="Index provider not found")
+
+ data = await _fetch_skeleton(provider["endpoint_url"], cursor, limit)
+
+ # Parse URIs from skeleton items and hydrate
+ keys: list[tuple[str, str]] = []
+ for item in data["items"]:
+ uri = item.get("uri", "")
+ try:
+ entry_did, _, entry_rkey = parse_at_uri(uri)
+ keys.append((entry_did, entry_rkey))
+ except ValueError:
+ continue # skip malformed URIs
+
+ rows = await query_get_entries(pool, keys)
+
+ # Build a lookup map to preserve skeleton ordering
+ row_map = {(r["did"], r["rkey"]): r for r in rows}
+ hydrated = []
+ for entry_did, entry_rkey in keys:
+ row = row_map.get((entry_did, entry_rkey))
+ if row:
+ hydrated.append(row_to_entry(row))
+
+ return IndexResponse(
+ items=hydrated,
+ cursor=data.get("cursor"),
+ )
+
+
# ---------------------------------------------------------------------------
# describeService
# ---------------------------------------------------------------------------
diff --git a/src/atdata_app/xrpc/router.py b/src/atdata_app/xrpc/router.py
index 627b382..d929478 100644
--- a/src/atdata_app/xrpc/router.py
+++ b/src/atdata_app/xrpc/router.py
@@ -1,10 +1,12 @@
-"""Combined XRPC router mounting all query and procedure endpoints."""
+"""Combined XRPC router mounting all query, procedure, and subscription endpoints."""
from fastapi import APIRouter
from atdata_app.xrpc.procedures import router as procedures_router
from atdata_app.xrpc.queries import router as queries_router
+from atdata_app.xrpc.subscriptions import router as subscriptions_router
router = APIRouter(prefix="/xrpc")
router.include_router(queries_router)
router.include_router(procedures_router)
+router.include_router(subscriptions_router)
diff --git a/src/atdata_app/xrpc/subscriptions.py b/src/atdata_app/xrpc/subscriptions.py
new file mode 100644
index 0000000..80bfff4
--- /dev/null
+++ b/src/atdata_app/xrpc/subscriptions.py
@@ -0,0 +1,77 @@
+"""WebSocket subscription endpoints for real-time change streaming."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+
+from fastapi import APIRouter, WebSocket, WebSocketDisconnect
+
+from atdata_app.changestream import ChangeStream
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+
+@router.websocket("/science.alt.dataset.subscribeChanges")
+async def subscribe_changes(websocket: WebSocket) -> None:
+ """Stream real-time change events over WebSocket.
+
+ Query parameters:
+ cursor: Optional sequence number to replay from.
+ """
+ change_stream: ChangeStream = websocket.app.state.change_stream
+
+ await websocket.accept()
+
+ cursor_param = websocket.query_params.get("cursor")
+
+ try:
+ sub_id, queue = change_stream.subscribe()
+ except RuntimeError:
+ await websocket.close(code=1013, reason="Too many subscribers")
+ return
+
+ try:
+ # Replay buffered events if cursor provided, tracking last seq
+ # to deduplicate against events that also landed in the live queue.
+ last_replayed_seq = 0
+ if cursor_param is not None:
+ try:
+ cursor = int(cursor_param)
+ except (ValueError, TypeError):
+ await websocket.close(code=1008, reason="Invalid cursor value")
+ return
+ missed = change_stream.replay_from(cursor)
+ for event in missed:
+ await websocket.send_text(json.dumps(event.to_dict()))
+ last_replayed_seq = event.seq
+
+ # Stream live events with periodic keepalive on idle
+ while True:
+ try:
+ event = await asyncio.wait_for(queue.get(), timeout=30.0)
+ except asyncio.TimeoutError:
+ # No events for 30s — send keepalive to detect dead connections
+ await websocket.send_text(json.dumps({"type": "keepalive"}))
+ continue
+
+ # Deduplicate events already sent during replay
+ if event.seq <= last_replayed_seq:
+ continue
+ await websocket.send_text(json.dumps(event.to_dict()))
+ # Check if we were marked as dropped due to backpressure
+ if change_stream.is_dropped(sub_id):
+ await websocket.close(
+ code=4000, reason="Backpressure: events were dropped"
+ )
+ return
+
+ except WebSocketDisconnect:
+ logger.debug("Subscriber %d disconnected", sub_id)
+ except Exception:
+ logger.exception("Error in subscriber %d", sub_id)
+ finally:
+ change_stream.unsubscribe(sub_id)
diff --git a/tests/test_analytics.py b/tests/test_analytics.py
index 6be9707..1ed9d04 100644
--- a/tests/test_analytics.py
+++ b/tests/test_analytics.py
@@ -8,7 +8,6 @@
import pytest
from httpx import ASGITransport, AsyncClient
-from atdata_app.config import AppConfig
from atdata_app.database import (
fire_analytics_event,
record_analytics_event,
@@ -28,17 +27,12 @@
# ---------------------------------------------------------------------------
-@pytest.fixture
-def config() -> AppConfig:
- return AppConfig(dev_mode=True, hostname="localhost", port=8000)
-
-
@pytest.fixture
def pool() -> AsyncMock:
return AsyncMock()
-def _mock_app(config: AppConfig, pool: AsyncMock):
+def _mock_app(config, pool):
"""Create a FastAPI app with mocked lifespan (no real DB)."""
app = create_app(config)
app.state.db_pool = pool
@@ -374,3 +368,303 @@ def test_describe_service_response_without_analytics():
recordCount={"science.alt.dataset.entry": 10},
)
assert resp.analytics is None
+
+
+def test_get_entry_stats_response_with_interactions():
+ resp = GetEntryStatsResponse(
+ views=10, searchAppearances=3, downloads=5, citations=2, derivatives=1, period="week"
+ )
+ assert resp.downloads == 5
+ assert resp.citations == 2
+ assert resp.derivatives == 1
+
+
+def test_get_entry_stats_response_defaults_interactions():
+ resp = GetEntryStatsResponse(views=10, searchAppearances=3, period="week")
+ assert resp.downloads == 0
+ assert resp.citations == 0
+ assert resp.derivatives == 0
+
+
+# ---------------------------------------------------------------------------
+# sendInteractions endpoint
+# ---------------------------------------------------------------------------
+
+_PROC = "atdata_app.xrpc.procedures"
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_valid_batch(mock_fire, mock_auth, config, pool):
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={
+ "interactions": [
+ {
+ "type": "download",
+ "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz",
+ },
+ {
+ "type": "citation",
+ "datasetUri": "at://did:plc:def/science.alt.dataset.entry/4abc",
+ "timestamp": "2025-06-01T12:00:00Z",
+ },
+ ]
+ },
+ )
+
+ assert resp.status_code == 200
+ assert resp.json() == {}
+ assert mock_fire.call_count == 2
+ mock_fire.assert_any_call(pool, "download", target_did="did:plc:abc", target_rkey="3xyz")
+ mock_fire.assert_any_call(pool, "citation", target_did="did:plc:def", target_rkey="4abc")
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_empty_array(mock_fire, mock_auth, config, pool):
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={"interactions": []},
+ )
+
+ assert resp.status_code == 200
+ assert resp.json() == {}
+ mock_fire.assert_not_called()
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_invalid_uri(mock_fire, mock_auth, config, pool):
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={
+ "interactions": [
+ {"type": "download", "datasetUri": "https://not-an-at-uri"},
+ ]
+ },
+ )
+
+ assert resp.status_code == 400
+ assert "invalid AT-URI" in resp.json()["detail"]
+ mock_fire.assert_not_called()
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_invalid_type(mock_fire, mock_auth, config, pool):
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={
+ "interactions": [
+ {
+ "type": "bookmark",
+ "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz",
+ },
+ ]
+ },
+ )
+
+ assert resp.status_code == 400
+ assert "invalid type" in resp.json()["detail"]
+ mock_fire.assert_not_called()
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_batch_size_exceeded(mock_fire, mock_auth, config, pool):
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ interactions = [
+ {"type": "download", "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz"}
+ for _ in range(101)
+ ]
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={"interactions": interactions},
+ )
+
+ assert resp.status_code == 400
+ assert "Batch size exceeds maximum" in resp.json()["detail"]
+ mock_fire.assert_not_called()
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_ignores_timestamp(mock_fire, mock_auth, config, pool):
+ """Timestamp field is accepted but not validated (informational only)."""
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={
+ "interactions": [
+ {
+ "type": "download",
+ "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz",
+ "timestamp": "not-a-date",
+ },
+ ]
+ },
+ )
+
+ assert resp.status_code == 200
+ assert mock_fire.call_count == 1
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_missing_dataset_uri(mock_fire, mock_auth, config, pool):
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={
+ "interactions": [
+ {"type": "download"},
+ ]
+ },
+ )
+
+ assert resp.status_code == 400
+ assert "datasetUri is required" in resp.json()["detail"]
+ mock_fire.assert_not_called()
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_not_an_array(mock_fire, mock_auth, config, pool):
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={"interactions": "not-an-array"},
+ )
+
+ assert resp.status_code == 400
+ assert "interactions must be an array" in resp.json()["detail"]
+ mock_fire.assert_not_called()
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_all_three_types(mock_fire, mock_auth, config, pool):
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={
+ "interactions": [
+ {
+ "type": "download",
+ "datasetUri": "at://did:plc:a/science.alt.dataset.entry/1",
+ },
+ {
+ "type": "citation",
+ "datasetUri": "at://did:plc:b/science.alt.dataset.entry/2",
+ },
+ {
+ "type": "derivative",
+ "datasetUri": "at://did:plc:c/science.alt.dataset.entry/3",
+ },
+ ]
+ },
+ )
+
+ assert resp.status_code == 200
+ assert mock_fire.call_count == 3
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_missing_key(mock_fire, mock_auth, config, pool):
+ """Body without 'interactions' key should return 400."""
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={"data": []},
+ )
+
+ assert resp.status_code == 400
+ assert "interactions must be an array" in resp.json()["detail"]
+ mock_fire.assert_not_called()
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_non_dict_item(mock_fire, mock_auth, config, pool):
+ """Non-object items in the interactions array should return 400."""
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={"interactions": ["not-a-dict"]},
+ )
+
+ assert resp.status_code == 400
+ assert "must be an object" in resp.json()["detail"]
+ mock_fire.assert_not_called()
+
+
+@pytest.mark.asyncio
+@patch(f"{_PROC}.verify_service_auth", new_callable=AsyncMock)
+@patch(f"{_PROC}.fire_analytics_event")
+async def test_send_interactions_boundary_at_max(mock_fire, mock_auth, config, pool):
+ """Exactly 100 interactions (the maximum) should succeed."""
+ mock_auth.return_value = MagicMock(iss="did:plc:caller")
+ app = _mock_app(config, pool)
+ transport = ASGITransport(app=app)
+ interactions = [
+ {"type": "download", "datasetUri": "at://did:plc:abc/science.alt.dataset.entry/3xyz"}
+ for _ in range(100)
+ ]
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.sendInteractions",
+ json={"interactions": interactions},
+ )
+
+ assert resp.status_code == 200
+ assert mock_fire.call_count == 100
diff --git a/tests/test_changestream.py b/tests/test_changestream.py
new file mode 100644
index 0000000..48e4318
--- /dev/null
+++ b/tests/test_changestream.py
@@ -0,0 +1,442 @@
+"""Tests for the change stream event bus and WebSocket subscription endpoint."""
+
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from fastapi import FastAPI
+from starlette.testclient import TestClient
+
+from atdata_app.changestream import ChangeEvent, ChangeStream, make_change_event
+from atdata_app.ingestion.processor import process_commit
+from atdata_app.xrpc.subscriptions import router as subscriptions_router
+
+
+# ---------------------------------------------------------------------------
+# ChangeStream unit tests
+# ---------------------------------------------------------------------------
+
+
+class TestChangeStream:
+ def test_publish_assigns_monotonic_seq(self):
+ cs = ChangeStream()
+ ev1 = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="abc",
+ )
+ ev2 = make_change_event(
+ event_type="update",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="def",
+ )
+ cs.publish(ev1)
+ cs.publish(ev2)
+ assert ev1.seq == 1
+ assert ev2.seq == 2
+ assert cs.current_seq == 2
+
+ def test_publish_delivers_to_subscribers(self):
+ cs = ChangeStream()
+ sub_id, queue = cs.subscribe()
+
+ ev = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="abc",
+ )
+ cs.publish(ev)
+
+ assert not queue.empty()
+ received = queue.get_nowait()
+ assert received.seq == 1
+ assert received.did == "did:plc:test"
+
+ def test_multiple_subscribers_receive_events(self):
+ cs = ChangeStream()
+ _, q1 = cs.subscribe()
+ _, q2 = cs.subscribe()
+
+ ev = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="abc",
+ )
+ cs.publish(ev)
+
+ assert not q1.empty()
+ assert not q2.empty()
+ assert q1.get_nowait().seq == 1
+ assert q2.get_nowait().seq == 1
+
+ def test_unsubscribe_removes_subscriber(self):
+ cs = ChangeStream()
+ sub_id, queue = cs.subscribe()
+ assert cs.subscriber_count == 1
+
+ cs.unsubscribe(sub_id)
+ assert cs.subscriber_count == 0
+
+ # Publishing after unsubscribe should not deliver
+ ev = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="abc",
+ )
+ cs.publish(ev)
+ assert queue.empty()
+
+ def test_full_queue_drops_event(self):
+ cs = ChangeStream(subscriber_queue_size=1)
+ _, queue = cs.subscribe()
+
+ ev1 = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="a",
+ )
+ ev2 = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="b",
+ )
+ cs.publish(ev1)
+ cs.publish(ev2) # Should be dropped (queue full)
+
+ assert queue.qsize() == 1
+ assert queue.get_nowait().rkey == "a"
+
+ def test_replay_from_cursor(self):
+ cs = ChangeStream(buffer_size=10)
+ for i in range(5):
+ ev = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey=str(i),
+ )
+ cs.publish(ev)
+
+ # Replay from seq 3 — should get events 4 and 5
+ replayed = cs.replay_from(3)
+ assert len(replayed) == 2
+ assert replayed[0].seq == 4
+ assert replayed[1].seq == 5
+
+ def test_replay_from_zero_returns_all(self):
+ cs = ChangeStream(buffer_size=10)
+ for i in range(3):
+ ev = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey=str(i),
+ )
+ cs.publish(ev)
+
+ replayed = cs.replay_from(0)
+ assert len(replayed) == 3
+
+ def test_replay_cursor_too_old(self):
+ cs = ChangeStream(buffer_size=3)
+ for i in range(5):
+ ev = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey=str(i),
+ )
+ cs.publish(ev)
+
+ # Buffer only holds seq 3, 4, 5 — cursor 1 is too old
+ replayed = cs.replay_from(1)
+ assert len(replayed) == 0
+
+ def test_replay_empty_buffer(self):
+ cs = ChangeStream()
+ assert cs.replay_from(0) == []
+
+ def test_bounded_buffer(self):
+ cs = ChangeStream(buffer_size=3)
+ for i in range(10):
+ ev = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey=str(i),
+ )
+ cs.publish(ev)
+
+ assert len(cs._buffer) == 3
+ assert cs._buffer[0].seq == 8
+ assert cs._buffer[-1].seq == 10
+
+
+class TestChangeEvent:
+ def test_to_dict_create(self):
+ ev = ChangeEvent(
+ seq=1,
+ type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="abc",
+ timestamp="2026-01-01T00:00:00Z",
+ record={"name": "test"},
+ cid="bafytest",
+ )
+ d = ev.to_dict()
+ assert d["seq"] == 1
+ assert d["type"] == "create"
+ assert d["record"] == {"name": "test"}
+ assert d["cid"] == "bafytest"
+
+ def test_to_dict_delete_omits_record_and_cid(self):
+ ev = ChangeEvent(
+ seq=2,
+ type="delete",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="abc",
+ timestamp="2026-01-01T00:00:00Z",
+ )
+ d = ev.to_dict()
+ assert "record" not in d
+ assert "cid" not in d
+
+
+class TestMakeChangeEvent:
+ def test_creates_event_with_timestamp(self):
+ ev = make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="abc",
+ record={"name": "test"},
+ cid="bafytest",
+ )
+ assert ev.seq == 0 # Not yet assigned
+ assert ev.type == "create"
+ assert ev.timestamp # Should have a timestamp
+
+
+# ---------------------------------------------------------------------------
+# Processor integration tests
+# ---------------------------------------------------------------------------
+
+_DB = "atdata_app.database"
+
+
+def _make_event(
+ did: str = "did:plc:test123",
+ collection: str = "science.alt.dataset.entry",
+ operation: str = "create",
+ rkey: str = "3xyz",
+ record: dict | None = None,
+ cid: str = "bafytest",
+) -> dict:
+ commit: dict = {
+ "rev": "rev1",
+ "operation": operation,
+ "collection": collection,
+ "rkey": rkey,
+ }
+ if operation != "delete":
+ commit["record"] = record or {
+ "$type": collection,
+ "name": "test-dataset",
+ "schemaRef": "at://did:plc:test/science.alt.dataset.schema/test@1.0.0",
+ "storage": {"$type": "science.alt.dataset.storageHttp", "shards": []},
+ "createdAt": "2025-01-01T00:00:00Z",
+ }
+ commit["cid"] = cid
+
+ return {
+ "did": did,
+ "time_us": 1725911162329308,
+ "kind": "commit",
+ "commit": commit,
+ }
+
+
+@pytest.mark.asyncio
+async def test_processor_publishes_create_event():
+ mock_upsert = AsyncMock()
+ pool = AsyncMock()
+ cs = ChangeStream()
+ _, queue = cs.subscribe()
+
+ event = _make_event(operation="create")
+ with patch.dict(f"{_DB}.UPSERT_FNS", {"entries": mock_upsert}):
+ await process_commit(pool, event, change_stream=cs)
+
+ assert not queue.empty()
+ change_event = queue.get_nowait()
+ assert change_event.type == "create"
+ assert change_event.collection == "science.alt.dataset.entry"
+ assert change_event.did == "did:plc:test123"
+ assert change_event.rkey == "3xyz"
+ assert change_event.record is not None
+ assert change_event.cid == "bafytest"
+
+
+@pytest.mark.asyncio
+@patch(f"{_DB}.delete_record", new_callable=AsyncMock)
+async def test_processor_publishes_delete_event(mock_delete):
+ pool = AsyncMock()
+ cs = ChangeStream()
+ _, queue = cs.subscribe()
+
+ event = _make_event(operation="delete")
+ await process_commit(pool, event, change_stream=cs)
+
+ assert not queue.empty()
+ change_event = queue.get_nowait()
+ assert change_event.type == "delete"
+ assert change_event.collection == "science.alt.dataset.entry"
+ assert change_event.record is None
+ assert change_event.cid is None
+
+
+@pytest.mark.asyncio
+async def test_processor_no_event_on_upsert_failure():
+ mock_upsert = AsyncMock(side_effect=Exception("db error"))
+ pool = AsyncMock()
+ cs = ChangeStream()
+ _, queue = cs.subscribe()
+
+ event = _make_event(operation="create")
+ with patch.dict(f"{_DB}.UPSERT_FNS", {"entries": mock_upsert}):
+ await process_commit(pool, event, change_stream=cs)
+
+ assert queue.empty()
+
+
+@pytest.mark.asyncio
+async def test_processor_works_without_change_stream():
+ """Backward compat: process_commit works when change_stream is None."""
+ mock_upsert = AsyncMock()
+ pool = AsyncMock()
+ event = _make_event(operation="create")
+ with patch.dict(f"{_DB}.UPSERT_FNS", {"entries": mock_upsert}):
+ await process_commit(pool, event)
+ mock_upsert.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# WebSocket endpoint tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture
+def app_with_changestream():
+ """Minimal app with just the subscriptions router — no DB lifespan."""
+ app = FastAPI()
+ app.state.change_stream = ChangeStream()
+ app.include_router(subscriptions_router, prefix="/xrpc")
+ return app
+
+
+def test_websocket_subscribe_and_receive(app_with_changestream):
+ app = app_with_changestream
+ cs: ChangeStream = app.state.change_stream
+
+ with TestClient(app) as client:
+ with client.websocket_connect(
+ "/xrpc/science.alt.dataset.subscribeChanges"
+ ) as ws:
+ # Publish an event from another "thread"
+ cs.publish(
+ make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="abc",
+ record={"name": "test"},
+ cid="bafytest",
+ )
+ )
+
+ data = ws.receive_json()
+ assert data["seq"] == 1
+ assert data["type"] == "create"
+ assert data["collection"] == "science.alt.dataset.entry"
+ assert data["did"] == "did:plc:test"
+ assert data["record"] == {"name": "test"}
+
+
+def test_websocket_cursor_replay(app_with_changestream):
+ app = app_with_changestream
+ cs: ChangeStream = app.state.change_stream
+
+ # Pre-populate buffer
+ for i in range(5):
+ cs.publish(
+ make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey=str(i),
+ )
+ )
+
+ with TestClient(app) as client:
+ with client.websocket_connect(
+ "/xrpc/science.alt.dataset.subscribeChanges?cursor=3"
+ ) as ws:
+ # Should replay events 4 and 5
+ msg1 = ws.receive_json()
+ assert msg1["seq"] == 4
+
+ msg2 = ws.receive_json()
+ assert msg2["seq"] == 5
+
+
+def test_websocket_disconnect_cleanup(app_with_changestream):
+ app = app_with_changestream
+ cs: ChangeStream = app.state.change_stream
+
+ with TestClient(app) as client:
+ with client.websocket_connect(
+ "/xrpc/science.alt.dataset.subscribeChanges"
+ ):
+ assert cs.subscriber_count == 1
+
+ # After disconnect, subscriber should be cleaned up
+ assert cs.subscriber_count == 0
+
+
+def test_websocket_multiple_subscribers(app_with_changestream):
+ app = app_with_changestream
+ cs: ChangeStream = app.state.change_stream
+
+ with TestClient(app) as client:
+ with client.websocket_connect(
+ "/xrpc/science.alt.dataset.subscribeChanges"
+ ) as ws1:
+ with client.websocket_connect(
+ "/xrpc/science.alt.dataset.subscribeChanges"
+ ) as ws2:
+ assert cs.subscriber_count == 2
+
+ cs.publish(
+ make_change_event(
+ event_type="create",
+ collection="science.alt.dataset.entry",
+ did="did:plc:test",
+ rkey="abc",
+ )
+ )
+
+ d1 = ws1.receive_json()
+ d2 = ws2.receive_json()
+ assert d1["seq"] == d2["seq"] == 1
+
+ assert cs.subscriber_count == 0
diff --git a/tests/test_frontend.py b/tests/test_frontend.py
index 43f714f..7e3d71e 100644
--- a/tests/test_frontend.py
+++ b/tests/test_frontend.py
@@ -41,6 +41,7 @@ def _make_schema_row(
did: str = "did:plc:test123",
rkey: str = "test@1.0.0",
name: str = "TestSchema",
+ schema_body: str | dict = '{"type": "object"}',
) -> dict:
return {
"did": did,
@@ -49,7 +50,7 @@ def _make_schema_row(
"name": name,
"version": "1.0.0",
"schema_type": "jsonSchema",
- "schema_body": '{"type": "object"}',
+ "schema_body": schema_body,
"description": "A test schema",
"metadata": None,
"created_at": "2025-01-01T00:00:00Z",
@@ -140,10 +141,12 @@ async def test_home_search(mock_search):
@pytest.mark.asyncio
@patch("atdata_app.frontend.routes.query_labels_for_dataset", new_callable=AsyncMock)
+@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock)
@patch("atdata_app.frontend.routes.query_get_entry", new_callable=AsyncMock)
-async def test_dataset_detail(mock_get, mock_labels):
+async def test_dataset_detail(mock_get, mock_schema, mock_labels):
pool, _conn = _mock_pool()
mock_get.return_value = _make_entry_row()
+ mock_schema.return_value = _make_schema_row()
mock_labels.return_value = [_make_label_row()]
app = _make_app(pool)
transport = ASGITransport(app=app)
@@ -260,6 +263,96 @@ async def test_about(mock_counts):
assert "did:web:localhost%3A8000" in resp.text
+# ---------------------------------------------------------------------------
+# Schema detail — array format & ndarray annotations
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock)
+async def test_schema_detail_array_format(mock_get):
+ pool, _conn = _mock_pool()
+ mock_get.return_value = _make_schema_row(
+ schema_body={
+ "arrayFormat": "sparseBytes",
+ "dtype": "float32",
+ "shape": [100, 200],
+ "dimensionNames": ["samples", "features"],
+ },
+ )
+ app = _make_app(pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get("/schema/did:plc:test123/test@1.0.0")
+ assert resp.status_code == 200
+ assert "Sparse matrix" in resp.text
+ assert "float32" in resp.text
+ assert "100" in resp.text
+ assert "samples" in resp.text
+
+
+@pytest.mark.asyncio
+@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock)
+async def test_schema_detail_no_array_format(mock_get):
+ """Plain schemas should not show array format rows."""
+ pool, _conn = _mock_pool()
+ mock_get.return_value = _make_schema_row()
+ app = _make_app(pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get("/schema/did:plc:test123/test@1.0.0")
+ assert resp.status_code == 200
+ assert "Array Format" not in resp.text
+ assert "Data Type" not in resp.text
+
+
+# ---------------------------------------------------------------------------
+# Dataset detail — schema format info
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+@patch("atdata_app.frontend.routes.query_labels_for_dataset", new_callable=AsyncMock)
+@patch("atdata_app.frontend.routes.query_get_schema", new_callable=AsyncMock)
+@patch("atdata_app.frontend.routes.query_get_entry", new_callable=AsyncMock)
+async def test_dataset_detail_with_schema_format(mock_entry, mock_schema, mock_labels):
+ pool, _conn = _mock_pool()
+ mock_entry.return_value = _make_entry_row()
+ mock_schema.return_value = _make_schema_row(
+ did="did:plc:test",
+ rkey="test@1.0.0",
+ schema_body={"arrayFormat": "numpyBytes", "dtype": "float64"},
+ )
+ mock_labels.return_value = []
+ app = _make_app(pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get("/dataset/did:plc:test123/3xyz")
+ assert resp.status_code == 200
+ assert "NumPy ndarray" in resp.text
+ assert "float64" in resp.text
+
+
+# ---------------------------------------------------------------------------
+# Schemas list — format column
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+@patch("atdata_app.frontend.routes.query_list_schemas", new_callable=AsyncMock)
+async def test_schemas_list_shows_format(mock_list):
+ pool, _conn = _mock_pool()
+ mock_list.return_value = [
+ _make_schema_row(schema_body={"arrayFormat": "safetensors"}),
+ ]
+ app = _make_app(pool)
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get("/schemas")
+ assert resp.status_code == 200
+ assert "Safetensors" in resp.text
+
+
# ---------------------------------------------------------------------------
# Static files
# ---------------------------------------------------------------------------
diff --git a/tests/test_index.py b/tests/test_index.py
new file mode 100644
index 0000000..15c84a8
--- /dev/null
+++ b/tests/test_index.py
@@ -0,0 +1,541 @@
+"""Tests for index provider endpoints (skeleton/hydration pattern)."""
+
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import httpx
+import pytest
+from httpx import ASGITransport, AsyncClient
+
+from atdata_app.config import AppConfig
+from atdata_app.ingestion.processor import process_commit
+from atdata_app.main import create_app
+
+_DB = "atdata_app.database"
+_QUERIES = "atdata_app.xrpc.queries"
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+def _make_app() -> tuple:
+ config = AppConfig(dev_mode=True, hostname="localhost", port=8000)
+ app = create_app(config)
+ pool = AsyncMock()
+ app.state.db_pool = pool
+ return app, pool
+
+
+def _index_provider_row(
+ did: str = "did:plc:provider1",
+ rkey: str = "3abc",
+ endpoint_url: str = "https://example.com/skeleton",
+ name: str = "Genomics Index",
+ description: str = "Curated genomics datasets",
+) -> dict:
+ """Simulate an asyncpg Record as a dict-like object."""
+ return MagicMock(
+ **{
+ "__getitem__": lambda self, key: {
+ "did": did,
+ "rkey": rkey,
+ "cid": "bafyindex",
+ "name": name,
+ "description": description,
+ "endpoint_url": endpoint_url,
+ "created_at": "2025-01-01T00:00:00Z",
+ "indexed_at": "2025-01-01T00:00:00+00:00",
+ }[key],
+ }
+ )
+
+
+def _entry_row(did: str = "did:plc:author1", rkey: str = "3xyz") -> MagicMock:
+ return MagicMock(
+ **{
+ "__getitem__": lambda self, key: {
+ "did": did,
+ "rkey": rkey,
+ "cid": "bafyentry",
+ "name": "test-dataset",
+ "schema_ref": "at://did:plc:test/science.alt.dataset.schema/test@1.0.0",
+ "storage": '{"$type": "science.alt.dataset.storageHttp", "shards": []}',
+ "description": None,
+ "tags": None,
+ "license": None,
+ "size_samples": None,
+ "size_bytes": None,
+ "size_shards": None,
+ "created_at": "2025-01-01T00:00:00Z",
+ "indexed_at": "2025-01-01T00:00:00+00:00",
+ }[key],
+ }
+ )
+
+
+# ---------------------------------------------------------------------------
+# getIndexSkeleton
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock)
+async def test_get_index_skeleton_success(mock_get_provider):
+ app, pool = _make_app()
+ mock_get_provider.return_value = _index_provider_row()
+
+ skeleton_response = {
+ "items": [{"uri": "at://did:plc:a/science.alt.dataset.entry/3xyz"}],
+ "cursor": "next123",
+ }
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 200
+ mock_resp.json.return_value = skeleton_response
+
+ with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls:
+ mock_client = AsyncMock()
+ mock_client.get.return_value = mock_resp
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock(return_value=False)
+ mock_client_cls.return_value = mock_client
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.getIndexSkeleton",
+ params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"},
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert len(data["items"]) == 1
+ assert data["items"][0]["uri"] == "at://did:plc:a/science.alt.dataset.entry/3xyz"
+ assert data["cursor"] == "next123"
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock)
+async def test_get_index_skeleton_not_found(mock_get_provider):
+ app, pool = _make_app()
+ mock_get_provider.return_value = None
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.getIndexSkeleton",
+ params={"index": "at://did:plc:missing/science.alt.dataset.index/3abc"},
+ )
+ assert resp.status_code == 404
+
+
+@pytest.mark.asyncio
+async def test_get_index_skeleton_invalid_uri():
+ app, pool = _make_app()
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.getIndexSkeleton",
+ params={"index": "not-a-uri"},
+ )
+ assert resp.status_code == 400
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock)
+async def test_get_index_skeleton_upstream_error(mock_get_provider):
+ app, pool = _make_app()
+ mock_get_provider.return_value = _index_provider_row()
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 500
+
+ with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls:
+ mock_client = AsyncMock()
+ mock_client.get.return_value = mock_resp
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock(return_value=False)
+ mock_client_cls.return_value = mock_client
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.getIndexSkeleton",
+ params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"},
+ )
+ assert resp.status_code == 502
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock)
+async def test_get_index_skeleton_upstream_unreachable(mock_get_provider):
+ app, pool = _make_app()
+ mock_get_provider.return_value = _index_provider_row()
+
+ with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls:
+ mock_client = AsyncMock()
+ mock_client.get.side_effect = httpx.ConnectError("Connection refused")
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock(return_value=False)
+ mock_client_cls.return_value = mock_client
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.getIndexSkeleton",
+ params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"},
+ )
+ assert resp.status_code == 502
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock)
+async def test_get_index_skeleton_invalid_response(mock_get_provider):
+ """Upstream returns JSON without 'items' array."""
+ app, pool = _make_app()
+ mock_get_provider.return_value = _index_provider_row()
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 200
+ mock_resp.json.return_value = {"bad": "data"}
+
+ with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls:
+ mock_client = AsyncMock()
+ mock_client.get.return_value = mock_resp
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock(return_value=False)
+ mock_client_cls.return_value = mock_client
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.getIndexSkeleton",
+ params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"},
+ )
+ assert resp.status_code == 502
+
+
+# ---------------------------------------------------------------------------
+# getIndex (hydrated)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_get_entries", new_callable=AsyncMock)
+@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock)
+async def test_get_index_hydrated(mock_get_provider, mock_get_entries):
+ app, pool = _make_app()
+ mock_get_provider.return_value = _index_provider_row()
+ mock_get_entries.return_value = [_entry_row("did:plc:a", "3xyz")]
+
+ skeleton_response = {
+ "items": [
+ {"uri": "at://did:plc:a/science.alt.dataset.entry/3xyz"},
+ ],
+ "cursor": "next456",
+ }
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 200
+ mock_resp.json.return_value = skeleton_response
+
+ with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls:
+ mock_client = AsyncMock()
+ mock_client.get.return_value = mock_resp
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock(return_value=False)
+ mock_client_cls.return_value = mock_client
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.getIndex",
+ params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"},
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert len(data["items"]) == 1
+ assert data["items"][0]["name"] == "test-dataset"
+ assert data["cursor"] == "next456"
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_get_entries", new_callable=AsyncMock)
+@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock)
+async def test_get_index_omits_missing_entries(mock_get_provider, mock_get_entries):
+ """Entries not in the DB should be silently omitted."""
+ app, pool = _make_app()
+ mock_get_provider.return_value = _index_provider_row()
+ # Return only one of two requested entries
+ mock_get_entries.return_value = [_entry_row("did:plc:a", "3xyz")]
+
+ skeleton_response = {
+ "items": [
+ {"uri": "at://did:plc:a/science.alt.dataset.entry/3xyz"},
+ {"uri": "at://did:plc:b/science.alt.dataset.entry/3deleted"},
+ ],
+ }
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 200
+ mock_resp.json.return_value = skeleton_response
+
+ with patch("atdata_app.xrpc.queries.httpx.AsyncClient") as mock_client_cls:
+ mock_client = AsyncMock()
+ mock_client.get.return_value = mock_resp
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock(return_value=False)
+ mock_client_cls.return_value = mock_client
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.getIndex",
+ params={"index": "at://did:plc:provider1/science.alt.dataset.index/3abc"},
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert len(data["items"]) == 1
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_get_index_provider", new_callable=AsyncMock)
+async def test_get_index_not_found(mock_get_provider):
+ app, pool = _make_app()
+ mock_get_provider.return_value = None
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.getIndex",
+ params={"index": "at://did:plc:missing/science.alt.dataset.index/3abc"},
+ )
+ assert resp.status_code == 404
+
+
+# ---------------------------------------------------------------------------
+# listIndexes
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_list_index_providers", new_callable=AsyncMock)
+async def test_list_indexes(mock_list):
+ app, pool = _make_app()
+ mock_list.return_value = [_index_provider_row()]
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get("/xrpc/science.alt.dataset.listIndexes")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert len(data["indexes"]) == 1
+ assert data["indexes"][0]["name"] == "Genomics Index"
+ assert data["indexes"][0]["endpointUrl"] == "https://example.com/skeleton"
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_list_index_providers", new_callable=AsyncMock)
+async def test_list_indexes_empty(mock_list):
+ app, pool = _make_app()
+ mock_list.return_value = []
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get("/xrpc/science.alt.dataset.listIndexes")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["indexes"] == []
+ assert data["cursor"] is None
+
+
+@pytest.mark.asyncio
+@patch(f"{_QUERIES}.query_list_index_providers", new_callable=AsyncMock)
+async def test_list_indexes_with_repo_filter(mock_list):
+ app, pool = _make_app()
+ mock_list.return_value = []
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.get(
+ "/xrpc/science.alt.dataset.listIndexes",
+ params={"repo": "did:plc:provider1"},
+ )
+ assert resp.status_code == 200
+ mock_list.assert_called_once_with(pool, "did:plc:provider1", 50, None, None, None)
+
+
+# ---------------------------------------------------------------------------
+# publishIndex
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock)
+@patch("atdata_app.xrpc.procedures._resolve_pds", new_callable=AsyncMock)
+@patch("atdata_app.xrpc.procedures._proxy_create_record", new_callable=AsyncMock)
+async def test_publish_index(mock_proxy, mock_pds, mock_auth):
+ app, pool = _make_app()
+
+ mock_auth.return_value = MagicMock(iss="did:plc:publisher1")
+ mock_pds.return_value = "https://pds.example.com"
+ mock_proxy.return_value = {
+ "uri": "at://did:plc:publisher1/science.alt.dataset.index/3abc",
+ "cid": "bafynew",
+ }
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.publishIndex",
+ json={
+ "record": {
+ "name": "Genomics Index",
+ "endpointUrl": "https://example.com/skeleton",
+ "createdAt": "2025-01-01T00:00:00Z",
+ },
+ },
+ headers={
+ "Authorization": "Bearer test-token",
+ "X-PDS-Auth": "pds-token",
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["uri"] == "at://did:plc:publisher1/science.alt.dataset.index/3abc"
+ assert data["cid"] == "bafynew"
+
+
+@pytest.mark.asyncio
+@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock)
+async def test_publish_index_missing_field(mock_auth):
+ app, pool = _make_app()
+ mock_auth.return_value = MagicMock(iss="did:plc:publisher1")
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.publishIndex",
+ json={
+ "record": {
+ "name": "Test",
+ "createdAt": "2025-01-01T00:00:00Z",
+ # missing endpointUrl
+ },
+ },
+ headers={
+ "Authorization": "Bearer test-token",
+ "X-PDS-Auth": "pds-token",
+ },
+ )
+ assert resp.status_code == 400
+ assert "endpointUrl" in resp.json()["detail"]
+
+
+@pytest.mark.asyncio
+@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock)
+async def test_publish_index_http_url_rejected(mock_auth):
+ app, pool = _make_app()
+ mock_auth.return_value = MagicMock(iss="did:plc:publisher1")
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.publishIndex",
+ json={
+ "record": {
+ "name": "Bad Index",
+ "endpointUrl": "http://insecure.example.com/skeleton",
+ "createdAt": "2025-01-01T00:00:00Z",
+ },
+ },
+ headers={
+ "Authorization": "Bearer test-token",
+ "X-PDS-Auth": "pds-token",
+ },
+ )
+ assert resp.status_code == 400
+ assert "HTTPS" in resp.json()["detail"]
+
+
+@pytest.mark.asyncio
+@patch("atdata_app.xrpc.procedures.verify_service_auth", new_callable=AsyncMock)
+async def test_publish_index_invalid_type(mock_auth):
+ app, pool = _make_app()
+ mock_auth.return_value = MagicMock(iss="did:plc:publisher1")
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ resp = await client.post(
+ "/xrpc/science.alt.dataset.publishIndex",
+ json={
+ "record": {
+ "$type": "science.alt.dataset.entry",
+ "name": "Wrong Type",
+ "endpointUrl": "https://example.com/skeleton",
+ "createdAt": "2025-01-01T00:00:00Z",
+ },
+ },
+ headers={
+ "Authorization": "Bearer test-token",
+ "X-PDS-Auth": "pds-token",
+ },
+ )
+ assert resp.status_code == 400
+ assert "Invalid $type" in resp.json()["detail"]
+
+
+# ---------------------------------------------------------------------------
+# Ingestion: index provider records
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_process_commit_index_provider():
+ mock_upsert = AsyncMock()
+ pool = AsyncMock()
+ event = {
+ "did": "did:plc:provider1",
+ "time_us": 1725911162329308,
+ "kind": "commit",
+ "commit": {
+ "rev": "rev1",
+ "operation": "create",
+ "collection": "science.alt.dataset.index",
+ "rkey": "3abc",
+ "record": {
+ "$type": "science.alt.dataset.index",
+ "name": "Genomics Index",
+ "endpointUrl": "https://example.com/skeleton",
+ "createdAt": "2025-01-01T00:00:00Z",
+ },
+ "cid": "bafyindex",
+ },
+ }
+ with patch.dict(f"{_DB}.UPSERT_FNS", {"index_providers": mock_upsert}):
+ await process_commit(pool, event)
+ mock_upsert.assert_called_once_with(
+ pool, "did:plc:provider1", "3abc", "bafyindex", event["commit"]["record"]
+ )
+
+
+@pytest.mark.asyncio
+@patch(f"{_DB}.delete_record", new_callable=AsyncMock)
+async def test_process_commit_delete_index_provider(mock_delete):
+ pool = AsyncMock()
+ event = {
+ "did": "did:plc:provider1",
+ "time_us": 1725911162329308,
+ "kind": "commit",
+ "commit": {
+ "rev": "rev1",
+ "operation": "delete",
+ "collection": "science.alt.dataset.index",
+ "rkey": "3abc",
+ },
+ }
+ await process_commit(pool, event)
+ mock_delete.assert_called_once_with(pool, "index_providers", "did:plc:provider1", "3abc")
diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py
index 2de2b49..fd31238 100644
--- a/tests/test_ingestion.py
+++ b/tests/test_ingestion.py
@@ -8,7 +8,6 @@
from atdata_app.ingestion.processor import process_commit
-# All patches target the `db` module reference used inside processor.py
_DB = "atdata_app.database"
@@ -44,15 +43,22 @@ def _make_event(
}
+def _patch_upsert(table: str):
+ """Patch a single entry in UPSERT_FNS by table name."""
+ mock = AsyncMock()
+ return patch.dict(f"{_DB}.UPSERT_FNS", {table: mock}), mock
+
+
@pytest.mark.asyncio
-@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock)
-async def test_process_commit_create(mock_upsert):
- pool = AsyncMock()
- event = _make_event(operation="create")
- await process_commit(pool, event)
- mock_upsert.assert_called_once_with(
- pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
- )
+async def test_process_commit_create():
+ patcher, mock_upsert = _patch_upsert("entries")
+ with patcher:
+ pool = AsyncMock()
+ event = _make_event(operation="create")
+ await process_commit(pool, event)
+ mock_upsert.assert_called_once_with(
+ pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
+ )
@pytest.mark.asyncio
@@ -72,85 +78,109 @@ async def test_process_commit_delete(mock_delete):
@pytest.mark.asyncio
-@patch(f"{_DB}.upsert_schema", new_callable=AsyncMock)
-async def test_process_commit_schema(mock_upsert):
- pool = AsyncMock()
- event = _make_event(
- collection="science.alt.dataset.schema",
- record={
- "$type": "science.alt.dataset.schema",
- "name": "TestSchema",
- "version": "1.0.0",
- "schemaType": "jsonSchema",
- "schema": {"$type": "science.alt.dataset.schema#jsonSchemaFormat"},
- "createdAt": "2025-01-01T00:00:00Z",
- },
- )
- await process_commit(pool, event)
- mock_upsert.assert_called_once_with(
- pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
- )
+async def test_process_commit_schema():
+ patcher, mock_upsert = _patch_upsert("schemas")
+ with patcher:
+ pool = AsyncMock()
+ event = _make_event(
+ collection="science.alt.dataset.schema",
+ record={
+ "$type": "science.alt.dataset.schema",
+ "name": "TestSchema",
+ "version": "1.0.0",
+ "schemaType": "jsonSchema",
+ "schema": {"$type": "science.alt.dataset.schema#jsonSchemaFormat"},
+ "createdAt": "2025-01-01T00:00:00Z",
+ },
+ )
+ await process_commit(pool, event)
+ mock_upsert.assert_called_once_with(
+ pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
+ )
@pytest.mark.asyncio
-@patch(f"{_DB}.upsert_label", new_callable=AsyncMock)
-async def test_process_commit_label(mock_upsert):
- pool = AsyncMock()
- event = _make_event(
- collection="science.alt.dataset.label",
- record={
- "$type": "science.alt.dataset.label",
- "name": "mnist",
- "datasetUri": "at://did:plc:test/science.alt.dataset.entry/3xyz",
- "createdAt": "2025-01-01T00:00:00Z",
- },
- )
- await process_commit(pool, event)
- mock_upsert.assert_called_once_with(
- pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
- )
+async def test_process_commit_label():
+ patcher, mock_upsert = _patch_upsert("labels")
+ with patcher:
+ pool = AsyncMock()
+ event = _make_event(
+ collection="science.alt.dataset.label",
+ record={
+ "$type": "science.alt.dataset.label",
+ "name": "mnist",
+ "datasetUri": "at://did:plc:test/science.alt.dataset.entry/3xyz",
+ "createdAt": "2025-01-01T00:00:00Z",
+ },
+ )
+ await process_commit(pool, event)
+ mock_upsert.assert_called_once_with(
+ pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
+ )
@pytest.mark.asyncio
-@patch(f"{_DB}.upsert_lens", new_callable=AsyncMock)
-async def test_process_commit_lens(mock_upsert):
- pool = AsyncMock()
- event = _make_event(
- collection="science.alt.dataset.lens",
- record={
- "$type": "science.alt.dataset.lens",
- "name": "test-lens",
- "sourceSchema": "at://did:plc:test/science.alt.dataset.schema/a@1.0.0",
- "targetSchema": "at://did:plc:test/science.alt.dataset.schema/b@1.0.0",
- "getterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "get.py"},
- "putterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "put.py"},
- "createdAt": "2025-01-01T00:00:00Z",
- },
- )
- await process_commit(pool, event)
- mock_upsert.assert_called_once_with(
- pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
- )
+async def test_process_commit_lens():
+ patcher, mock_upsert = _patch_upsert("lenses")
+ with patcher:
+ pool = AsyncMock()
+ event = _make_event(
+ collection="science.alt.dataset.lens",
+ record={
+ "$type": "science.alt.dataset.lens",
+ "name": "test-lens",
+ "sourceSchema": "at://did:plc:test/science.alt.dataset.schema/a@1.0.0",
+ "targetSchema": "at://did:plc:test/science.alt.dataset.schema/b@1.0.0",
+ "getterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "get.py"},
+ "putterCode": {"repository": "https://github.com/test/repo", "commit": "abc", "path": "put.py"},
+ "createdAt": "2025-01-01T00:00:00Z",
+ },
+ )
+ await process_commit(pool, event)
+ mock_upsert.assert_called_once_with(
+ pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_process_commit_index_provider():
+ patcher, mock_upsert = _patch_upsert("index_providers")
+ with patcher:
+ pool = AsyncMock()
+ event = _make_event(
+ collection="science.alt.dataset.index",
+ record={
+ "$type": "science.alt.dataset.index",
+ "name": "Genomics Index",
+ "endpointUrl": "https://example.com/skeleton",
+ "createdAt": "2025-01-01T00:00:00Z",
+ },
+ )
+ await process_commit(pool, event)
+ mock_upsert.assert_called_once_with(
+ pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
+ )
@pytest.mark.asyncio
-@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock)
-async def test_process_commit_update(mock_upsert):
+async def test_process_commit_update():
"""Update operations should route to the same upsert function as create."""
- pool = AsyncMock()
- event = _make_event(operation="update")
- await process_commit(pool, event)
- mock_upsert.assert_called_once_with(
- pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
- )
+ patcher, mock_upsert = _patch_upsert("entries")
+ with patcher:
+ pool = AsyncMock()
+ event = _make_event(operation="update")
+ await process_commit(pool, event)
+ mock_upsert.assert_called_once_with(
+ pool, "did:plc:test123", "3xyz", "bafytest", event["commit"]["record"]
+ )
@pytest.mark.asyncio
-@patch(f"{_DB}.upsert_entry", new_callable=AsyncMock)
-async def test_process_commit_upsert_error_is_caught(mock_upsert):
+async def test_process_commit_upsert_error_is_caught():
"""Upsert failures should be logged, not raised."""
- mock_upsert.side_effect = Exception("db error")
- pool = AsyncMock()
- event = _make_event(operation="create")
- # Should not raise
- await process_commit(pool, event)
+ mock = AsyncMock(side_effect=Exception("db error"))
+ with patch.dict(f"{_DB}.UPSERT_FNS", {"entries": mock}):
+ pool = AsyncMock()
+ event = _make_event(operation="create")
+ # Should not raise
+ await process_commit(pool, event)
diff --git a/tests/test_integration.py b/tests/test_integration.py
index 485ee27..2cccd1f 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -162,6 +162,7 @@ async def test_all_expected_tables_exist(self, db_pool):
"entries",
"labels",
"lenses",
+ "index_providers",
"cursor_state",
"analytics_events",
"analytics_counters",
@@ -196,6 +197,8 @@ async def test_expected_indexes_exist(self, db_pool):
"idx_lenses_did",
"idx_analytics_events_type_created",
"idx_analytics_events_target",
+ "idx_index_providers_did",
+ "idx_index_providers_indexed_at",
}
async with db_pool.acquire() as conn:
rows = await conn.fetch(
diff --git a/tests/test_models.py b/tests/test_models.py
index 23f6046..a09fc1e 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -5,11 +5,14 @@
import pytest
from atdata_app.models import (
+ ARRAY_FORMAT_LABELS,
+ KNOWN_ARRAY_FORMATS,
decode_cursor,
encode_cursor,
make_at_uri,
parse_at_uri,
row_to_entry,
+ row_to_index_provider,
row_to_label,
row_to_lens,
row_to_schema,
@@ -174,6 +177,95 @@ def test_row_to_schema_json_string_body():
assert d["schema"] == {"type": "object"}
+def test_row_to_schema_no_array_format_fields_when_absent():
+ """Plain schemas should not gain arrayFormat/ndarray annotation keys."""
+ d = row_to_schema(_SCHEMA_ROW)
+ assert "arrayFormat" not in d
+ assert "arrayFormatLabel" not in d
+ assert "dtype" not in d
+ assert "shape" not in d
+ assert "dimensionNames" not in d
+
+
+# ---------------------------------------------------------------------------
+# row_to_schema — array format types
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize("fmt", sorted(KNOWN_ARRAY_FORMATS))
+def test_row_to_schema_known_array_format(fmt):
+ """Each known format token should surface arrayFormat and a human label."""
+ row = {**_SCHEMA_ROW, "schema_body": {"arrayFormat": fmt}}
+ d = row_to_schema(row)
+ assert d["arrayFormat"] == fmt
+ assert d["arrayFormatLabel"] == ARRAY_FORMAT_LABELS[fmt]
+
+
+def test_row_to_schema_unknown_array_format_passes_through():
+ """Unknown format tokens are stored and surfaced as-is."""
+ row = {**_SCHEMA_ROW, "schema_body": {"arrayFormat": "futureFormat"}}
+ d = row_to_schema(row)
+ assert d["arrayFormat"] == "futureFormat"
+ assert d["arrayFormatLabel"] == "futureFormat"
+
+
+# ---------------------------------------------------------------------------
+# row_to_schema — ndarray v1.1.0 annotations
+# ---------------------------------------------------------------------------
+
+
+def test_row_to_schema_ndarray_annotations():
+ """ndarray v1.1.0 annotation fields are surfaced at top level."""
+ row = {
+ **_SCHEMA_ROW,
+ "schema_body": {
+ "arrayFormat": "numpyBytes",
+ "dtype": "float32",
+ "shape": [100, 200],
+ "dimensionNames": ["samples", "features"],
+ },
+ }
+ d = row_to_schema(row)
+ assert d["arrayFormat"] == "numpyBytes"
+ assert d["dtype"] == "float32"
+ assert d["shape"] == [100, 200]
+ assert d["dimensionNames"] == ["samples", "features"]
+
+
+def test_row_to_schema_ndarray_partial_annotations():
+ """Only present annotation fields should appear in output."""
+ row = {
+ **_SCHEMA_ROW,
+ "schema_body": {"arrayFormat": "sparseBytes", "dtype": "int64"},
+ }
+ d = row_to_schema(row)
+ assert d["dtype"] == "int64"
+ assert "shape" not in d
+ assert "dimensionNames" not in d
+
+
+# ---------------------------------------------------------------------------
+# KNOWN_ARRAY_FORMATS constant
+# ---------------------------------------------------------------------------
+
+
+def test_known_array_formats_contains_all_expected():
+ expected = {
+ "numpyBytes",
+ "parquetBytes",
+ "sparseBytes",
+ "structuredBytes",
+ "arrowTensor",
+ "safetensors",
+ }
+ assert KNOWN_ARRAY_FORMATS == expected
+
+
+def test_array_format_labels_covers_all_known():
+ """Every known format should have a human-readable label."""
+ assert set(ARRAY_FORMAT_LABELS.keys()) == KNOWN_ARRAY_FORMATS
+
+
# ---------------------------------------------------------------------------
# row_to_label
# ---------------------------------------------------------------------------
@@ -193,6 +285,7 @@ def test_row_to_schema_json_string_body():
def test_row_to_label():
d = row_to_label(_LABEL_ROW)
assert d["uri"] == "at://did:plc:abc/science.alt.dataset.label/3lbl"
+ assert d["did"] == "did:plc:abc"
assert d["datasetUri"] == _LABEL_ROW["dataset_uri"]
assert d["version"] == "1.0.0"
assert d["description"] == "First version"
@@ -227,6 +320,7 @@ def test_row_to_label_omits_optional_fields():
def test_row_to_lens():
d = row_to_lens(_LENS_ROW)
assert d["uri"] == "at://did:plc:abc/science.alt.dataset.lens/3lens"
+ assert d["did"] == "did:plc:abc"
assert d["sourceSchema"] == _LENS_ROW["source_schema"]
assert d["getterCode"] == _LENS_ROW["getter_code"]
assert d["description"] == "Transforms A to B"
@@ -249,3 +343,33 @@ def test_row_to_lens_json_string_code():
d = row_to_lens(row)
assert d["getterCode"] == {"repo": "x"}
assert d["putterCode"] == {"repo": "y"}
+
+
+# ---------------------------------------------------------------------------
+# row_to_index_provider
+# ---------------------------------------------------------------------------
+
+_INDEX_PROVIDER_ROW = {
+ "did": "did:plc:provider1",
+ "rkey": "3idx",
+ "cid": "bafyindex",
+ "name": "Genomics Index",
+ "description": "Curated genomics datasets",
+ "endpoint_url": "https://example.com/skeleton",
+ "created_at": "2025-01-01T00:00:00Z",
+}
+
+
+def test_row_to_index_provider():
+ d = row_to_index_provider(_INDEX_PROVIDER_ROW)
+ assert d["uri"] == "at://did:plc:provider1/science.alt.dataset.index/3idx"
+ assert d["did"] == "did:plc:provider1"
+ assert d["name"] == "Genomics Index"
+ assert d["endpointUrl"] == "https://example.com/skeleton"
+ assert d["description"] == "Curated genomics datasets"
+
+
+def test_row_to_index_provider_omits_null_description():
+ row = {**_INDEX_PROVIDER_ROW, "description": None}
+ d = row_to_index_provider(row)
+ assert "description" not in d
diff --git a/uv.lock b/uv.lock
index 72b071a..2dfa852 100644
--- a/uv.lock
+++ b/uv.lock
@@ -75,7 +75,7 @@ wheels = [
[[package]]
name = "atdata-app"
-version = "0.3.0b1"
+version = "0.4.0b1"
source = { editable = "." }
dependencies = [
{ name = "asyncpg" },