diff --git a/backend/SERVICE_DEPENDENCIES.md b/backend/SERVICE_DEPENDENCIES.md new file mode 100644 index 0000000..0b730b2 --- /dev/null +++ b/backend/SERVICE_DEPENDENCIES.md @@ -0,0 +1,136 @@ +# Backend service dependency map + +**Audience:** contributors changing the FastAPI backend; operators +investigating an incident; auditors tracing data flow. + +**Last updated:** 2026-05-15 + +This doc inventories every service the FastAPI backend depends on, in +the direction of the dependency (who-calls-whom). For each: what it's +used for, when failure is acceptable, and the failure-mode hooks. + +The complementary docs (in the sibling `ndi-cloud-app` repo): +- `apps/web/docs/operations/vendor-dependencies.md` — vendor + BAA + inventory at the higher level +- `apps/web/docs/operations/disaster-recovery.md` — runbooks per + failure mode + +--- + +## Topology + +``` + ┌──────────────────────────────┐ + │ FastAPI backend │ + │ (this repo, on Railway) │ + └─────┬───────────┬────────────┘ + │ │ + │ │ + ▼ ▼ + ┌─────────────┐ ┌──────────────┐ + │ Redis │ │ Postgres │ + │ (Railway) │ │ (Railway) │ + └─────────────┘ └──────────────┘ + │ + │ (rate limits, sessions, table cache) + │ + ▼ + ┌──────────────────────────────────┐ + │ ndi-cloud-node │ + │ (AWS Lambda + API Gateway) │ + └──────┬───────────────────────────┘ + │ + ├── AWS Cognito User Pool (identity) + ├── AWS DocumentDB (metadata) + └── AWS S3 (binary recordings) +``` + +--- + +## Outbound dependencies (what FastAPI calls) + +### Redis (Railway-hosted) + +| Field | Value | +|---|---| +| **Used for** | Session store (Fernet-encrypted access tokens), rate-limit counters, summary-table response cache, CSRF-failure budget | +| **Failure mode** | Sessions: every request returns 401 (forces re-login). Rate limit: middleware fails-open (allows requests) per the swallow-error-and-pass pattern in `csrf.py:_maybe_promote_to_rate_limit`. Cache: every read becomes a miss (slower but correct). | +| **Acceptable downtime?** | Sessions: no — platform unusable. Rate limit + cache: yes, with degraded UX. | +| **Code surface** | `backend/auth/session.py` (sessions), `backend/middleware/rate_limit.py`, `backend/cache/redis_table.py`. | + +### Postgres (Railway-hosted) + +| Field | Value | +|---|---| +| **Used for** | pgvector RAG store for `/ask` semantic search; future `chat_usage_events` table (Stream 3) for per-user cost tracking. | +| **Failure mode** | Semantic search returns soft error; chat falls back to structured catalog tools. | +| **Acceptable downtime?** | Yes — chat works without semantic search via fallback. | +| **Code surface** | The RAG-store schema lives in the sibling `ndi-cloud-app` repo at `apps/web/lib/ai/db/`. The cloud-app side reads pgvector directly via `@vercel/postgres`. FastAPI doesn't currently touch the RAG store; it WILL when Stream 3.2 (`chat_usage_events`) lands. | + +### ndi-cloud-node (AWS Lambda) + +| Field | Value | +|---|---| +| **Used for** | All catalog reads, all auth (Cognito-backed login), all dataset metadata, all NDI Query DSL evaluation, all binary-document downloads (proxied via signed S3 URLs). | +| **Failure mode** | Circuit breaker opens after 5 consecutive failures (default `CLOUD_CIRCUIT_BREAKER_THRESHOLD`); cooldown 30s. While the breaker is open, every FastAPI request that needs the cloud returns `CloudUnreachable` typed error → 503 `cloud_unreachable`. | +| **Acceptable downtime?** | No — platform unusable. AWS SLO is the binding constraint. | +| **Code surface** | `backend/clients/ndi_cloud.py` (HTTP client + circuit breaker), `backend/clients/circuit_breaker.py`. | +| **Auth** | Bearer access-token (Cognito JWT) per-request, no service account; the user's session-stored token is decrypted and forwarded on the request. | + +### AWS S3 (via signed URLs) + +| Field | Value | +|---|---| +| **Used for** | Binary recording downloads. ndi-cloud-node returns a signed S3 URL; FastAPI forwards the URL to the client OR streams the bytes through (depending on size). | +| **Failure mode** | Binary downloads return 502. Catalog reads + metadata are unaffected. | +| **Code surface** | `backend/clients/_url_allowlist.py` enforces an allowlist of S3 hostnames before any FastAPI-side download proxy. The May 2026 audit (`test_download_from_off_allowlist_host_hard_rejects`) verifies the allowlist rejects non-S3 hosts even when ndi-cloud-node returns a redirect to one. | + +### OpenTelemetry collector (optional) + +| Field | Value | +|---|---| +| **Used for** | Trace export when `OTEL_EXPORTER_OTLP_ENDPOINT` is non-empty. Default: empty (tracing disabled). | +| **Failure mode** | Tracing dropped silently. No impact on application requests. | +| **Code surface** | `backend/observability/` (sender), `backend/middleware/request_id.py` (per-request id propagation). | + +--- + +## Inbound dependencies (who calls FastAPI) + +### Vercel-hosted ndi-cloud-app frontend (production + preview) + +| Field | Value | +|---|---| +| **Used for** | Every `/api/*` request from the browser is proxied to FastAPI via Vercel `rewrites()`. Same for RSC-server-side fetches (`INTERNAL_API_URL`). | +| **Auth posture** | Cookie + CSRF — matches the FastAPI middleware contract. | +| **Branch awareness** | The cloud-app's `feat/experimental-ask-chat` branch routes `/api/*` to **this** experimental FastAPI env (`ndb-v2-experimental`) via the branch-aware rewrite. Main branch routes to production FastAPI. See ADR-005 in the cloud-app repo. | + +### vh-lab-chatbot + shrek-lab-chatbot + +| Field | Value | +|---|---| +| **Used for** | These two sibling chatbots historically read the same Postgres RAG index. Today they don't call FastAPI directly — they query their own embedding indices. Listed here for completeness because they share the Voyage API key (incident-prone: see the May 2026 leaked-credentials postmortem in the cloud-app repo). | + +--- + +## Service-startup order + +The FastAPI app's lifespan handler (`backend/app.py:lifespan`) starts services in this order: + +1. **NdiCloudClient.start()** — opens the httpx pool. Lazy DNS, no + eager call to the cloud. +2. **SessionStore** — instantiates with the Fernet key from settings. +3. **RateLimiter** — Redis-backed; lazy on first use. +4. **Ontology cache** — SQLite at `ONTOLOGY_CACHE_DB_PATH`, created if + absent. + +Shutdown is reverse order. If startup fails at any step, the container +crashes before serving the first request — by design (fail-loud). + +--- + +## Update history + +| Date | Change | +|---|---| +| 2026-05-15 | Initial draft (Stream 4.8 deliverable). | diff --git a/backend/app.py b/backend/app.py index 9d4e1bb..f80369b 100644 --- a/backend/app.py +++ b/backend/app.py @@ -36,17 +36,25 @@ from .observability.logging import configure_logging, get_logger, request_id_ctx from .observability.tracing import init_tracing from .routers import ( + aggregate_documents, auth, binary, datasets, documents, health, + image, + ndi_dataset, ontology, + psth, query, signal, + spike_summary, tables, + tabular_query, + treatment_timeline, visualize, ) +from .services.dataset_binding_service import DatasetBindingService from .services.ontology_cache import OntologyCache from .services.ontology_service import OntologyService from .static_files import safe_static_path @@ -255,6 +263,98 @@ async def _facets_warm() -> None: log.info("keepwarm.started", interval_seconds=240) log.info("facets_warm.started", interval_seconds=240) + # NDI-python strict-boot check. + # + # The Phase A integration adds vlt (VHSB), ndicompress, and + # ndi.ontology. When `NDI_PYTHON_REQUIRED=1` (set by the Railway + # Dockerfile), the stack MUST be importable or we hard-fail. + # Unset (dev/test/CI), we log a warning if NDI is missing but + # keep going — every NDI-python call gracefully returns None and + # callers fall through to their legacy paths. + # + # Why an explicit env var rather than guessing from + # `settings.ENVIRONMENT`: the test/CI/local matrix is fuzzy, and + # the only thing that actually matters here is "is this image + # supposed to have NDI-python installed?" The Dockerfile knows; + # nothing else needs to. + import os as _os + if _os.environ.get("NDI_PYTHON_REQUIRED", "").strip() in ("1", "true", "yes"): + from .services import ndi_python_service as _ndi + if not _ndi.is_ndi_available(): + raise RuntimeError( + "ndi_python_service.is_ndi_available() returned False at " + "startup but NDI_PYTHON_REQUIRED=1. The NDI-python stack " + "(vlt, ndicompress, ndi.ontology) failed to import. Check " + "the Dockerfile's pinned git SHAs and the install layer logs." + ) + log.info("ndi_python.boot_ok") + + # Sprint 1.5 dataset-binding service — singleton, lives on app.state. + # Always instantiated (cheap object — empty LRU). The router behind + # ``/api/datasets/{id}/ndi_overview`` calls into it; on internal + # failure (NDI-python missing, cloud unreachable, etc.) the service + # returns None and the router maps that to a 503. Frontend tool + # falls back to ndi_query gracefully. + app.state.dataset_binding_service = DatasetBindingService() + + # Optional pre-warm of the 3 demo datasets. We fire-and-forget per + # dataset so a single failure doesn't block the others. Each task + # is parked on app.state so asyncio doesn't GC the reference + # mid-flight (RUF006). We DO NOT await them — they run in the + # background while the app starts serving requests immediately. + # + # If NDI-python isn't available, the service returns None on the + # first call and we skip the rest — costs essentially nothing. + async def _prewarm_dataset(dataset_id: str) -> None: + try: + log.info("dataset_binding.prewarm_start", dataset_id=dataset_id) + result = await app.state.dataset_binding_service.get_dataset( + dataset_id + ) + if result is not None: + log.info( + "dataset_binding.prewarm_done", + dataset_id=dataset_id, + ) + else: + # Service already logged the reason at WARN — keep this + # at INFO so the boot timeline is one-line-per-dataset. + log.info( + "dataset_binding.prewarm_skipped", + dataset_id=dataset_id, + ) + except _asyncio.CancelledError: + raise + except Exception as exc: + # Truly defensive: get_dataset() is documented to never + # raise, but log loudly if that contract breaks so we know. + log.warning( + "dataset_binding.prewarm_unexpected_raise", + dataset_id=dataset_id, + error=str(exc), + error_type=type(exc).__name__, + ) + + # Three demo datasets surfaced by the experimental /ask chat: + # Dabrowska BNST (EPM behavior), Bhar (chemotaxis), Haley + # (patch-encounter). Order does not matter; tasks run concurrently. + # Pre-warm is gated to production-like environments so dev/test + # boots stay fast. + if settings.ENVIRONMENT in ("production", "preview"): + prewarm_ids = ( + "67f723d574f5f79c6062389d", # Dabrowska BNST + "69bc5ca11d547b1f6d083761", # Bhar + "682e7772cdf3f24938176fac", # Haley + ) + app.state.dataset_binding_prewarm_tasks = [ + _asyncio.create_task(_prewarm_dataset(did)) + for did in prewarm_ids + ] + log.info( + "dataset_binding.prewarm_started", + count=len(prewarm_ids), + ) + log.info("app.startup", environment=settings.ENVIRONMENT) try: yield @@ -272,6 +372,16 @@ async def _facets_warm() -> None: # so it surfaces in logs instead of disappearing. with _contextlib.suppress(_asyncio.CancelledError): await task + # Cancel any in-flight dataset-binding pre-warm tasks. + # downloadDataset is blocking I/O inside asyncio.to_thread — we + # can't actually interrupt it mid-thread, but cancellation + # prevents the post-download cache-write from running after + # teardown. + prewarm_tasks = getattr(app.state, "dataset_binding_prewarm_tasks", None) or [] + for task in prewarm_tasks: + task.cancel() + with _contextlib.suppress(_asyncio.CancelledError, Exception): + await task await cloud_client.close() await ontology_service.close() # `redis.asyncio.Redis.aclose()` is the correct async-context @@ -422,8 +532,16 @@ async def handle_unhandled(request: Request, exc: Exception) -> JSONResponse: app.include_router(tables.router) app.include_router(query.router) app.include_router(query.facets_router) + # Stream 4.9 (2026-05-16) — heavy aggregate runs on Railway, not Vercel. + app.include_router(aggregate_documents.router) app.include_router(binary.router) app.include_router(signal.router) + app.include_router(image.router) + app.include_router(tabular_query.router) + app.include_router(treatment_timeline.router) + app.include_router(psth.router) + app.include_router(spike_summary.router) + app.include_router(ndi_dataset.router) app.include_router(ontology.router) app.include_router(visualize.router) diff --git a/backend/auth/cookie_attrs.py b/backend/auth/cookie_attrs.py index ac7eb21..5599b84 100644 --- a/backend/auth/cookie_attrs.py +++ b/backend/auth/cookie_attrs.py @@ -1,18 +1,81 @@ """Per-environment cookie attribute helper. Centralizes the ``Set-Cookie`` / ``Delete-Cookie`` attribute set used by -the session and CSRF cookies. Production carries -``Domain=.ndi-cloud.com`` so the apex Vercel deployment can read cookies -issued by the Railway backend after the cross-repo unification (Phase -4); dev keeps host-only + insecure for plain-HTTP localhost; everything -else (e.g. staging) is host-only + secure. +the session and CSRF cookies. + +Domain attribute +---------------- + +Production carries ``Domain=.ndi-cloud.com`` ONLY when the request +originates from ``*.ndi-cloud.com`` so the apex Vercel deployment can +read cookies issued by the Railway backend (cross-repo unification, +Phase 4). + +Vercel **preview** deployments at ``*.vercel.app`` get host-only +cookies. A Set-Cookie that carries ``Domain=.ndi-cloud.com`` on a +response served back to a non-``ndi-cloud.com`` host is silently +rejected by the browser — the cookie spec forbids servers from +setting cookies for domains they don't control. That's why +preview-time login was breaking with ``CSRF_INVALID`` errors before +this fix (2026-05-14 tutorial-parity smoke). + +Other attributes +---------------- + +Dev keeps host-only + insecure for plain-HTTP localhost. Staging (and +any other ENVIRONMENT value) is host-only + secure. """ from typing import Any +from urllib.parse import urlparse + +from fastapi import Request from ..config import Settings -def cookie_attrs(settings: Settings) -> dict[str, Any]: +def cookie_attrs(settings: Settings, *, request: Request) -> dict[str, Any]: + """Return the Set-Cookie attribute dict for the current env + request. + + The ``request`` parameter is required: the per-request Origin (or + Referer) is what decides whether the Domain attribute is safe to + attach. Old callers that passed only ``settings`` must be updated — + silently guessing wrong is what broke preview login. + """ if settings.ENVIRONMENT == "production": - return {"secure": True, "domain": ".ndi-cloud.com"} + if _request_from_ndi_cloud(request): + return {"secure": True, "domain": ".ndi-cloud.com"} + # Preview / vercel.app / anything else served by the production + # backend: secure but host-only. The browser will accept these + # because the cookie's implicit Domain matches the response + # origin (the preview hostname). + return {"secure": True} return {"secure": settings.ENVIRONMENT != "development"} + + +def _request_from_ndi_cloud(request: Request) -> bool: + """Was this request issued by a browser tab on ``*.ndi-cloud.com``? + + Reads the Origin header (browsers set this on every cross-site and + every same-origin POST since 2020), with a fallback to Referer for + older clients and the few same-origin GETs that omit Origin. + Returns True only if the URL's hostname is exactly + ``ndi-cloud.com`` or a subdomain of it. + + Returns False when: + - both Origin and Referer are missing or unparseable + - the host doesn't end with ``ndi-cloud.com`` (i.e. preview) + """ + for header_name in ("origin", "referer"): + raw = request.headers.get(header_name) + if not raw: + continue + try: + parts = urlparse(raw) + except ValueError: + continue + if not parts.netloc: + continue + host = parts.netloc.split(":", 1)[0].lower() + if host == "ndi-cloud.com" or host.endswith(".ndi-cloud.com"): + return True + return False diff --git a/backend/auth/dependencies.py b/backend/auth/dependencies.py index 0b24cd1..cf31430 100644 --- a/backend/auth/dependencies.py +++ b/backend/auth/dependencies.py @@ -44,9 +44,13 @@ async def get_current_session( # roam across networks, and hard-rejecting would shred UX. current_ip_hash, current_ua_hash = fingerprint(request) if current_ua_hash != session.user_agent_hash: + # Truncate session id to 8 chars in logs — the full id IS the + # session secret (whoever reads Railway logs could otherwise + # replay it by setting the `session` cookie). Matches the + # other callsites in this module + login.py. log.warning( "session.ua_changed", - session_id=session.session_id, + session_id=session.session_id[:8], stored_ua_hash=session.user_agent_hash, current_ua_hash=current_ua_hash, ) @@ -55,7 +59,7 @@ async def get_current_session( if current_ip_hash != session.ip_addr_hash: log.warning( "session.ip_changed", - session_id=session.session_id, + session_id=session.session_id[:8], stored_ip_hash=session.ip_addr_hash, current_ip_hash=current_ip_hash, ) diff --git a/backend/auth/login.py b/backend/auth/login.py index 6b23fc0..11c7eb1 100644 --- a/backend/auth/login.py +++ b/backend/auth/login.py @@ -104,8 +104,11 @@ async def do_login( # truncate; this success path was the holdout. log.info("auth.login.success", session_id=session.session_id[:8]) - # Session cookie — HttpOnly; Secure + Domain derived from environment. - attrs = cookie_attrs(settings) + # Session cookie — HttpOnly; Secure + Domain derived from + # environment AND the request's Origin (so previews at + # `*.vercel.app` get host-only cookies rather than a Domain the + # browser would reject). + attrs = cookie_attrs(settings, request=request) response.set_cookie( key=SESSION_COOKIE, value=session.session_id, @@ -134,6 +137,7 @@ async def do_login( async def do_logout( *, + request: Request, response: Response, session: SessionData | None, store: SessionStore, @@ -151,17 +155,22 @@ async def do_logout( swallowed so the local teardown completes. """ settings = get_settings() - attrs = cookie_attrs(settings) + # Mirror do_login: per-request Origin decides whether Domain is + # attached, so the delete-cookie attrs match what was set. + attrs = cookie_attrs(settings, request=request) try: if session is not None: try: await cloud.logout(session.access_token) except Exception as e: # Best-effort upstream logout — local teardown continues. + # Truncate session id to 8 chars: the full id IS the + # session secret (anyone with Railway log access could + # otherwise replay it by setting the `session` cookie). log.info( "auth.logout.cloud_failed", reason=type(e).__name__, - session_id=session.session_id, + session_id=session.session_id[:8], ) # Local session teardown must run even if cloud logout raised. await store.delete(session.session_id) diff --git a/backend/cache/redis_table.py b/backend/cache/redis_table.py index a6e2f57..6137180 100644 --- a/backend/cache/redis_table.py +++ b/backend/cache/redis_table.py @@ -62,7 +62,36 @@ def __init__(self, redis: Redis, *, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> N # identifier so two authenticated users can never share a cached entry. # Bumping the schema version ensures in-flight v3 entries are ignored # post-deploy (they TTL out naturally within one hour). - SCHEMA_VERSION = "v4" + # v5: F-1d + F-1e shape change (2026-05-19). + # - Projection dispatch now uses the REQUESTED class instead of + # resolved alias, so element_epoch returns EPOCH_COLUMNS even + # when the alias chain hits epochfiles_ingested / + # daqreader_mfdaq_epochdata_ingested. + # - treatment_drug + treatment_transfer route to TREATMENT_COLUMNS + # (was GENERIC_COLUMNS), with the new auto-detect branches in + # _row_treatment. Bumping the version forces stale GENERIC_COLUMNS + # blobs cached under the old code to be ignored. + # v6: F-1b — subject table now includes dynamic treatment columns + # from server-side broadcast (2026-05-19). The subject summary + # table previously emitted only SUBJECT_COLUMNS (15 cols); + # `_broadcast_treatments_onto_subjects` now appends one + # `Name` + `Ontology` pair per distinct + # treatmentName in the dataset, with per-subject cells + # populated from `stringValue` / `numericValue` / + # `treatmentOntology`. Replaces the cloud-app's frontend + # joinTreatmentsToSubjects pivot at table-shell.tsx. Bumping + # the version invalidates stale v5 subject blobs that lack + # the new columns. + # v7: F-1b follow-up (2026-05-19) — subject enrichment now + # fetches `treatment_drug` + `treatment_transfer` in addition + # to literal `treatment`. Datasets like Bhar (0 literal + + # 24466 drug + 1675 transfer) had cached empty-broadcast v6 + # subject blobs because the enrichment fetcher only ran the + # literal class query (the alias chain only fires for primary- + # class fetches). v7 keys force a re-build that pulls all + # three classes and merges them in `_project_for_class`'s + # subject branch. + SCHEMA_VERSION = "v7" @staticmethod def table_key(dataset_id: str, class_name: str, *, user_scope: str) -> str: diff --git a/backend/config.py b/backend/config.py index b72403f..18962fd 100644 --- a/backend/config.py +++ b/backend/config.py @@ -114,6 +114,22 @@ class Settings(BaseSettings): ), ) + # Stream 3.4 (2026-05-15): per-org access control for the `/ask` + # experimental chat. Comma-separated list of organization IDs + # that have `enable_ask` true. Empty (the default) means EVERY + # authenticated user can use chat — i.e. the experimental + # anonymous-public phase. Once Stream 3.1 moves /ask to + # /my/ask (auth-gated), populate this to gate per-org. Admin + # users (`is_admin=true`) bypass the gate. + ENABLE_ASK_ORG_IDS: str = Field( + default="", + description=( + "Comma-separated list of organization IDs whose users may " + "use the /ask chat. Empty = open to every authenticated " + "user. Admin users always allowed." + ), + ) + # --- Observability --- LOG_LEVEL: str = "INFO" LOG_FORMAT: Literal["json", "console"] = "json" @@ -146,6 +162,27 @@ class Settings(BaseSettings): def cors_origins_list(self) -> list[str]: return [o.strip() for o in self.CORS_ORIGINS.split(",") if o.strip()] + @property + def enable_ask_org_ids_list(self) -> list[str]: + """Parsed allowlist of org IDs that may use /ask. Empty list + means "every authenticated user" (the experimental default).""" + return [ + o.strip() for o in self.ENABLE_ASK_ORG_IDS.split(",") if o.strip() + ] + + def user_can_use_ask(self, *, organization_ids: list[str], is_admin: bool) -> bool: + """Stream 3.4 — verdict gate consumed by `/api/auth/me` and + the `/api/ask` route. Admin always wins; an empty allowlist + means "no gate"; otherwise the user's org set must intersect + the allowlist.""" + if is_admin: + return True + allowlist = self.enable_ask_org_ids_list + if not allowlist: + return True + allowed = set(allowlist) + return any(oid in allowed for oid in organization_ids) + @property def download_host_allowlist_list(self) -> list[str]: """Parsed list of allowlist host patterns (exact or `*.suffix`).""" diff --git a/backend/middleware/csrf.py b/backend/middleware/csrf.py index 7dd1026..11c5c7a 100644 --- a/backend/middleware/csrf.py +++ b/backend/middleware/csrf.py @@ -51,6 +51,15 @@ def verify(signed: str) -> bool: "/api/health", "/api/health/ready", "/metrics", + # Idempotent read-only lookups that happen to be POST-shaped because + # the request body is a small array of CURIEs (request body keeps the + # URL clean and avoids ?term=A&term=B&… repetition for batches up to + # 200 terms). Anonymous visitors hit these every time a /datasets/* + # page renders, before they've had a chance to GET /api/auth/csrf — + # the resulting 403 was falling back to "label-only" display in the + # ontology popovers and surfacing a "1 warning" banner on every + # SummaryTableView (visual-UX audit a395 P0 #3, 2026-05-14). + "/api/ontology/batch-lookup", } # Previously `/api/auth/login` was also exempted on the premise that a # pre-session user couldn't have a CSRF token. That's wrong: the frontend's diff --git a/backend/observability/logging.py b/backend/observability/logging.py index 489d738..489d41d 100644 --- a/backend/observability/logging.py +++ b/backend/observability/logging.py @@ -66,12 +66,25 @@ def configure_logging() -> None: # failure). Leave it out for console mode. renderer = structlog.dev.ConsoleRenderer(colors=True) + # cache_logger_on_first_use=False (was True before Stream 6.6): + # caching binds each ``get_logger(__name__)`` lazy-proxy to its + # first-seen processor chain. In production that's a tiny win, + # but it breaks pytest's structlog.testing.capture_logs() inside + # unit tests that run AFTER an integration test has called + # configure_logging() — the cached proxies stay pinned to the + # integration chain even when the unit test re-configures. The + # symptom: capture_logs returns an empty list while pytest's + # "captured log call" panel shows the WARNING was emitted (caught + # 2026-05-15 against three test_cloud_client + test_dependencies + # flakes). Disabling the cache costs ~1-2 µs per log call in prod + # (negligible vs the network round-trip these logs accompany) and + # makes the test harness behave deterministically. structlog.configure( processors=[*shared, renderer], wrapper_class=structlog.make_filtering_bound_logger(level), context_class=dict, logger_factory=structlog.stdlib.LoggerFactory(), - cache_logger_on_first_use=True, + cache_logger_on_first_use=False, ) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index a4bc8b6..308e3ec 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -19,10 +19,19 @@ dependencies = [ "Pillow>=10.4.0", "numpy>=2.0.0", "scipy>=1.14.0", - # pandas was declared but never imported anywhere in backend/. Removed - # in audit 2026-04-23 (#58) — it was a ~30 MB wheel for no runtime - # use. If a future feature needs it, re-add with a comment pointing - # at the specific caller. + # pandas restored in Phase A of the NDI-python integration + # (2026-05-13) — required because vlt.file and ndi/__init__ both + # eagerly pull pandas via their submodule chains. Without it, the + # NDI stack imports fail at backend startup. + "pandas>=2.0.0", + # Direct runtime deps of the NDI-python stack (installed via + # --no-deps in infra/Dockerfile to skip matplotlib + opencv). + "networkx>=2.6", + "jsonschema>=4.0.0", + "requests>=2.28.0", + "openMINDS>=0.2.0", + "portalocker>=2.0.0", + "h5py>=3.0.0", ] [project.optional-dependencies] @@ -123,6 +132,17 @@ module = [ "scipy.*", "opentelemetry", "opentelemetry.*", + # NDI-python stack — installed via git+https in the Docker image. Locally + # (and in CI before pip install) the modules aren't on PYTHONPATH so mypy + # can't find their stubs. The wrappers in services/ndi_python_service.py + # lazy-import these and surface a typed None on import failure, so the + # "missing-imports" diagnostic is a false positive in our usage shape. + "ndi", + "ndi.*", + "ndicompress", + "ndicompress.*", + "vlt", + "vlt.*", ] ignore_missing_imports = true diff --git a/backend/requirements.txt b/backend/requirements.txt index f80a358..54c82d4 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -8,8 +8,26 @@ cachetools>=5.4.0 cryptography>=43.0.0 structlog>=24.4.0 prometheus-client>=0.20.0 -python-multipart>=0.0.9 +# python-multipart bumped to >=0.0.27 (was >=0.0.9) to close CVE-2026-42561. +# Surfaced 2026-05-15 via `pip-audit`. The 0.0.9 floor still resolved to a +# vulnerable version when Docker rebuilt because the resolver picks the +# latest compatible — but 0.0.9 itself is vulnerable. +python-multipart>=0.0.27 Pillow>=10.4.0 numpy>=2.0.0 scipy>=1.14.0 -# pandas removed in audit 2026-04-23 (#58) — never imported. +# pandas was previously removed in audit 2026-04-23 (#58). RESTORED in Phase A +# of the NDI-python integration: `vlt.file.__init__` and `ndi/__init__` both +# eagerly import pandas via their submodule chains, so the NDI stack won't +# import without it. Marginal image cost ~30 MB; runtime cost ~0 (lazy paths). +pandas>=2.0.0 +# Direct deps of NDI-python and its git kin, hand-listed so the +# `--no-deps` installs in the Dockerfile have everything they need at +# runtime. Skipping matplotlib (~50-70 MB) and opencv-python-headless +# (~80 MB) which are declared but never imported on our paths. +networkx>=2.6 +jsonschema>=4.0.0 +requests>=2.28.0 +openMINDS>=0.2.0 +portalocker>=2.0.0 +h5py>=3.0.0 diff --git a/backend/routers/_deps.py b/backend/routers/_deps.py index 7afc141..0c5baf3 100644 --- a/backend/routers/_deps.py +++ b/backend/routers/_deps.py @@ -9,7 +9,9 @@ from ..cache.redis_table import RedisTableCache from ..clients.ndi_cloud import NdiCloudClient from ..middleware.rate_limit import Limit, RateLimiter +from ..services.aggregate_documents_service import AggregateDocumentsService from ..services.binary_service import BinaryService +from ..services.dataset_binding_service import DatasetBindingService from ..services.dataset_provenance_service import DatasetProvenanceService from ..services.dataset_service import DatasetService from ..services.dataset_summary_service import ( @@ -19,10 +21,12 @@ from ..services.dependency_graph_service import DependencyGraphService from ..services.document_service import DocumentService from ..services.facet_service import FacetService +from ..services.image_service import ImageService from ..services.ontology_service import OntologyService from ..services.pivot_service import PivotService from ..services.query_service import QueryService from ..services.summary_table_service import SummaryTableService +from ..services.tabular_query_service import TabularQueryService from ..services.visualize_service import VisualizeService @@ -50,6 +54,15 @@ def query_service(request: Request) -> QueryService: return QueryService(cloud(request)) +def aggregate_documents_service(request: Request) -> AggregateDocumentsService: + """Stream 4.9 (2026-05-16) — POST /api/aggregate-documents handler. + + Stateless per-request, mirrors `query_service` shape. The cloud client + is the only collaborator; nothing held on app.state. + """ + return AggregateDocumentsService(cloud(request)) + + def table_cache(request: Request) -> RedisTableCache | None: return getattr(request.app.state, "table_cache", None) @@ -62,6 +75,10 @@ def summary_table_service(request: Request) -> SummaryTableService: return SummaryTableService(cloud(request), cache=table_cache(request)) +def tabular_query_service(request: Request) -> TabularQueryService: + return TabularQueryService(summary_table_service(request)) + + def dataset_summary_cache(request: Request) -> RedisTableCache | None: return getattr(request.app.state, "dataset_summary_cache", None) @@ -117,6 +134,10 @@ def binary_service(request: Request) -> BinaryService: return BinaryService(cloud(request)) +def image_service(request: Request) -> ImageService: + return ImageService(cloud(request)) + + def visualize_service(request: Request) -> VisualizeService: return VisualizeService(cloud(request)) @@ -125,6 +146,18 @@ def ontology_service(request: Request) -> OntologyService: return request.app.state.ontology_service # type: ignore[no-any-return] +def dataset_binding_service(request: Request) -> DatasetBindingService: + """Return the singleton DatasetBindingService held on app.state. + + The service owns an in-memory LRU of materialized ndi.dataset.Dataset + objects + per-id locks for download coalescing — both must persist + across requests, so this MUST resolve to the shared instance, not a + new one per call. Lifespan wires + ``app.state.dataset_binding_service`` at startup. + """ + return request.app.state.dataset_binding_service # type: ignore[no-any-return] + + # --- Rate-limit helpers --- async def _subject( diff --git a/backend/routers/aggregate_documents.py b/backend/routers/aggregate_documents.py new file mode 100644 index 0000000..8cdf27d --- /dev/null +++ b/backend/routers/aggregate_documents.py @@ -0,0 +1,44 @@ +"""Aggregate-documents endpoint — Stream 4.9 (2026-05-16). + +POST /api/aggregate-documents → run an ndi_query and aggregate a numeric +field across the matches. Auth-optional: anonymous requests get the +public-dataset slice; authenticated requests get the user's org reach +(propagated via the inbound session). Rate-limited under +``limit_queries`` (heavier than reads — the cloud may scan up to 50K +docs). + +Closes ADR-001 compliance debt: the old TS handler ran the whole loop +on Vercel; this router moves it to the right runtime. +""" +from __future__ import annotations + +from typing import Annotated, Any + +from fastapi import APIRouter, Depends + +from ..auth.dependencies import get_current_session +from ..auth.session import SessionData +from ..services.aggregate_documents_service import ( + AggregateDocumentsRequest, + AggregateDocumentsService, +) +from ._deps import aggregate_documents_service, limit_queries + +router = APIRouter( + prefix="/api/aggregate-documents", + tags=["query"], + dependencies=[Depends(limit_queries)], +) + + +@router.post("") +async def aggregate( + body: AggregateDocumentsRequest, + svc: Annotated[ + AggregateDocumentsService, Depends(aggregate_documents_service), + ], + session: Annotated[SessionData | None, Depends(get_current_session)], +) -> dict[str, Any]: + return await svc.aggregate( + body, access_token=session.access_token if session else None, + ) diff --git a/backend/routers/auth.py b/backend/routers/auth.py index e4c768e..ae06985 100644 --- a/backend/routers/auth.py +++ b/backend/routers/auth.py @@ -85,6 +85,13 @@ class MeResponse(BaseModel): # frontend can render an admin affordance when relevant. organizationIds: list[str] = [] isAdmin: bool = False + # Stream 3.4 (2026-05-15): true when this user is allowed to use + # the /ask chat, given `ENABLE_ASK_ORG_IDS` config + the user's + # org memberships. Admin users always get true. The frontend + # hides /ask nav / surfaces a "request access" affordance when + # this is false. The /api/ask route re-checks server-side so the + # gate isn't bypassable via DOM tampering. + canUseAsk: bool = True class CsrfResponse(BaseModel): @@ -92,7 +99,7 @@ class CsrfResponse(BaseModel): @router.get("/csrf", response_model=CsrfResponse) -async def csrf(response: Response) -> CsrfResponse: +async def csrf(request: Request, response: Response) -> CsrfResponse: raw = generate_token() token = sign(raw) response.set_cookie( @@ -102,7 +109,11 @@ async def csrf(response: Response) -> CsrfResponse: samesite="lax", path="/", max_age=86400, - **cookie_attrs(get_settings()), + # `request` lets cookie_attrs read the Origin header so the + # Domain attribute is only attached when the caller is on + # `*.ndi-cloud.com`. Preview hosts get host-only cookies that + # the browser will actually accept. + **cookie_attrs(get_settings(), request=request), ) return CsrfResponse(csrfToken=token) @@ -134,12 +145,18 @@ async def login( @router.post("/logout") async def logout( + request: Request, response: Response, session: Annotated[SessionData | None, Depends(get_current_session)], store: Annotated[SessionStore, Depends(session_store)], cl: Annotated[NdiCloudClient, Depends(cloud)], ) -> dict[str, bool]: - await do_logout(response=response, session=session, store=store, cloud=cl) + # `request` is threaded through to do_logout so the delete-cookie + # attributes match the set-cookie ones (Domain attribute must agree + # or the browser ignores the clear). + await do_logout( + request=request, response=response, session=session, store=store, cloud=cl, + ) return {"ok": True} @@ -147,6 +164,7 @@ async def logout( async def me( session: Annotated[SessionData, Depends(require_session)], ) -> MeResponse: + settings = get_settings() return MeResponse( userId=session.user_id, email_hash=session.user_email_hash[:16], @@ -155,6 +173,10 @@ async def me( expiresAt=session.access_token_expires_at, organizationIds=list(session.organization_ids), isAdmin=session.is_admin, + canUseAsk=settings.user_can_use_ask( + organization_ids=list(session.organization_ids), + is_admin=session.is_admin, + ), ) diff --git a/backend/routers/documents.py b/backend/routers/documents.py index e3ccc77..ff2352a 100644 --- a/backend/routers/documents.py +++ b/backend/routers/documents.py @@ -1,7 +1,7 @@ """Document list / detail / dependency graph.""" from __future__ import annotations -from typing import Annotated, Any +from typing import Annotated, Any, Literal from fastapi import APIRouter, Depends, Query @@ -55,15 +55,22 @@ async def dependencies( svc: Annotated[DependencyGraphService, Depends(dependency_graph_service)], session: Annotated[SessionData | None, Depends(get_current_session)], max_depth: int = Query(3, ge=1, le=MAX_DEPTH_HARD_CAP, alias="max_depth"), + direction: Literal["both", "upstream", "downstream"] = Query("both"), ) -> dict[str, Any]: - """Walk `depends_on` up to `max_depth` levels in both directions. - Returns `{target_id, target_ndi_id, nodes, edges, node_count, - edge_count, truncated, max_depth}`. See - `services/dependency_graph_service.py` for shape details. + """Walk ``depends_on`` up to ``max_depth`` levels. + + Returns ``{target_id, target_ndi_id, nodes, edges, node_count, + edge_count, truncated, max_depth}``. When ``?direction=upstream`` + or ``?direction=downstream`` is passed, the response also carries + ``direction_filter`` and the returned edges/nodes are restricted + to that walk direction. Default ``both`` preserves the + pre-F-3 behaviour. See ``services/dependency_graph_service.py`` + for full shape details. """ return await svc.get_graph( dataset_id, document_id, max_depth=max_depth, session=session, + direction=direction, ) diff --git a/backend/routers/image.py b/backend/routers/image.py new file mode 100644 index 0000000..10ac9ad --- /dev/null +++ b/backend/routers/image.py @@ -0,0 +1,93 @@ +"""Image endpoint for the experimental /ask chat's ``fetch_image`` tool. + +GET /api/datasets/{dataset_id}/documents/{document_id}/image + ?frame=N (multi-frame TIFF / animated GIF frame index; default 0) + +The route fetches the document's primary image file, decodes it via +Pillow (supports TIFF/PNG/JPEG/GIF auto-detect), converts to a 2D +grayscale float array, and returns the array plus min/max for Plotly's +heatmap colorscale. + +Targets the patch-encounter map / fluorescence image / cell-image use +cases for Haley accept-reject-foraging and Bhar memory datasets. PIs +asking "show me the encounter map" now get an inline heatmap instead of +"that's not currently supported". + +Soft errors (decode failure, missing file, unsupported format) surface +as ``{"error", "errorKind"}`` — the chat tool inspects the envelope and +the LLM tells the user plainly rather than emitting a chart fence. + +This is a NEW additive endpoint. Anonymous-readable. 60s timeout (large +TIFFs from the cloud can be slow to download). +""" +from __future__ import annotations + +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Query + +from ..auth.dependencies import get_current_session +from ..auth.session import SessionData +from ..services.document_service import DocumentService +from ..services.image_service import ImageService +from ._deps import document_service, image_service, limit_reads +from ._validators import DatasetId, DocumentId + +router = APIRouter( + prefix="/api/datasets/{dataset_id}/documents/{document_id}", + tags=["image"], + dependencies=[Depends(limit_reads)], +) + + +@router.get("/image") +async def get_image( + dataset_id: DatasetId, + document_id: DocumentId, + docs: Annotated[DocumentService, Depends(document_service)], + svc: Annotated[ImageService, Depends(image_service)], + session: Annotated[SessionData | None, Depends(get_current_session)], + frame: Annotated[ + int, + Query( + ge=0, + le=10_000, + description=( + "Frame index for multi-frame containers (TIFF stack, " + "animated GIF). Defaults to 0 (first frame). Out-of-range " + "values clamp to the last frame and log a warning." + ), + ), + ] = 0, +) -> dict[str, Any]: + """Return a 2D image array with provenance. + + Response shape (success):: + + { + "width": int, + "height": int, + "data": [[float, ...], ...], + "min": float, + "max": float, + "format": "tiff" | "png" | "jpeg" | "...", + "downsampled": bool, + "source": { + "dataset_id": str, + "document_id": str, + "doc_class": str | None, + "doc_name": str | None, + "filename": str | None, + } + } + + Response shape (soft error):: + + {"error": "...", "errorKind": "notfound|decode|unsupported"} + """ + document = await docs.detail( + dataset_id, + document_id, + access_token=session.access_token if session else None, + ) + return await svc.fetch_image(document, frame=frame, session=session) diff --git a/backend/routers/ndi_dataset.py b/backend/routers/ndi_dataset.py new file mode 100644 index 0000000..d1929fd --- /dev/null +++ b/backend/routers/ndi_dataset.py @@ -0,0 +1,138 @@ +"""ndi_dataset router — Sprint 1.5 cloud-backed dataset binding endpoint. + +GET /api/datasets/{dataset_id}/ndi_overview + Returns a high-level summary (element / subject / epoch counts + + element listing) computed by traversing a LOCAL + ``ndi.dataset.Dataset`` materialized from the cloud's documents. + +Failure posture (deliberate): when the binding can't produce a value — +NDI-python missing, downloadDataset timed out, cloud unreachable, +anything — the endpoint returns **HTTP 503** with a JSON envelope so +the chat tool can gracefully fall back to its existing ``ndi_query`` +path. Callers should NOT treat 503 as a hard failure. + +Why a separate router rather than folding into ``datasets.py``: +1. The Sprint 1.5 binding is OPTIONAL infrastructure — keeping it in + its own module makes it trivial to disable (just unmount the + router) if the cloud auth / Mongo download path fails in + production. +2. The endpoint has a different latency posture (cold loads up to + 90s) — visible isolation helps with metrics + rate-limit reasoning. +""" +from __future__ import annotations + +import asyncio +from typing import Annotated, Any + +from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse + +from ..observability.logging import get_logger +from ..services.dataset_binding_service import DatasetBindingService +from ._deps import dataset_binding_service, limit_reads +from ._validators import DatasetId + +log = get_logger(__name__) + + +# Per-call wall-clock cap. Cold loads can take 10-30s for the demo +# datasets; we allow up to 60s before surfacing a 503 so the chat +# doesn't hang. The service's own ``COLD_LOAD_TIMEOUT_SECONDS`` is 90s +# — that's the BACKGROUND limit (warm/pre-warm tasks). This per- +# request cap is stricter so a user-facing request never blocks the +# router for ~90s. +REQUEST_TIMEOUT_SECONDS = 60.0 + + +router = APIRouter( + prefix="/api/datasets/{dataset_id}", + tags=["ndi_dataset"], + dependencies=[Depends(limit_reads)], +) + + +@router.get("/ndi_overview") +async def ndi_overview( + dataset_id: DatasetId, + svc: Annotated[DatasetBindingService, Depends(dataset_binding_service)], +) -> Any: + """High-level dataset summary computed by NDI-python's SDK. + + Returns a dict shape on success: + + { + element_count: int, + subject_count: int, + epoch_count: int, + elements: [{name, type}], # capped at 50 + elements_truncated: bool, + reference: str, + cache_hit: bool, + cache_age_seconds: float, + } + + Returns 503 with ``{error, reason}`` on any failure (binding + unavailable, cold-load timeout, cloud unreachable). The chat tool + layer translates 503 → graceful fallback prompt. + """ + try: + result = await asyncio.wait_for( + svc.overview(dataset_id), + timeout=REQUEST_TIMEOUT_SECONDS, + ) + except TimeoutError: + log.warning( + "ndi_dataset.overview.request_timeout", + dataset_id=dataset_id, + timeout_seconds=REQUEST_TIMEOUT_SECONDS, + ) + return JSONResponse( + status_code=503, + content={ + "error": "dataset binding unavailable", + "reason": ( + f"overview computation exceeded {REQUEST_TIMEOUT_SECONDS:.0f}s " + "wall clock; try again in a moment" + ), + }, + ) + except Exception as exc: # blind — must not 500 a user request + log.warning( + "ndi_dataset.overview.unexpected_failure", + dataset_id=dataset_id, + error=str(exc), + error_type=type(exc).__name__, + ) + return JSONResponse( + status_code=503, + content={ + "error": "dataset binding unavailable", + "reason": "binding raised an unexpected error", + }, + ) + + if result is None: + # Surface the specific failure code + message captured by the + # service's most-recent cold load. Chat tool falls back to + # ndi_query on any 503; richer diagnostics here help operators + # tell "Phase A missing" from "cloud auth failed" from "/tmp + # full" in the dashboard without tailing logs. + last = svc.last_failure() if hasattr(svc, "last_failure") else None + code = last[0] if last else "binding_unavailable" + reason = ( + last[1] + if last + else ( + "NDI-python dataset materialization failed or is not " + "configured on this server" + ) + ) + return JSONResponse( + status_code=503, + content={ + "error": "dataset binding unavailable", + "code": code, + "reason": reason, + }, + ) + return result diff --git a/backend/routers/psth.py b/backend/routers/psth.py new file mode 100644 index 0000000..15f4eef --- /dev/null +++ b/backend/routers/psth.py @@ -0,0 +1,116 @@ +"""PSTH endpoint for the experimental /ask chat's ``fetch_psth`` tool +and for the data-browser workspace. + +POST /api/datasets/{dataset_id}/psth + Body: PsthRequest (camelCase or snake_case fields accepted) + +Returns a peri-stimulus time histogram for one unit + one stimulus +document. Response shape:: + + { + bin_centers, counts, mean_rate_hz, + n_trials, n_spikes, + bin_size_ms, t0, t1, + unit_name, unit_doc_id, stimulus_doc_id, + per_trial_raster?, # only when include_raster=True + error?, error_kind?, # soft-error envelope + } + +This is a NEW additive endpoint — no schema changes, no existing-route +changes. Read-rate-limited; works for anonymous callers (public +datasets) and logged-in callers (private datasets) via +``get_current_session``. + +Soft errors mirror /signal and /spike-summary: when the unit doc fails +to decode, the stimulus doc carries no event timestamps, or every +window comes back empty, the response is a valid (but zero-filled or +empty) histogram with ``error`` + ``error_kind`` set so the chat tool +can branch on it. + +Cloud-tier hard failures (Railway can't reach ndi-cloud-node, etc.) +translate to a 503 envelope at the HTTP boundary — same pattern as +/spike-summary. +""" +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, Depends + +from ..auth.dependencies import get_current_session +from ..auth.session import SessionData +from ..errors import CloudInternalError, CloudTimeout, CloudUnreachable +from ..observability.logging import get_logger +from ..services.binary_service import BinaryService +from ..services.document_service import DocumentService +from ..services.psth_service import ( + PsthRequest, + PsthResponse, + compute_psth, +) +from ._deps import binary_service, document_service, limit_reads +from ._validators import DatasetId + +log = get_logger(__name__) + +router = APIRouter( + prefix="/api/datasets/{dataset_id}", + tags=["psth"], + dependencies=[Depends(limit_reads)], +) + + +@router.post("/psth") +async def post_psth( + dataset_id: DatasetId, + body: PsthRequest, + docs: Annotated[DocumentService, Depends(document_service)], + bs: Annotated[BinaryService, Depends(binary_service)], + session: Annotated[SessionData | None, Depends(get_current_session)], +) -> PsthResponse: + """Build a peri-stimulus time histogram for one unit + one stimulus. + + The body's ``unit_doc_id`` and ``stimulus_doc_id`` must both be + 24-char Mongo ObjectIds (pydantic enforces min_length only; the + document service resolves ndiId form if the caller passes one + transparently). + """ + try: + return await compute_psth( + body, + document_service=docs, + binary_service=bs, + session=session, + dataset_id=dataset_id, + ) + except (CloudInternalError, CloudUnreachable, CloudTimeout) as exc: + # Translate cloud-layer failures to a typed 503 envelope — + # matches /spike-summary. Without this, the global handler + # returns an opaque 500 and the chat tool can't surface a + # useful error to the LLM. + from fastapi.responses import JSONResponse + log.warning( + "psth.cloud_error", + dataset_id=dataset_id, + error_type=type(exc).__name__, + error=str(exc), + ) + return JSONResponse( # type: ignore[return-value] + status_code=503, + content={ + "bin_centers": [], + "counts": [], + "mean_rate_hz": [], + "n_trials": 0, + "n_spikes": 0, + "bin_size_ms": body.bin_size_ms, + "t0": body.t0, + "t1": body.t1, + "unit_name": "", + "unit_doc_id": body.unit_doc_id, + "stimulus_doc_id": body.stimulus_doc_id, + "per_trial_raster": None, + "error": str(exc) or type(exc).__name__, + "error_kind": "cloud_unavailable", + }, + ) diff --git a/backend/routers/spike_summary.py b/backend/routers/spike_summary.py new file mode 100644 index 0000000..07b2988 --- /dev/null +++ b/backend/routers/spike_summary.py @@ -0,0 +1,100 @@ +"""Spike-summary endpoint for the experimental /ask chat's +``fetch_spike_summary`` tool and for the data-browser workspace. + +POST /api/datasets/{dataset_id}/spike-summary + Body: SpikeSummaryRequest (camelCase or snake_case fields accepted) + +Returns per-unit RAW spike-train data +(``{units: [{name, doc_id, spike_times, isi_intervals}], ...}``). +The TS handler reshapes this into chart_payloads on the chat side; the +workspace consumes raw data directly. + +This is a NEW additive endpoint — no schema changes, no existing-route +changes. Read-rate-limited; works for anonymous callers (public +datasets) and logged-in callers (private datasets) via +``get_current_session``. + +Soft errors mirror the /signal route: a unit whose spike-times array +fails to parse comes back as a unit entry with ``error`` + +``error_kind`` set, so the chat tool can branch on it without +crashing the whole request. +""" +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, Depends + +from ..auth.dependencies import get_current_session +from ..auth.session import SessionData +from ..clients.ndi_cloud import NdiCloudClient +from ..errors import CloudInternalError, CloudTimeout, CloudUnreachable +from ..observability.logging import get_logger +from ..services.document_service import DocumentService +from ..services.spike_summary_service import ( + SpikeSummaryRequest, + SpikeSummaryResponse, + compute_spike_summary, +) +from ._deps import cloud, document_service, limit_reads +from ._validators import DatasetId + +log = get_logger(__name__) + +router = APIRouter( + prefix="/api/datasets/{dataset_id}", + tags=["spike_summary"], + dependencies=[Depends(limit_reads)], +) + + +@router.post("/spike-summary") +async def post_spike_summary( + dataset_id: DatasetId, + body: SpikeSummaryRequest, + docs: Annotated[DocumentService, Depends(document_service)], + cloud_client: Annotated[NdiCloudClient, Depends(cloud)], + session: Annotated[SessionData | None, Depends(get_current_session)], +) -> SpikeSummaryResponse: + """Build a spike-summary response for one or more units. + + The body's ``dataset_id`` (alias ``datasetId``) MUST match the path + parameter — we trust the path for routing and the body for the rest + of the input. When the body's dataset_id differs we override it to + the path value so the URL is the single source of truth. + """ + # URL is source of truth — body might come pre-filled by the TS proxy + # with an out-of-date id and we don't want to surprise the caller + # with a 422 over a mismatch they can't see. Override silently. + if body.dataset_id != dataset_id: + body = body.model_copy(update={"dataset_id": dataset_id}) + + try: + return await compute_spike_summary( + body, + document_service=docs, + cloud=cloud_client, + session=session, + ) + except (CloudInternalError, CloudUnreachable, CloudTimeout) as exc: + # Translate cloud-layer failures to a typed 503 envelope — + # matches /tabular_query. Without this, the global handler + # returns an opaque 500 and the chat tool can't surface a + # useful error to the LLM. + from fastapi.responses import JSONResponse + log.warning( + "spike_summary.cloud_error", + dataset_id=dataset_id, + error_type=type(exc).__name__, + error=str(exc), + ) + return JSONResponse( # type: ignore[return-value] + status_code=503, + content={ + "units": [], + "total_matching": 0, + "kind": body.kind, + "error": str(exc) or type(exc).__name__, + "error_kind": "cloud_unavailable", + }, + ) diff --git a/backend/routers/tables.py b/backend/routers/tables.py index 8a55fc1..2f18dd0 100644 --- a/backend/routers/tables.py +++ b/backend/routers/tables.py @@ -9,7 +9,7 @@ from typing import Annotated, Any -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from ..auth.dependencies import get_current_session from ..auth.session import SessionData @@ -29,6 +29,16 @@ SUPPORTED_CLASSES = { "subject", "probe", "epoch", "element", "element_epoch", "treatment", "openminds", "openminds_subject", "probe_location", + # F-1 (2026-05-19): StimuliPicker projection. + "stimulus", "stimulus_presentation", + # F-1e (2026-05-19): treatment subclass projection direct access. + # Treatment_timeline service already invokes these via summary + # service directly (bypassing the SUPPORTED_CLASSES gate); these + # entries open the public /tables/{class} route for callers that + # want raw treatment_drug / treatment_transfer projections (e.g. + # workspace SubjectsBrowser treatment-broadcast columns, the + # /ask chat's query_documents tool). + "treatment_drug", "treatment_transfer", } @@ -68,10 +78,48 @@ async def single( class_name: str, svc: Annotated[SummaryTableService, Depends(summary_table_service)], session: Annotated[SessionData | None, Depends(get_current_session)], + page: Annotated[int | None, Query(ge=1)] = None, + page_size: Annotated[ + int | None, + Query(ge=1, le=1000, alias="pageSize"), + ] = None, + subject: Annotated[str | None, Query()] = None, ) -> dict[str, Any]: + """Single-class table fetch. + + Pagination (Stream 5.8, 2026-05-16): when ``?page`` and/or ``?pageSize`` + are supplied, the response is sliced server-side and gains the envelope + fields ``{page, pageSize, totalRows, hasMore}``. Defaults: ``page=1``, + ``pageSize=200`` (max 1000). When NEITHER is supplied the response keeps + the legacy unpaged envelope ``{columns, rows, distinct_summary}`` — + backward-compatible with the Document Explorer + cron warm-cache. + + Egress impact: Bhar's ``ontologyTableRow`` is ~5.3k rows x ~15 cols ~= + 6 MB unpaged; with ``pageSize=200`` the first request drops to ~250 KB. + The cache stays keyed by (dataset_id, class_name, user_scope) — full row + set is cached once, every page slices in-memory from the same cached + payload. + + Subject filter (F-2, 2026-05-19): when ``?subject=`` is set, + rows are filtered to those whose ``subjectDocumentIdentifier`` + matches BEFORE pagination is applied. Empowers the workspace's + SessionsBrowser cascade — clicking a subject in the rail + drives ``/tables/element_epoch?subject=`` to fetch ONLY + that subject's sessions instead of pulling all 1604 rows and + filtering client-side. The full table is still cached + (filter happens post-cache), so subsequent unfiltered requests + are still O(1). + """ if class_name not in SUPPORTED_CLASSES and class_name != "combined": raise HTTPException(status_code=400, detail=f"Unsupported table class: {class_name}") return await cancel_on_disconnect( request, - svc.single_class(dataset_id, class_name, session=session), + svc.single_class( + dataset_id, + class_name, + session=session, + page=page, + page_size=page_size, + subject_filter=subject, + ), ) diff --git a/backend/routers/tabular_query.py b/backend/routers/tabular_query.py new file mode 100644 index 0000000..3c1edff --- /dev/null +++ b/backend/routers/tabular_query.py @@ -0,0 +1,268 @@ +"""Tabular-query endpoint for the experimental /ask chat's +``tabular_query`` tool + ``ViolinChart`` component. + +GET /api/datasets/{dataset_id}/tabular_query + ?variableNameContains=ElevatedPlusMaze (required substring) + &groupBy=treatment_group (optional grouping col) + &groupOrder=Saline,CNO (optional CSV order) + +Returns per-group summary stats + raw values for a violin / jitter +plot. See :mod:`backend.services.tabular_query_service` for the +aggregation logic. + +This is a NEW additive endpoint — no schema changes, no existing- +route changes. Anonymous-readable (matches the read posture of the +rest of v2's surface). + +S5.3 (2026-05-18) — POST ``/cross-table-query`` is the sibling +endpoint that pairs two ontologyTableRow columns per subject (or +pairs a measurement with a treatment label). Same anonymous-read +posture, same 503 envelope on cloud failures, body schema lives in +:class:`CrossTableQueryBody`. +""" +from __future__ import annotations + +from typing import Annotated, Any, Literal + +from fastapi import APIRouter, Body, Depends, Query +from pydantic import BaseModel, Field + +from ..auth.dependencies import get_current_session +from ..auth.session import SessionData +from ..errors import CloudInternalError, CloudTimeout, CloudUnreachable +from ..observability.logging import get_logger +from ..services.tabular_query_service import TabularQueryService +from ._deps import limit_reads, tabular_query_service +from ._validators import DatasetId + +log = get_logger(__name__) + +router = APIRouter( + prefix="/api/datasets/{dataset_id}", + tags=["tabular_query"], + dependencies=[Depends(limit_reads)], +) + + +# F-8 (2026-05-19) — POST body model. Mirrors the GET query param +# shape so the cloud-app's POST wrapper at +# /api/datasets/[id]/tabular-query can forward this verbatim +# without GET-vs-POST translation. Both endpoints share +# `_dispatch` so the underlying behaviour stays identical. +class TabularQueryBody(BaseModel): + variableNameContains: str = Field(min_length=1, max_length=200) + groupBy: str | None = Field(default=None, min_length=1, max_length=80) + groupOrder: str | None = Field(default=None, max_length=400) + + +async def _dispatch( + dataset_id: str, + svc: TabularQueryService, + session: SessionData | None, + *, + variableNameContains: str, + groupBy: str | None, + groupOrder: str | None, +) -> Any: + group_order_list = ( + [g.strip() for g in groupOrder.split(",") if g.strip()] + if groupOrder + else None + ) + try: + return await svc.violin_groups( + dataset_id, + variableNameContains, + group_by=groupBy, + group_order=group_order_list, + session=session, + ) + except (CloudInternalError, CloudUnreachable, CloudTimeout) as exc: + # Translate cloud-layer failures to a typed 503 envelope — + # without this, the global FastAPI handler returns an opaque + # 500 JSON and the chat tool layer can't surface a useful + # error to the LLM. The frontend `fetchJson` helper maps 503 + # to a clean "Upstream returned 503" message that the LLM + # then paraphrases. Matches the discipline of /ndi_overview. + from fastapi.responses import JSONResponse + log.warning( + "tabular_query.cloud_error", + dataset_id=dataset_id, + error_type=type(exc).__name__, + error=str(exc), + ) + return JSONResponse( + status_code=503, + content={ + "error": "tabular_query unavailable", + "errorKind": "cloud_unavailable", + "reason": str(exc) or type(exc).__name__, + }, + ) + + +@router.post("/tabular_query") +async def tabular_query_post( + dataset_id: DatasetId, + body: Annotated[TabularQueryBody, Body()], + svc: Annotated[TabularQueryService, Depends(tabular_query_service)], + session: Annotated[SessionData | None, Depends(get_current_session)], +) -> Any: + """POST variant of /tabular_query (F-8, 2026-05-19). + + Same shape + semantics as the GET endpoint; accepts the params + in a JSON body instead of the query string. Lets cloud-app's + workspace wrapper at /api/datasets/[id]/tabular-query forward + its POST body 1:1 without translating to a GET. Both endpoints + share the same handler internals. + """ + return await _dispatch( + dataset_id, + svc, + session, + variableNameContains=body.variableNameContains, + groupBy=body.groupBy, + groupOrder=body.groupOrder, + ) + + +@router.get("/tabular_query") +async def tabular_query( + dataset_id: DatasetId, + svc: Annotated[TabularQueryService, Depends(tabular_query_service)], + session: Annotated[SessionData | None, Depends(get_current_session)], + variableNameContains: Annotated[ + str, + Query( + min_length=1, + max_length=200, + description=( + "Substring matched against the ontologyTableRow's name " + "and column headers. Case-insensitive." + ), + ), + ], + groupBy: Annotated[ + str | None, + Query( + min_length=1, + max_length=80, + description=( + "Optional grouping column (e.g. 'treatment_group', " + "'strain'). When unset, all rows form one group " + "named 'all'." + ), + ), + ] = None, + groupOrder: Annotated[ + str | None, + Query( + max_length=400, + description=( + "Optional CSV of group names defining left-to-right " + "order on the violin plot. Names not present in the " + "data are dropped; data with unlisted groups appears " + "after the listed ones." + ), + ), + ] = None, +) -> dict[str, Any]: + return await _dispatch( + dataset_id, + svc, + session, + variableNameContains=variableNameContains, + groupBy=groupBy, + groupOrder=groupOrder, + ) + + +# --------------------------------------------------------------------------- +# S5.3 (2026-05-18) — POST /cross-table-query +# +# Pairs two ontologyTableRow columns per subject (joinOn=subject) OR +# pairs a measurement column with the subject's treatment label +# (joinOn=treatment). Response shape mirrored on the cloud-app side +# at `apps/web/lib/ndi/tools/cross-table-query.ts`. +# --------------------------------------------------------------------------- + + +class CrossTableQueryBody(BaseModel): + """POST body for ``/cross-table-query``. + + Mirrors the zod schema in + ``apps/web/lib/ndi/tools/cross-table-query.ts::crossTableQueryInput`` + minus the ``datasetId`` (which comes from the path). + """ + + xVariableContains: str = Field(min_length=1, max_length=200) + yVariableContains: str = Field(min_length=1, max_length=200) + joinOn: Literal["subject", "treatment"] + groupBy: str | None = Field(default=None, min_length=1, max_length=80) + # Optional explicit group ordering. List form to mirror the chat + # tool's zod schema; the violin path's ``groupOrder`` is a CSV + # string because it ships as a GET query param. List form here + # avoids the CSV-parsing step. + groupOrder: list[str] | None = Field(default=None, max_length=20) + + +async def _dispatch_cross_table( + dataset_id: str, + svc: TabularQueryService, + session: SessionData | None, + *, + body: CrossTableQueryBody, +) -> Any: + try: + return await svc.cross_table_pairs( + dataset_id, + body.xVariableContains, + body.yVariableContains, + join_on=body.joinOn, + group_by=body.groupBy, + group_order=body.groupOrder, + session=session, + ) + except (CloudInternalError, CloudUnreachable, CloudTimeout) as exc: + # Same 503 envelope as the violin path. Without this, the + # global FastAPI handler returns an opaque 500 JSON and the + # chat tool layer can't surface a useful error to the LLM. + from fastapi.responses import JSONResponse + log.warning( + "cross_table_query.cloud_error", + dataset_id=dataset_id, + error_type=type(exc).__name__, + error=str(exc), + ) + return JSONResponse( + status_code=503, + content={ + "error": "cross_table_query unavailable", + "errorKind": "cloud_unavailable", + "reason": str(exc) or type(exc).__name__, + }, + ) + + +@router.post("/cross-table-query") +async def cross_table_query( + dataset_id: DatasetId, + body: Annotated[CrossTableQueryBody, Body()], + svc: Annotated[TabularQueryService, Depends(tabular_query_service)], + session: Annotated[SessionData | None, Depends(get_current_session)], +) -> Any: + """POST /api/datasets/{dataset_id}/cross-table-query. + + Returns a ``{pairs, xLabel, yLabel, groupLabel, joinKind, unjoined, + source?, _meta?}`` envelope. See + :meth:`TabularQueryService.cross_table_pairs` for behavior. + + Cap: response truncated at ``MAX_PAIRS`` (1000) pairs; a + ``_meta.reason`` diagnostic is added when the cap fires. + """ + return await _dispatch_cross_table( + dataset_id, + svc, + session, + body=body, + ) diff --git a/backend/routers/treatment_timeline.py b/backend/routers/treatment_timeline.py new file mode 100644 index 0000000..6853618 --- /dev/null +++ b/backend/routers/treatment_timeline.py @@ -0,0 +1,124 @@ +"""Treatment-timeline endpoint — Gantt-style horizontal projection of +treatment docs for a dataset. + +POST /api/datasets/{dataset_id}/treatment-timeline + +The Next.js chat tool layer used to own this orchestration in +``apps/web/lib/ndi/tools/treatment-timeline.ts``. We're moving the +heart of NDI processing to Railway/Python so the work lives next to +ndi-python; the TS handler shrinks to a thin proxy that forwards +``{datasetId, title, maxSubjects}`` to this endpoint and reshapes the +raw response into the chat-specific ``chart_payload`` envelope. + +Schema compatibility +──────────────────── +The pydantic body accepts BOTH camelCase (``datasetId``, +``maxSubjects``) and snake_case (``dataset_id``, ``max_subjects``) via +field aliases. The TS proxy sends camelCase; future Python callers +(e.g. the workspace) may prefer snake_case. Both flow through the +same model. + +Error posture +───────────── +We DELIBERATELY do not surface cloud-error envelopes (e.g. 503) here. +The service catches its own internal failures and returns an +``empty_hint`` envelope instead, so callers always get a well-typed +response shape even when one of the two backends is degraded. If +both primary AND fallback are zero, ``empty_hint`` is set; the chart +renders an empty state and the chat tells the user plainly. +""" +from __future__ import annotations + +from typing import Annotated, Any + +from fastapi import APIRouter, Depends +from pydantic import BaseModel, ConfigDict, Field + +from ..auth.dependencies import get_current_session +from ..auth.session import SessionData +from ..observability.logging import get_logger +from ..services.summary_table_service import SummaryTableService +from ..services.tabular_query_service import TabularQueryService +from ..services.treatment_timeline_service import ( + DEFAULT_MAX_SUBJECTS, + HARD_CAP_MAX_SUBJECTS, + TreatmentTimelineService, +) +from ._deps import limit_reads, summary_table_service, tabular_query_service +from ._validators import DatasetId + +log = get_logger(__name__) + + +router = APIRouter( + prefix="/api/datasets/{dataset_id}", + tags=["treatment_timeline"], + dependencies=[Depends(limit_reads)], +) + + +class TreatmentTimelineRequest(BaseModel): + """Body for ``POST /api/datasets/{id}/treatment-timeline``. + + Field aliases let the model accept BOTH camelCase (the TS proxy) + AND snake_case (future Python callers) without forcing the caller + to pick a side. + """ + + # ``populate_by_name`` lets us submit either the alias OR the + # underlying name; the response model serializes by alias. + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + title: str | None = Field( + default=None, + max_length=160, + description="Optional chart title; passed through verbatim.", + ) + max_subjects: int = Field( + default=DEFAULT_MAX_SUBJECTS, + alias="maxSubjects", + gt=0, + le=HARD_CAP_MAX_SUBJECTS, + description=( + f"Max distinct subjects in the chart. Default " + f"{DEFAULT_MAX_SUBJECTS}, hard cap {HARD_CAP_MAX_SUBJECTS}. " + "Beyond that the chart becomes a wall of bars." + ), + ) + + +def treatment_timeline_service( + summary: Annotated[SummaryTableService, Depends(summary_table_service)], + tabular: Annotated[TabularQueryService, Depends(tabular_query_service)], +) -> TreatmentTimelineService: + """DI factory — composes the orchestration service from the two + underlying services that already have their own DI graph wired + on ``app.state``. No new app-state caches required. + """ + return TreatmentTimelineService(summary=summary, tabular=tabular) + + +@router.post("/treatment-timeline") +async def treatment_timeline( + dataset_id: DatasetId, + body: TreatmentTimelineRequest, + svc: Annotated[TreatmentTimelineService, Depends(treatment_timeline_service)], + session: Annotated[SessionData | None, Depends(get_current_session)], +) -> dict[str, Any]: + """Compute the treatment timeline for ``dataset_id``. + + Public/anonymous-readable for public datasets; honors session + cookies for private dataset access (matches the rest of v2's + read surface). Rate-limited under the standard ``reads`` bucket. + + On both primary AND fallback being empty, returns the + well-typed response body with ``empty_hint`` set — does NOT + raise an error. Frontend callers render an empty state and the + chat surfaces the reason in prose. + """ + return await svc.compute_timeline( + dataset_id, + title=body.title, + max_subjects=body.max_subjects, + session=session, + ) diff --git a/backend/services/aggregate_documents_service.py b/backend/services/aggregate_documents_service.py new file mode 100644 index 0000000..ba0429a --- /dev/null +++ b/backend/services/aggregate_documents_service.py @@ -0,0 +1,488 @@ +"""Aggregate-documents service — Stream 4.9 (2026-05-16). + +Closes ADR-001 (Heart-on-Railway) compliance debt: the original aggregation +ran on Vercel (TypeScript) and walked up to 50K documents in JS — wrong +runtime for that workload. This service mirrors the TS handler exactly so +the chat tool can be rewritten as a thin client. + +Pipeline: + 1. Validate input (scope / searchstructure / valueField). + 2. Forward the searchstructure to `ndi.cloud /ndiquery` via the + existing `NdiCloudClient.ndiquery` plumbing. + 3. Walk the returned documents, extract numeric values at + ``valueField`` (dotted path under ``data.*``). + 4. Group by ``groupBy`` (dotted path) when set. Drop docs that have a + numeric value but no group label so ``numeric_matches`` stays + honest. + 5. Compute per-group stats (count, mean, median, std, min, max). + 6. Surface granular per-group sample docs + contributing datasets so + the cloud-app TS client can build per-group / per-dataset + References without re-walking the cloud response. + +Cost guardrails: + - ``max_docs`` caps the scan window at 50,000. Default 5,000 matches + the TS handler. + - Reference list capped at 30 (REFERENCE_CAP) — beyond that the + chat's citation panel becomes wall-of-chips noise. Mirrors TS. +""" +from __future__ import annotations + +import asyncio +import math +from typing import Any + +import structlog +from pydantic import BaseModel, Field, field_validator + +from ..clients.ndi_cloud import BULK_FETCH_MAX, NdiCloudClient +from .query_service import QueryRequest + +log = structlog.get_logger(__name__) + +# --------------------------------------------------------------------------- +# Bounds (mirror the TS handler's constants) +# --------------------------------------------------------------------------- + +MAX_DOCS_DEFAULT = 5_000 +MAX_DOCS_CEILING = 50_000 +REFERENCE_CAP = 30 + +# F-7 (2026-05-19) — hydration concurrency. When the cloud returns slim +# {id, datasetId} pairs or bare id strings, we re-hydrate via bulk_fetch +# in batches of BULK_FETCH_MAX (=500). Bound parallelism so a single +# aggregate doesn't fan out 100 concurrent bulk-fetches against the +# cloud. Matches summary_table_service.MAX_CONCURRENT_BULK_FETCH. +MAX_CONCURRENT_BULK_FETCH = 6 + + +# --------------------------------------------------------------------------- +# Request shape +# --------------------------------------------------------------------------- + +class AggregateDocumentsRequest(BaseModel): + """Pydantic schema matching the TS `AggregateDocumentsInput`. + + `searchstructure` and `scope` are re-used from `QueryRequest`'s + validation pattern (same ops allowlist, same scope grammar) — see + `query_service.py` for the canonical contract. + """ + + scope: str = Field(..., min_length=1, max_length=2048) + searchstructure: list[dict[str, Any]] = Field(..., min_length=1, max_length=20) + valueField: str = Field(..., min_length=1, max_length=256) + groupBy: str | None = Field(default=None, min_length=1, max_length=256) + maxDocs: int | None = Field(default=None, ge=1, le=MAX_DOCS_CEILING) + + @field_validator("scope") + @classmethod + def _check_scope(cls, v: str) -> str: + # Delegate to the same validator QueryRequest uses — keeps the + # two endpoints in lockstep on which scopes are valid. + return QueryRequest._check_scope(v) # type: ignore[no-any-return] + + +# --------------------------------------------------------------------------- +# Service +# --------------------------------------------------------------------------- + +class AggregateDocumentsService: + """Stateless per-call aggregator over an `ndi_query` result set.""" + + def __init__(self, cloud: NdiCloudClient) -> None: + self.cloud = cloud + + async def aggregate( + self, + req: AggregateDocumentsRequest, + *, + access_token: str | None, + ) -> dict[str, Any]: + max_docs = req.maxDocs or MAX_DOCS_DEFAULT + + body = await self.cloud.ndiquery( + searchstructure=[_normalize_node(n) for n in req.searchstructure], + scope=req.scope, + access_token=access_token, + ) + + all_docs = body.get("documents") or [] + if not isinstance(all_docs, list): + all_docs = [] + total_items = int(body.get("totalItems") or len(all_docs)) + scanned_raw = all_docs[:max_docs] + truncated = total_items > len(scanned_raw) or len(all_docs) > max_docs + + # F-7 (2026-05-19) — hydrate slim {id, datasetId} or bare id-string + # results via chunked bulk_fetch. The cloud's `/ndiquery` can return + # any of three shapes per page depending on the search scope: + # - full bodies: [{id, datasetId, data, document_class, ...}] + # - slim refs: [{id, datasetId}] (no `data`) + # - id-only: ["abc...", "def...", ...] (rare; CSV scope) + # In the slim/id-only cases the legacy code path silently dropped + # every doc from numeric_matches (no `data.*` to extract from). + # Hydration is a no-op when every entry already carries `data`. + scanned = await self._hydrate_if_needed( + scanned_raw, scope=req.scope, access_token=access_token, + ) + + # Bucket values by group. When groupBy is unset all values fall into + # the 'all' bucket. Per-group sample doc is the FIRST contributing + # document so the chat can build a "one example from each bucket" + # citation chip. + buckets: dict[str, list[float]] = {} + sample_docs: dict[str, dict[str, Any]] = {} + group_order: list[str] = [] + numeric_matches = 0 + + for doc in scanned: + v = _extract_numeric(doc, req.valueField) + if v is None: + continue + + group_key = "all" + if req.groupBy: + g = _extract_string(doc, req.groupBy) + # Numeric value exists but no group label → skip (matches the + # TS handler's behavior so numeric_matches is honest). + if g is None: + continue + group_key = g + + numeric_matches += 1 + if group_key not in buckets: + buckets[group_key] = [] + group_order.append(group_key) + sample_docs[group_key] = doc + buckets[group_key].append(v) + + groups: list[dict[str, Any]] = [] + for name in group_order: + vals = buckets.get(name) or [] + if not vals: + continue + stats = _summary_stats(vals) + sample = sample_docs.get(name) + groups.append({ + "group": name, + **stats, + "sample_doc": _project_sample(sample) if sample else None, + }) + + # Contributing-dataset list (capped) for the TS client to build + # dataset-level References without re-walking the scan window. + datasets_contributing: list[str] = [] + seen: set[str] = set() + for doc in scanned: + ds = _doc_dataset_id(doc) + if not ds or ds in seen: + continue + seen.add(ds) + datasets_contributing.append(ds) + if len(datasets_contributing) >= REFERENCE_CAP: + break + + return { + "total_items": total_items, + "numeric_matches": numeric_matches, + "truncated": truncated, + "valueField": req.valueField, + "scanned_docs": len(scanned), + "groups": groups, + "datasets_contributing": datasets_contributing, + } + + async def _hydrate_if_needed( + self, + docs: list[Any], + *, + scope: str, + access_token: str | None, + ) -> list[dict[str, Any]]: + """F-7 (2026-05-19): hydrate slim ndiquery results via chunked + bulk_fetch. Preserves identical numeric output relative to the + per-doc `get_document` path — both ultimately reach the same + cloud-side document body, just chunked for round-trip savings. + + Behavior: + - All entries are full bodies (carry ``data``) → returned as-is, + no extra cloud calls (no-op fast path; matches the pre-F-7 + behavior on the happy path). + - Mixed / slim / id-only → group by dataset_id, bulk_fetch in + BULK_FETCH_MAX-sized chunks per dataset, then re-assemble. + - id-only with no dataset_id (bare strings) → emit a structured + warning and DROP those entries. Aggregating without a dataset + id is impossible (bulk_fetch is per-dataset). The cloud emits + this shape only for niche CSV scopes — if it happens, the + operator wants to see it in logs. + """ + if not docs: + return [] + + full, needs_hydration, bare_id_count = _classify_docs(docs, scope=scope) + + if bare_id_count > 0: + log.warning( + "aggregate_documents.bare_ids_dropped", + bare_id_count=bare_id_count, + scope=scope, + ) + + if not needs_hydration: + return full + + hydrated = await self._bulk_fetch_hydration( + needs_hydration, access_token=access_token, + ) + + log.info( + "aggregate_documents.hydrated_via_bulk_fetch", + already_full=len(full), + hydrated=len(hydrated), + datasets=len({ds_id for ds_id, _ in needs_hydration}), + ) + + return _reorder_with_hydration(docs, hydrated, scope=scope) + + async def _bulk_fetch_hydration( + self, + needs_hydration: list[tuple[str, str]], + *, + access_token: str | None, + ) -> list[dict[str, Any]]: + """Group (dataset_id, doc_id) pairs by dataset, slice into + BULK_FETCH_MAX-sized chunks, and fan out concurrent bulk_fetch + calls bounded by MAX_CONCURRENT_BULK_FETCH.""" + by_dataset: dict[str, list[str]] = {} + for ds_id, doc_id in needs_hydration: + by_dataset.setdefault(ds_id, []).append(doc_id) + + batches: list[tuple[str, list[str]]] = [] + for ds_id, ids in by_dataset.items(): + for i in range(0, len(ids), BULK_FETCH_MAX): + batches.append((ds_id, ids[i : i + BULK_FETCH_MAX])) + + sem = asyncio.Semaphore(MAX_CONCURRENT_BULK_FETCH) + + async def _fetch_one(ds_id: str, ids: list[str]) -> list[dict[str, Any]]: + async with sem: + return await self.cloud.bulk_fetch( + ds_id, ids, access_token=access_token, + ) + + results = await asyncio.gather( + *[_fetch_one(ds_id, ids) for ds_id, ids in batches], + ) + hydrated: list[dict[str, Any]] = [] + for r in results: + hydrated.extend(r) + return hydrated + + +# --------------------------------------------------------------------------- +# Helpers — ported from apps/web/lib/ndi/tools/aggregate-documents.ts +# --------------------------------------------------------------------------- + +def _normalize_node(n: dict[str, Any]) -> dict[str, Any]: + """Strip None-valued keys so the cloud sees the same compact shape the + TS client used to send.""" + out: dict[str, Any] = {"operation": n.get("operation")} + for k in ("field", "param1", "param2"): + if k in n and n[k] is not None: + out[k] = n[k] + return out + + +def _scope_is_single_dataset_id(scope: str) -> bool: + """Return True iff `scope` is a single MongoDB ObjectId (24 hex chars). + Used by the F-7 hydration path: only when the search is scoped to one + dataset can bare id strings be attributed to a dataset for bulk_fetch. + """ + if not isinstance(scope, str) or len(scope) != 24: + return False + try: + int(scope, 16) + return True + except ValueError: + return False + + +def _classify_docs( + docs: list[Any], + *, + scope: str, +) -> tuple[list[dict[str, Any]], list[tuple[str, str]], int]: + """F-7 helper: partition raw ndiquery results into + (full_bodies, needs_hydration_pairs, bare_id_dropped_count). + + Full bodies carry ``data.*`` and pass through unchanged. Slim refs + contribute a ``(dataset_id, doc_id)`` pair to the hydration queue. + Bare id strings under a single-dataset scope attribute to that + scope; under any other scope they're unattributable and counted + in the third return value for warning emission upstream. + """ + full: list[dict[str, Any]] = [] + needs_hydration: list[tuple[str, str]] = [] + bare_id_count = 0 + is_single_ds = _scope_is_single_dataset_id(scope) + for d in docs: + if isinstance(d, dict) and isinstance(d.get("data"), dict): + full.append(d) + continue + if isinstance(d, dict): + doc_id = d.get("id") or d.get("_id") or d.get("ndiId") + ds_id = d.get("datasetId") or d.get("dataset") + if isinstance(doc_id, str) and isinstance(ds_id, str): + needs_hydration.append((ds_id, doc_id)) + continue + if isinstance(d, str): + if is_single_ds: + needs_hydration.append((scope, d)) + else: + bare_id_count += 1 + return full, needs_hydration, bare_id_count + + +def _reorder_with_hydration( + docs: list[Any], + hydrated: list[dict[str, Any]], + *, + scope: str, +) -> list[dict[str, Any]]: + """F-7 helper: re-walk the original ``docs`` order, substituting + hydrated bodies for the slim refs (so per-group sample_doc + + datasets_contributing preserve cloud-side ordering). Full bodies + pass through; unhydrateable entries drop silently.""" + index: dict[tuple[str, str], dict[str, Any]] = {} + for h in hydrated: + h_ds_id = _doc_dataset_id(h) + h_doc_id = h.get("id") or h.get("_id") or h.get("ndiId") + if isinstance(h_ds_id, str) and isinstance(h_doc_id, str): + index[(h_ds_id, h_doc_id)] = h + + is_single_ds = _scope_is_single_dataset_id(scope) + ordered: list[dict[str, Any]] = [] + for d in docs: + if isinstance(d, dict) and isinstance(d.get("data"), dict): + ordered.append(d) + continue + if isinstance(d, dict): + doc_id = d.get("id") or d.get("_id") or d.get("ndiId") + ds_id = d.get("datasetId") or d.get("dataset") + if isinstance(doc_id, str) and isinstance(ds_id, str): + hit = index.get((ds_id, doc_id)) + if hit is not None: + ordered.append(hit) + continue + if isinstance(d, str) and is_single_ds: + hit = index.get((scope, d)) + if hit is not None: + ordered.append(hit) + return ordered + + +def _lookup_path(obj: Any, path: str) -> Any: + """Walk a dotted path under an arbitrary nested dict. Returns None on + any missing segment or non-dict ancestor.""" + if not path: + return None + cur: Any = obj + for seg in path.split("."): + if cur is None or not isinstance(cur, dict): + return None + cur = cur.get(seg) + return cur + + +def _extract_numeric(doc: dict[str, Any], path: str) -> float | None: + """Pull a finite numeric value at ``path``. Coerces string-encoded + numbers (e.g. ``"3.14"`` → 3.14) the same way the TS helper does. + Returns None when the path is missing OR the value is NaN/Inf.""" + raw = _lookup_path(doc, path) + if isinstance(raw, bool): + # bools are technically int subclasses in Python; the TS code + # accepts numbers only. + return None + if isinstance(raw, (int, float)): + return float(raw) if math.isfinite(float(raw)) else None + if isinstance(raw, str): + try: + parsed = float(raw) + except ValueError: + return None + return parsed if math.isfinite(parsed) else None + return None + + +def _extract_string(doc: dict[str, Any], path: str) -> str | None: + """Pull a non-empty string value at ``path``. Coerces booleans and + numbers to strings (mirrors the TS helper) so groupBy works against + numeric / boolean group labels.""" + raw = _lookup_path(doc, path) + if isinstance(raw, str): + return raw if len(raw) > 0 else None + if isinstance(raw, bool): + return "true" if raw else "false" + if isinstance(raw, (int, float)): + return str(raw) + return None + + +def _summary_stats(values: list[float]) -> dict[str, float]: + """count / mean / median / std / min / max over a non-empty list. + + Uses the sample standard deviation (N-1 denominator) when len >= 2 to + match the TS handler. Median uses the linear-interpolation midpoint + for even-length lists. + """ + n = len(values) + sorted_vals = sorted(values) + mean = sum(sorted_vals) / n + if n % 2 == 1: + median = sorted_vals[(n - 1) // 2] + else: + median = (sorted_vals[n // 2 - 1] + sorted_vals[n // 2]) / 2 + if n >= 2: + sq = sum((v - mean) * (v - mean) for v in sorted_vals) + std = math.sqrt(sq / (n - 1)) + else: + std = 0.0 + return { + "count": n, + "mean": mean, + "median": median, + "std": std, + "min": sorted_vals[0], + "max": sorted_vals[-1], + } + + +def _doc_dataset_id(doc: dict[str, Any]) -> str | None: + """Best-effort dataset id extraction. Cloud responses use either + ``datasetId`` or ``dataset`` depending on age of the doc. + """ + ds = doc.get("datasetId") or doc.get("dataset") + return str(ds) if ds else None + + +def _project_sample(doc: dict[str, Any]) -> dict[str, Any] | None: + """Compact per-group sample doc for the TS client's chip-builder. + + Only carries the three fields the client needs to build a Reference: + doc id, dataset id, class name. Stripping the rest keeps the response + small (the chat is the primary consumer; bigger responses bloat the + token budget for the LLM's tool result). + """ + doc_id = doc.get("id") or doc.get("_id") or doc.get("ndiId") + dataset_id = _doc_dataset_id(doc) + cls = ( + (doc.get("document_class") or {}).get("class_name") + if isinstance(doc.get("document_class"), dict) + else None + ) or "document" + if not doc_id or not dataset_id: + return None + return { + "id": str(doc_id), + "dataset_id": dataset_id, + "class": str(cls), + } diff --git a/backend/services/binary_service.py b/backend/services/binary_service.py index 6e221a8..8208d01 100644 --- a/backend/services/binary_service.py +++ b/backend/services/binary_service.py @@ -145,7 +145,17 @@ async def get_timeseries( # noqa: PLR0911 ) ref = filtered[0] else: - ref = refs[0] + # 2026-05-19 — Smart default for multi-file docs (e.g. + # `daqreader_mfdaq_epochdata_ingested` carries + # `channel_list.bin` first, followed by N `.nbf_#` signal + # files). The pre-fix behavior `refs[0]` returned the + # metadata file and produced "Could not decode + # channel_list.bin" errors. Now: prefer the first ref whose + # filename ends with a known-decodable extension; fall back + # to refs[0] only when no candidate is decodable (so single- + # file docs continue to work and the error message still + # surfaces a real codec failure rather than "no match"). + ref = _pick_default_signal_ref(refs) if not ref.url: return _timeseries_error("no_download_url", "No download URL available for this file.") @@ -156,22 +166,53 @@ async def get_timeseries( # noqa: PLR0911 return _timeseries_error("download", f"Failed to download file: {e}") name = (ref.filename or "").lower() - # VH-Lab's VHSB files use a text metadata header ("This is a VHSB file, - # http://github.com/VH-Lab") followed by typed binary slots. The v1 - # decoder used the DID-python `vlt` library for this. v2 doesn't bundle - # vlt today; we surface the same "vlt library not available" soft error - # the v1 TimeseriesChart already maps to a friendly message. + # Decoder dispatch order matters. We try the cheapest discriminators + # first and fall through: + # 1. NDI-compressed wrappers (gzip magic 0x1f 0x8b) → ndicompress + # 2. VHSB text-tag header ("This is a VHSB file, ...") → vlt + # 3. VHSB binary-magic (synthetic; rare/dead in practice) → inline _parse_vhsb + # 4. Everything else → inline _parse_nbf (the .nbf raw-binary case) + # + # Phase A swap (2026-05-13): paths #1 and #2 are NEW. Before, both + # short-circuited to soft errors because vlt and ndicompress were + # not installed. Paths #3 and #4 are unchanged. The audit at + # `apps/web/scripts/audit-public-api.mjs` exists to prove that + # the NBF (path #4) byte-shape is byte-identical before/after. + from backend.services import ndi_python_service as _ndi + + # Path 1: NDI-compressed binary (.nbf.tgz wrapper) + if _ndi.is_ndi_compressed(payload): + shaped = _ndi.expand_ephys_from_bytes(payload) + if shaped is not None: + return shaped + # Fall through to soft error below if expansion failed. + return _timeseries_error( + "ndi_compressed_failed", + "This file looks NDI-compressed (.nbf.tgz) but the decoder " + "could not expand it. Format may be a non-Ephys variant.", + ) + + # Path 2: VHSB with the ASCII text-tag header head = payload[:5] if len(payload) >= 5 else b"" if head.startswith(b"This "): + shaped = _ndi.read_vhsb_from_bytes(payload) + if shaped is not None: + return shaped + # NDI stack unavailable, or the file is malformed. Preserve the + # legacy soft-error code so the v1 TimeseriesChart's friendly + # message still surfaces if NDI-python isn't loaded. return _timeseries_error( "vlt_library", - "vlt library is not available on this server — full VHSB " - "decoding requires the DID-python `vlt` extension. The raw " - "file is available in the document's Files section.", + "VHSB decoder unavailable on this server. The raw file is " + "still available in the document's Files section.", ) + try: + # Path 3: synthetic binary-magic VHSB. In practice rare; kept + # so test fixtures continue to work. if name.endswith(".vhsb") or (payload[:4] == b"VHSB"): return _parse_vhsb(payload) + # Path 4: raw .nbf. Byte-shape under audit; do not modify. return _parse_nbf(payload) except Exception as e: log.warning("binary.decode_failed", kind="timeseries", error=str(e)) @@ -188,7 +229,14 @@ async def get_image( refs = _file_refs(document) if not refs: raise BinaryNotFound() - payload = await self.cloud.download_file(refs[0].url, access_token=access_token) + # B5 sweep (2026-05-18). Pre-fix `refs[0]` could pick a metadata + # sidecar (e.g. `imageStack_parameters.json`) on multi-file image + # docs, causing PIL's `Image.open` to raise and the Document + # Explorer's image viewer to throw `BinaryDecodeFailed` even + # when a decodable image existed at a different position in the + # file list. Smart pick mirrors `_pick_default_signal_ref`. + ref = _pick_default_image_ref(refs) + payload = await self.cloud.download_file(ref.url, access_token=access_token) try: # Lazy-import PIL (audit #57) — see module docstring. from PIL import Image @@ -496,6 +544,122 @@ def _file_info_to_ref(fi: dict[str, Any]) -> FileRef | None: return FileRef(url=url, content_type=content_type, filename=name) +# Filename extensions / patterns we know the codec dispatch can decode. +# Matched case-insensitively as a suffix on the filename. Used by +# `_pick_default_signal_ref` to pick a timeseries-bearing file from a +# multi-file document instead of the first-listed file (which is often +# metadata like `channel_list.bin`). +_DECODABLE_SIGNAL_EXTENSIONS: tuple[str, ...] = ( + ".nbf", # NumPy float / float64 binary (Phase-A ingest) + ".vhsb", # vhlab Hauser-style binary (vlt) + ".dat", # raw timeseries (some Phase-A vintages) + ".bin", # generic binary — kept last as a fallback (NOT + # `channel_list.bin` though; see explicit skip below) +) + +# Files known to be metadata, NOT signal data. Skipped when picking the +# default ref for a doc with no explicit `filename=` hint. +_KNOWN_METADATA_FILENAMES: frozenset[str] = frozenset({ + "channel_list.bin", + "channel_info.bin", + "channels.json", + "meta.json", + "metadata.json", +}) + + +# Known image filename extensions Pillow can auto-detect. Matched case- +# insensitively as a suffix on the filename. Used by +# `_pick_default_image_ref` (B5 sweep, 2026-05-18) to pick the image +# file from a multi-file image doc instead of grabbing `refs[0]` +# blindly, which could be a sidecar like `imageStack_parameters.json` +# or a metadata file. The list is intentionally Pillow-aligned — raw +# NDI-native formats (.nim, raw imageStack) are NOT decoded by Pillow +# and the existing service surfaces `unsupported` for them by design. +_DECODABLE_IMAGE_EXTENSIONS: tuple[str, ...] = ( + ".tif", + ".tiff", + ".png", + ".jpg", + ".jpeg", + ".gif", +) + + +def _pick_ref_by_extension( + refs: list[FileRef], + extensions: tuple[str, ...], +) -> FileRef: + """Internal: pick the first ref matching a known decodable extension, + falling back through the metadata blocklist and then ``refs[0]``. + + Shared by ``_pick_default_signal_ref`` and + ``_pick_default_image_ref`` so both can use the same metadata-skip + + extension-suffix matching semantics. + + Heuristic (priority order): + 1. First ref whose filename has a known decodable extension AND + is not in the metadata blocklist. Suffix-matched OR matched + with a non-alphanumeric tail char (handles ``.nbf_1`` / + ``.tif.gz`` variants). + 2. First ref not in the metadata blocklist (any extension). + 3. First ref (legacy fallback — single-file docs hit this). + """ + # Step 1: known-decodable extension, not metadata. + for r in refs: + name = (r.filename or "").lower() + if name in _KNOWN_METADATA_FILENAMES: + continue + for ext in extensions: + # Match exact suffix OR suffix followed by a non-alphanumeric + # character (handles `.nbf_1`, `.nbf_#`, `.nbf.gz`, + # `.tif_1`, `.tiff.gz` variants). + if name.endswith(ext): + return r + idx = name.find(ext) + if idx != -1 and idx + len(ext) < len(name): + tail = name[idx + len(ext)] + if not tail.isalnum(): + return r + # Step 2: first non-metadata ref. + for r in refs: + if (r.filename or "").lower() not in _KNOWN_METADATA_FILENAMES: + return r + # Step 3: legacy fallback. + return refs[0] + + +def _pick_default_signal_ref(refs: list[FileRef]) -> FileRef: + """Pick the most likely timeseries-bearing file from a doc's refs. + + See ``_pick_ref_by_extension`` for the priority order. Targets the + multi-file daqreader case (e.g. ``daqreader_mfdaq_epochdata_ingested`` + on Francesconi) where ``channel_list.bin`` is alphabetically first + but is metadata; the signal lives in ``ai_group1_seg.nbf_1`` etc. + """ + return _pick_ref_by_extension(refs, _DECODABLE_SIGNAL_EXTENSIONS) + + +def _pick_default_image_ref(refs: list[FileRef]) -> FileRef: + """Pick the most likely image-bearing file from a doc's refs. + + B5 sweep (2026-05-18). Mirrors ``_pick_default_signal_ref`` but + against the image-extension list and the SAME metadata blocklist + (any of those JSONs / channel_list.bin would also fail Pillow + decoding, so it's correct to skip them here too). + + Multi-file image docs in NDI typically pair a primary image + (``frame_001.tif``) with one or more sidecars + (``imageStack_parameters.json``, ``calibration.json``). Pre-fix + behavior of ``refs[0]`` could grab the sidecar JSON depending on + cloud-side ordering. Pillow then raises ``UnidentifiedImageError`` + and the request returns ``errorKind="unsupported"`` even though + the doc DOES have a decodable image — just at a different + position in the file list. + """ + return _pick_ref_by_extension(refs, _DECODABLE_IMAGE_EXTENSIONS) + + def _class_name(document: dict[str, Any]) -> str: return ( document.get("className") diff --git a/backend/services/class_aliases.py b/backend/services/class_aliases.py new file mode 100644 index 0000000..33db4b4 --- /dev/null +++ b/backend/services/class_aliases.py @@ -0,0 +1,65 @@ +"""Shared canonical NDI class-name aliases. + +Lives in its own module so both ``summary_table_service`` (which uses +the aliases when projecting per-class tables) and ``document_service`` +(which now uses them in ``list_by_class`` for the workspace's +picker queries) can share one source-of-truth without a circular +import. + +See ``summary_table_service._CLASS_ALIASES`` for the original +docstring + smoke-test context. The forward-import is preserved +there as ``_CLASS_ALIASES = CLASS_ALIASES`` for backwards-compat +with any pickled cache keys / external test stubs. + +Pattern: first non-empty alias wins. Callers should query the +literal class first, and only fall through this chain when the +literal returns zero IDs. Logging the alias hit (preferably as +an ``...alias_hit`` log line) keeps observability able to +distinguish literal-class hits from alias resolution. +""" +from __future__ import annotations + +# Aliases for the canonical NDI class names. The cloud's `isa` operator +# walks ``classLineage`` so a query for the BASE class returns docs of +# any subclass — but the cloud (as deployed) does NOT walk the LINEAGE +# in the OTHER direction. Datasets ingested under the modern schema +# emit ``element`` rather than the legacy ``probe`` class name; an +# `isa` query for ``probe`` returns zero. The fallback chain below is +# what summary_table_service.SummaryTableService._build_single_class +# (the original consumer) and document_service.DocumentService.list_by_class +# (the workspace picker consumer added 2026-05-18 for B2) both use. +# +# Smoke-tested 2026-05-14 against Dabrowska BNST (id 67f723d574f5f79c +# 6062389d, 0 probes / 606 elements): ``query_documents(className=probe)`` +# returned 0 rows pre-fix; with the fallback it returns the 606 +# element rows. +# +# 2026-05-19 — F-1d follow-up. Cloud-app's SessionsBrowser calls +# `useSummaryTable('element_epoch')` which previously returned 0 +# rows for Francesconi (`67f723d574f5f79c6062389d`) and any other +# pre-2025 dataset that landed its per-epoch documents under the +# legacy ingestion-class names. Two known legacy spellings: +# - `epochfiles_ingested` (general Phase-A ingest) +# - `daqreader_mfdaq_epochdata_ingested` (mfdaq-specific ingest) +# First non-empty alias wins via the caller's existing loop. +# +# F-1 (2026-05-19): short-form alias so callers can hit +# /tables/stimulus and get the same projection as +# /tables/stimulus_presentation. Mirrors the probe → element +# short-form pattern. +CLASS_ALIASES: dict[str, list[str]] = { + "probe": ["element"], + "epoch": [ + "element_epoch", + "epochfiles_ingested", + "daqreader_mfdaq_epochdata_ingested", + ], + "element_epoch": [ + "epochfiles_ingested", + "daqreader_mfdaq_epochdata_ingested", + ], + "stimulus": ["stimulus_presentation"], +} + + +__all__ = ["CLASS_ALIASES"] diff --git a/backend/services/dataset_binding_service.py b/backend/services/dataset_binding_service.py new file mode 100644 index 0000000..47d407f --- /dev/null +++ b/backend/services/dataset_binding_service.py @@ -0,0 +1,530 @@ +"""dataset_binding_service — Sprint 1.5 cloud-backed ``ndi.dataset.Dataset`` +binding for the experimental ``/ask`` chat. + +The chat already has a structured ``ndi_query`` tool that proxies to the +cloud-node Mongo layer. This service adds ONE more capability: surfacing +SDK-level abstractions (``dataset.elements()``, ``element.epochs()``, +session traversal) over a LOCAL copy of the dataset that NDI-python has +materialized via :func:`ndi.cloud.orchestration.downloadDataset`. + +Why local materialization? + Most useful summary numbers — element count, total epoch count + across elements, list of (name, type) tuples — are computed by the + SDK by walking ``element`` + ``element_epoch`` docs and traversing + dependencies. The cloud-node's ``/ndiquery`` endpoint returns raw + docs but doesn't perform that traversal. Spinning up a real + ``ndi.dataset.Dataset`` once, in-process, lets us answer these + "how many X are there?" questions cheaply. + +Lifecycle +───────── +- First call for a given dataset_id is a cold load (10-30s typical for + the demo datasets). The cold path runs ``downloadDataset`` in a + thread so it doesn't block the asyncio loop. +- Subsequent calls are warm hits: bounded by an in-memory LRU. +- Cache is keyed by dataset_id; eviction at MAX_CACHED_DATASETS. +- Concurrent calls for the SAME dataset coalesce on an + :class:`asyncio.Lock` so we never download twice in parallel. + +Cache target folder +─────────────────── +- Env-var ``NDI_CACHE_DIR`` (default ``/tmp/ndi-cache``). +- Per-dataset subfolder under that root; downloadDataset itself appends + the dataset_id, so the on-disk layout is + ``//.ndi/…``. +- /tmp is ephemeral on Railway (no persistent volume requested for + this task) — that's fine for the demo. Entries get rebuilt after a + redeploy and the pre-warm tasks fan out automatically. + +Failure posture +─────────────── +- Every public method NEVER raises. On any internal failure they log + a warning and return None. Callers (the FastAPI router) treat None + as "binding unavailable" → 503 → frontend tool falls back to + /ndiquery. Safety > completeness. +""" + +from __future__ import annotations + +import asyncio +import os +import time +from collections import OrderedDict +from pathlib import Path +from typing import Any + +from ..observability.logging import get_logger + +log = get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Tunables (module-level so tests can monkeypatch them) +# --------------------------------------------------------------------------- + +# Max simultaneously-cached datasets. Each cached dataset holds: +# - the Python ndi_dataset_dir object (~MB-scale heap), +# - its on-disk .ndi store under NDI_CACHE_DIR//. +# 5 is enough to cover the 3 demo datasets + headroom for occasional +# user-driven calls without ballooning memory or disk. +MAX_CACHED_DATASETS = 5 + +# Per-cold-load wall-clock cap. downloadDataset is mostly I/O-bound +# (bulk Mongo fetch + S3 presign rewrites). 90s gives slow networks +# enough rope; longer than that and we'd rather surface the failure to +# the caller than hold the request handler open. +COLD_LOAD_TIMEOUT_SECONDS = 90.0 + +# Soft size budget for the on-disk cache. We log a warning when the +# cache exceeds this; we don't auto-prune because eviction-on-LRU +# already bounds growth in the steady state. The warning is the +# operator hint that something's leaking. +CACHE_DIR_SOFT_LIMIT_BYTES = 5 * 1024 * 1024 * 1024 # 5 GB + +# Max elements surfaced in the overview payload. The LLM token budget +# won't survive a 500-element table; the dataset-detail page already +# shows the full element list, the chat is the wrong place for it. +MAX_ELEMENTS_IN_OVERVIEW = 50 + + +# --------------------------------------------------------------------------- +# Internal cache entry — keeps the dataset object + bookkeeping together +# --------------------------------------------------------------------------- + + +class _CacheEntry: + """Single LRU slot. ``dataset`` is the NDI-python object; ``loaded_at`` + powers the ``cache_age_seconds`` field in the overview response. + + Mutable on purpose: ``DatasetBindingService`` rewrites ``loaded_at`` + on every warm hit so the LRU ordering reflects recency-of-use, not + recency-of-load. + """ + + __slots__ = ("dataset", "first_loaded_at", "loaded_at") + + def __init__(self, dataset: Any) -> None: + now = time.monotonic() + self.dataset = dataset + self.loaded_at = now + self.first_loaded_at = now + + +class DatasetBindingService: + """LRU-cached wrapper around :func:`ndi.cloud.orchestration.downloadDataset`. + + Public surface is two coroutines: :meth:`get_dataset` and + :meth:`overview`. Both swallow all exceptions and return ``None`` + on any failure so the router can map that to a 503 and the chat + falls through to its existing tools. + """ + + def __init__(self, *, cache_dir: str | None = None) -> None: + self._cache: OrderedDict[str, _CacheEntry] = OrderedDict() + # Per-dataset locks coalesce concurrent get_dataset() calls so + # two requests for the same id share a single download. + self._locks: dict[str, asyncio.Lock] = {} + # Global lock guards _locks dict mutation and LRU eviction. + self._mutex = asyncio.Lock() + self._cache_dir = Path( + cache_dir or os.environ.get("NDI_CACHE_DIR", "/tmp/ndi-cache") + ) + # Most-recent cold-load failure: ``(code, message)`` tuple or + # None. Captured by :meth:`_cold_load` so the router's 503 + # envelope can surface a specific reason (rather than the + # generic "not configured" string the chat tool used to see). + # Codes are stable identifiers; the message is human-readable. + self._last_failure: tuple[str, str] | None = None + + def last_failure(self) -> tuple[str, str] | None: + """Return the most-recent cold-load failure tuple ``(code, + message)``, or None if no cold load has failed since boot. + Used by the ``/ndi_overview`` router to enrich its 503 envelope + — the chat tool surfaces ``message`` in its fallback hint. + """ + return self._last_failure + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def get_dataset(self, dataset_id: str) -> Any | None: + """Return the cached ndi.dataset.Dataset for ``dataset_id``. + + Cold path: downloads (in a worker thread) under + ``//`` and caches the result. + Warm path: returns the cached object instantly + updates LRU + position. + + Returns ``None`` on any failure (NDI-python unavailable, + download timeout, exception during construction, etc.) — never + raises. The router translates None → 503; the frontend tool + translates 503 → "binding still warming, try ndi_query". + """ + if not dataset_id: + return None + + async with self._mutex: + existing = self._cache.get(dataset_id) + if existing is not None: + # LRU bump + warm-hit log. + self._cache.move_to_end(dataset_id) + existing.loaded_at = time.monotonic() + log.info( + "dataset_binding.warm_hit", + dataset_id=dataset_id, + cache_size=len(self._cache), + ) + return existing.dataset + # No cache entry. Acquire/create the per-dataset lock. + per_lock = self._locks.setdefault(dataset_id, asyncio.Lock()) + + # Hold the per-dataset lock to deduplicate concurrent cold + # loads. After acquiring, re-check the cache — another caller + # may have populated it while we waited. + async with per_lock: + async with self._mutex: + existing = self._cache.get(dataset_id) + if existing is not None: + self._cache.move_to_end(dataset_id) + log.info( + "dataset_binding.warm_hit_after_wait", + dataset_id=dataset_id, + ) + return existing.dataset + + dataset = await self._cold_load(dataset_id) + if dataset is None: + return None + + async with self._mutex: + self._cache[dataset_id] = _CacheEntry(dataset) + self._cache.move_to_end(dataset_id) + self._evict_lru_if_needed() + return dataset + + async def overview(self, dataset_id: str) -> dict[str, Any] | None: + """High-level summary: element / subject / epoch counts + element + listing. See module docstring for why this matters. + + Returns ``None`` if the binding is unavailable. Callers route + that to a 503. + """ + if not dataset_id: + return None + + # cache_hit reflects whether get_dataset hit a warm slot. + # Capture BEFORE the call so we can tell cold from warm. + async with self._mutex: + had_entry = dataset_id in self._cache + + dataset = await self.get_dataset(dataset_id) + if dataset is None: + return None + + # Pull cache age (now in seconds) — after get_dataset() the + # entry's loaded_at was bumped, so first_loaded_at gives us + # the actual age since cold-load. + async with self._mutex: + entry = self._cache.get(dataset_id) + cache_age_seconds = ( + time.monotonic() - entry.first_loaded_at if entry else 0.0 + ) + + # Compute the actual overview off the event loop. Most of the + # work is pure-Python iteration over the in-memory database, + # but element.epochtable() may trigger file I/O for ingested + # epochs. Threadpool it to be safe. + try: + payload: dict[str, Any] | None = await asyncio.to_thread( + self._compute_overview, dataset + ) + except Exception as exc: # blind — overview must never raise + log.warning( + "dataset_binding.overview_failed", + dataset_id=dataset_id, + error=str(exc), + error_type=type(exc).__name__, + ) + return None + + if payload is None: + return None + + payload["cache_hit"] = had_entry + payload["cache_age_seconds"] = round(cache_age_seconds, 2) + return payload + + # ------------------------------------------------------------------ + # Cold-load + overview computation (run off the event loop) + # ------------------------------------------------------------------ + + async def _cold_load(self, dataset_id: str) -> Any | None: + """Run ``downloadDataset`` in a worker thread with a wall-clock cap. + + Stashes the most-recent failure reason on ``self._last_failure`` + so the router can surface a specific cause in its 503 envelope + (instead of a generic "not configured" string). Reasons: + + - ``"phase_a_unavailable"`` — the Phase A NDI-python stack + (``ndi.ontology`` / ``ndicompress`` / ``vlt.file``) didn't + import on boot. Hits the bulk of the experimental Railway + surface; the chat tool's hint already says "use ndi_query". + - ``"binding_unavailable"`` — Phase A imported but the Sprint + 1.5 binding (``ndi.dataset`` / ``ndi.cloud.orchestration``) + didn't. + - ``"cache_dir_unwritable"`` — Railway's ``/tmp`` not mountable + (rare; would also break other tempfile users). + - ``"cold_load_timeout"`` — ``downloadDataset`` exceeded + ``COLD_LOAD_TIMEOUT_SECONDS``. + - ``"cold_load_failed"`` — ``downloadDataset`` raised. Most + common live cause: cloud-node auth (the Sprint 1.5 binding + needs creds the request handler doesn't have). + """ + from . import ndi_python_service + + if not ndi_python_service.is_ndi_available(): + self._last_failure = ( + "phase_a_unavailable", + "NDI-python Phase A stack not importable on this server", + ) + log.warning( + "dataset_binding.ndi_unavailable", + dataset_id=dataset_id, + ) + return None + + if not ndi_python_service.is_dataset_binding_available(): + self._last_failure = ( + "binding_unavailable", + "ndi.dataset / ndi.cloud.orchestration not importable", + ) + log.warning( + "dataset_binding.binding_module_unavailable", + dataset_id=dataset_id, + ) + return None + + # Ensure the cache root exists before handing it to + # downloadDataset (which mkdirs its own per-dataset subfolder + # but assumes the parent is writable). + try: + self._cache_dir.mkdir(parents=True, exist_ok=True) + except OSError as exc: + self._last_failure = ( + "cache_dir_unwritable", + f"cache dir {self._cache_dir} is not writable", + ) + log.warning( + "dataset_binding.cache_dir_unwritable", + cache_dir=str(self._cache_dir), + error=str(exc), + ) + return None + + log.info( + "dataset_binding.cold_load_start", + dataset_id=dataset_id, + cache_dir=str(self._cache_dir), + ) + start = time.monotonic() + try: + dataset = await asyncio.wait_for( + asyncio.to_thread(self._download_blocking, dataset_id), + timeout=COLD_LOAD_TIMEOUT_SECONDS, + ) + except TimeoutError: + self._last_failure = ( + "cold_load_timeout", + f"downloadDataset exceeded {COLD_LOAD_TIMEOUT_SECONDS:.0f}s", + ) + log.warning( + "dataset_binding.cold_load_timeout", + dataset_id=dataset_id, + timeout_seconds=COLD_LOAD_TIMEOUT_SECONDS, + ) + return None + except Exception as exc: # blind — cold load must never raise + self._last_failure = ( + "cold_load_failed", + # Truncate long stack-trace-like messages to a single + # line so the 503 envelope stays readable. Most useful + # info (the exception type + first line of message) is + # in the first ~120 chars. + f"{type(exc).__name__}: {str(exc).splitlines()[0][:200]}" + if str(exc) else type(exc).__name__, + ) + log.warning( + "dataset_binding.cold_load_failed", + dataset_id=dataset_id, + error=str(exc), + error_type=type(exc).__name__, + ) + return None + + duration_seconds = time.monotonic() - start + log.info( + "dataset_binding.cold_load", + dataset_id=dataset_id, + duration_seconds=round(duration_seconds, 2), + ) + + # Best-effort: warn if the on-disk cache has grown past the + # soft budget. Computed lazily so we don't du -sh on every + # call; cheap enough at cold-load granularity. + self._warn_if_cache_oversized() + return dataset + + def _download_blocking(self, dataset_id: str) -> Any: + """Synchronous downloadDataset call. Lives in a thread. + + Lazy-import ndi.cloud here so this module stays cheap to import + even when NDI-python isn't installed (test/CI matrix). + """ + from ndi.cloud.orchestration import downloadDataset + return downloadDataset( + dataset_id, + str(self._cache_dir), + sync_files=False, + ) + + def _compute_overview(self, dataset: Any) -> dict[str, Any] | None: + """Walk the dataset and return the LLM-facing summary dict. + + Runs on a worker thread. Tolerant of partial failures: each + sub-count is wrapped in its own try/except so one missing + traversal doesn't blank the whole payload. + """ + # ------ element listing + count ------ + elements: list[Any] = [] + element_count = 0 + element_listing: list[dict[str, str]] = [] + try: + session = getattr(dataset, "_session", None) + if session is not None and hasattr(session, "getelements"): + elements = list(session.getelements()) or [] + element_count = len(elements) + for elem in elements[:MAX_ELEMENTS_IN_OVERVIEW]: + name = str(getattr(elem, "name", "") or "") + etype = str(getattr(elem, "type", "") or "") + if name or etype: + element_listing.append({"name": name, "type": etype}) + except Exception as exc: + log.warning( + "dataset_binding.element_listing_failed", + error=str(exc), + error_type=type(exc).__name__, + ) + elements = [] + element_count = 0 + element_listing = [] + + # ------ subject count via isa('subject') search ------ + # The ndi_query import is inside the try so a missing SDK + # version on a dev machine downgrades subject_count to 0 + # without blanking the rest of the payload. + subject_count = 0 + try: + from ndi.query import ndi_query + subj_docs = dataset.database_search( + ndi_query("").isa("subject") + ) + subject_count = len(subj_docs) if subj_docs is not None else 0 + except Exception as exc: + log.warning( + "dataset_binding.subject_count_failed", + error=str(exc), + error_type=type(exc).__name__, + ) + + # ------ epoch count via per-element epochtable ------ + # We sum across ALL elements (not just the first + # MAX_ELEMENTS_IN_OVERVIEW) to preserve count fidelity. Each + # element's numepochs() walks the in-memory epoch table; the + # cost is bounded by element count * avg epochs. + epoch_count = 0 + try: + for elem in elements: + try: + if hasattr(elem, "numepochs"): + epoch_count += int(elem.numepochs()) + else: + et, _ = elem.epochtable() + epoch_count += len(et) if et else 0 + except Exception as exc: + # Per-element failure: log but keep counting. + log.debug( + "dataset_binding.element_epoch_count_failed", + element_name=str(getattr(elem, "name", "")), + error=str(exc), + ) + except Exception as exc: + log.warning( + "dataset_binding.epoch_count_failed", + error=str(exc), + error_type=type(exc).__name__, + ) + + # ------ dataset reference (for citation snippet) ------ + reference = "" + try: + reference = str(getattr(dataset, "reference", "") or "") + except Exception: + reference = "" + + return { + "element_count": element_count, + "subject_count": subject_count, + "epoch_count": epoch_count, + "elements": element_listing, + "elements_truncated": element_count > len(element_listing), + "reference": reference, + } + + # ------------------------------------------------------------------ + # LRU eviction + disk-usage guard + # ------------------------------------------------------------------ + + def _evict_lru_if_needed(self) -> None: + """Drop the least-recently-used entry when the cache is full. + + Called under self._mutex. We don't unlink the on-disk folder + of the evicted dataset — leaving it lets a later cold-load + skip the network entirely (downloadDataset reuses an existing + target folder if the JSONs are already there). + """ + while len(self._cache) > MAX_CACHED_DATASETS: + oldest_id, _ = self._cache.popitem(last=False) + self._locks.pop(oldest_id, None) + log.info( + "dataset_binding.evicted", + dataset_id=oldest_id, + cache_size=len(self._cache), + ) + + def _warn_if_cache_oversized(self) -> None: + """Best-effort disk-usage check. Walks the cache dir once per + cold load; cheap relative to a download but still bounded. + """ + try: + if not self._cache_dir.exists(): + return + total = 0 + for path in self._cache_dir.rglob("*"): + try: + if path.is_file(): + total += path.stat().st_size + except OSError: + continue + if total > CACHE_DIR_SOFT_LIMIT_BYTES: + log.warning( + "dataset_binding.cache_dir_oversized", + cache_dir=str(self._cache_dir), + size_bytes=total, + limit_bytes=CACHE_DIR_SOFT_LIMIT_BYTES, + ) + except Exception as exc: + log.debug( + "dataset_binding.cache_dir_size_check_failed", + error=str(exc), + ) diff --git a/backend/services/dataset_summary_service.py b/backend/services/dataset_summary_service.py index 72a17be..5804c7e 100644 --- a/backend/services/dataset_summary_service.py +++ b/backend/services/dataset_summary_service.py @@ -97,7 +97,28 @@ # behavior. SUMMARY_CACHE_TTL_SECONDS = SUMMARY_CACHE_TTL_DEGRADED_SECONDS -SUMMARY_KEY_PREFIX = "summary:v1" +# Bumped v1 → v2 on 2026-05-18 to invalidate stale `counts.sessions` +# entries cached before B6's parent-session filter shipped. +# +# Bumped v2 → v3 on 2026-05-18 (same day) to invalidate the v2 +# entries written by the depends_on-only B6 filter. The newer +# prefix-suffix fallback (Haley case) needs a fresh build to shift +# `counts.sessions` from raw → filtered for datasets that don't +# encode session identity via depends_on graph edges. +# +# Bumped v7 → v8 on 2026-05-19 — F-1c. `counts.probes` now mirrors +# `counts.elements` when the literal `probe` class is zero. The +# Python runtime treats `probe` as an alias for `element` (per +# `services/class_aliases.py`), so datasets like Francesconi +# (`67f723d574f5f79c6062389d`, 0 probe + 606 element + 3 probeType +# facets) previously rendered "Probes: 0" on the snapshot tile. +# Stale v7 entries cached the wrong value and must be invalidated. +# +# Response SHAPE remains identical across all bumps — only the +# `counts.sessions` value can shift. The model's `schemaVersion` +# literal stays `summary:v1` (clients consuming that field don't +# need to recompile). Only the cache key namespace changes. +SUMMARY_KEY_PREFIX = "summary:v8" # Audit 2026-04-23 (#60): bumped 3 → 6 to actually match # ``summary_table_service.MAX_CONCURRENT_BULK_FETCH`` (the prior comment # claimed alignment but the values differed). Catalog list-with-summary @@ -140,6 +161,16 @@ # warmer Mongo cache. PER_CLASS_FETCH_TIMEOUT_SECONDS = 25.0 +# B6 — parent/aggregate-session filter walk cap. If a dataset reports more +# than this many sessions, skip the per-session reverse-dependency check +# entirely and use the raw count. Real datasets with many sessions +# (multi-day recording series) virtually always have downstream refs on +# every leaf, so the filter adds zero value at high N; and the cost is +# O(N) indexed ndiquery calls which can stack into the seconds at large N. +# Sized so Haley (3) + Bhar (2) + Francesconi (1) + every other published +# dataset we've seen falls well under. +_MAX_SESSIONS_FILTER_WALK = 50 + # Stage-1 deadline for the cheap, always-needed cloud calls # (``GET /datasets/:id`` + ``GET /datasets/:id/document-class-counts``). # These typically resolve in 50-200ms warm and 1-5s cold, BUT smoke-test @@ -359,7 +390,7 @@ async def _build_and_serialize( # Cache writes go through JSON, so serialize to a plain dict here. return summary.model_dump(mode="json") - async def _build( + async def _build( # noqa: PLR0912, PLR0915 — single summary orchestrator; splitting would obscure the gather-frontier flow self, dataset_id: str, *, access_token: str | None, ) -> DatasetSummary: t0 = time.perf_counter() @@ -445,8 +476,24 @@ async def _build( # brainRegions; element (primary=probe-like) → probeTypes. These all # share the dataset scope so we parallelize via asyncio.gather with # a shared semaphore bounding bulk-fetch concurrency. + # + # B6 (2026-05-18): also parallelize the parent-session filter walk + # alongside the structured-facts fetches so it adds zero wall-clock + # latency on the happy path (Haley's 3-session walk runs in ~200ms + # against the indexed depends_on path; openminds_subject + probe + # fanouts dominate at multi-second scale). subjects_present = counts.subjects > 0 probe_present = counts.probes > 0 or counts.elements > 0 + # B6 — only filter sessions when (a) there's >1 to potentially filter + # AND (b) the dataset has at least one non-session doc that could + # carry a depends_on ref pointing at a session. Pure-session datasets + # (e.g. test fixtures, newly-published catalogs awaiting their + # element_epoch ingestion) would always report 0 real sessions → + # fail-open kicks in and returns the raw count → wasted ndiquery + # calls. Skip up front. + sessions_filter_warranted = ( + counts.sessions > 1 and counts.totalDocuments > counts.sessions + ) if subjects_present: om_task = self._fetch_class_bounded( @@ -469,6 +516,15 @@ async def _build( pl_task = _empty_list() element_task = _empty_list() + if sessions_filter_warranted: + sessions_filter_task = self._count_real_sessions( + dataset_id, counts.sessions, + access_token=access_token, sem=sem, + warnings=warnings, + ) + else: + sessions_filter_task = _identity_int(counts.sessions) + # Shield each leg with return_exceptions so one flaky class doesn't # torpedo the whole summary — we surface a warning instead. The # `_fetch_class_bounded` wrapper raises ``TimeoutError`` on @@ -477,10 +533,18 @@ async def _build( # ``extractionWarnings`` and return ``[]`` so the per-class fact # extraction below cleanly degrades to ``None`` (subjects→species # null, probe_location→brainRegions null, element→probeTypes null). - results = await asyncio.gather(om_task, pl_task, element_task, return_exceptions=True) + results = await asyncio.gather( + om_task, pl_task, element_task, sessions_filter_task, + return_exceptions=True, + ) openminds_docs = _result_or_warn(results[0], "openminds_subject", warnings) probe_location_docs = _result_or_warn(results[1], "probe_location", warnings) element_docs = _result_or_warn(results[2], "element", warnings) + filtered_sessions = _filtered_sessions_or_warn( + results[3], raw=counts.sessions, warnings=warnings, + ) + if filtered_sessions != counts.sessions: + counts = counts.model_copy(update={"sessions": filtered_sessions}) # Structured facts. species = _extract_om_terms( @@ -497,6 +561,15 @@ async def _build( ) if probe_present else None probe_types = _extract_probe_types(element_docs) if probe_present else None + # Stream 5.6 diagnostic (2026-05-15) — see helper docstring. + _maybe_log_species_empty( + dataset_id=dataset_id, + subjects_present=subjects_present, + species=species, + subject_count=counts.subjects, + openminds_docs=openminds_docs, + ) + # Ontology resolution — delegate label enrichment. Dedupe by # ontologyId so we don't look up the same term twice. await self._enrich_ontology_labels( @@ -617,6 +690,201 @@ async def _one(batch: list[str]) -> list[dict[str, Any]]: flat.extend(c) return flat + # --- Session filtering (B6) ----------------------------------------- + + async def _count_real_sessions( # noqa: PLR0911 — explicit skip-paths read more clearly than a single accumulator + self, + dataset_id: str, + raw_session_count: int, + *, + access_token: str | None, + sem: asyncio.Semaphore, + warnings: list[str], + ) -> int: + """B6 — filter parent/aggregate session docs from the session count. + + A "real" session is one with ≥1 other document carrying + ``depends_on.value`` pointing at its ``ndiId``. Parent/aggregate + sessions (e.g. Haley's ``haley_2025`` parent, ingested 10h after + the ``_Celegans`` + ``_Ecoli`` leaves) have no downstream refs + because they're administrative containers that no + ``element_epoch`` / ``subject`` doc derives from. + + Returns the filtered count. **Fail-open**: on any error or + unexpected zero result, returns ``raw_session_count`` and emits + a structured log line so operators can audit. + + Cost: ``_fetch_class_bounded("session")`` is one ndiquery + + small bulk_fetch (sessions are tiny per-dataset, typically + 1-10). Then one indexed ``depends_on * [ndiId]`` query per + session — ≤10 queries on the indexed downstream path is + ~hundreds of ms total. Runs alongside the existing structured- + facts gather so it adds no extra wall-clock latency on the + hot path. + """ + # Skip when there's nothing to filter. + if raw_session_count <= 1: + return raw_session_count + # Safety cap. Datasets with O(50+) sessions are either a + # genuine multi-day recording series (where every session + # legitimately has downstream refs) or an ingestion anomaly + # — in either case skip the per-session walk to avoid + # quadratic cost. The raw count is "wrong" only by the + # number of parent/aggregate session docs, which is bounded + # by ~1-2 per dataset in practice. + if raw_session_count > _MAX_SESSIONS_FILTER_WALK: + log.info( + "dataset_summary.session_filter_skipped_too_many", + dataset_id=dataset_id, + raw_count=raw_session_count, + ) + return raw_session_count + + try: + session_docs = await self._fetch_class_bounded( + dataset_id, "session", + access_token=access_token, sem=sem, + ) + except Exception as e: + warnings.append(f"session filter: session class fetch failed: {e}") + return raw_session_count + + if len(session_docs) <= 1: + # Cloud may report a different count than what's actually + # fetchable. Use what we can see; fail-open isn't useful + # when the empirical truth disagrees with the count. + return len(session_docs) + + async def _has_downstream(doc: dict[str, Any]) -> bool: + # Try every observed cloud shape for the ndiId: + # 1. Top-level `ndiId` (list_documents_by_dataset / bulk-fetch + # summary projection) + # 2. `data.base.id` (canonical NDI doc body — the one + # `_depends_on_edges` reads in dependency_graph_service) + # 3. Top-level `base.id` (legacy / paranoid) + ndi_id: str | None = None + top_ndi = doc.get("ndiId") + if isinstance(top_ndi, str) and top_ndi: + ndi_id = top_ndi + else: + data = doc.get("data") + if isinstance(data, dict): + base = data.get("base") + if isinstance(base, dict): + v = base.get("id") + if isinstance(v, str) and v: + ndi_id = v + if ndi_id is None: + base = doc.get("base") + if isinstance(base, dict): + v = base.get("id") + if isinstance(v, str) and v: + ndi_id = v + if not ndi_id: + # Can't reverse-query a session with no ndiId — assume + # real so we don't filter it spuriously. + return True + try: + body = await self.cloud.ndiquery( + searchstructure=[ + {"operation": "depends_on", "param1": "*", "param2": [ndi_id]}, + ], + scope=dataset_id, + access_token=access_token, + page_size=1, + fetch_all=False, + ) + except Exception as e: + log.warning( + "dataset_summary.session_downstream_failed", + dataset_id=dataset_id, ndi_id=ndi_id, error=str(e), + ) + # Fail-open: lookup failure → assume real so we never + # silently drop a session because the cloud was flaky. + return True + total = int(body.get("totalItems") or body.get("number_matches") or 0) + return total > 0 + + results = await asyncio.gather( + *[_has_downstream(doc) for doc in session_docs], + return_exceptions=True, + ) + real_count_via_deps = sum(1 for r in results if r is True) + + # Both heuristics are computed up-front so we can compare and + # pick the most informative answer. + # + # Lessons from the Haley live-verification rollout (2026-05-18): + # + # (a) The depends_on heuristic alone is too PERMISSIVE for + # datasets whose parent session has admin docs pointing + # at it (e.g. Haley's `dataset_session_info` doc + # depends_on the parent `haley_2025` session ndiId, so the + # parent looks "referenced" via depends_on even though + # it's an aggregate container with no experimental data). + # Pure depends_on returns 3 for Haley → no filtering. + # + # (b) The reference-prefix heuristic alone is too AGGRESSIVE + # for datasets where sibling sessions happen to share a + # naming prefix (e.g. `cohort_2025` + `cohort_2025_part1` + # are both real but the prefix heuristic would mark + # `cohort_2025` as a parent). + # + # (c) Their composition is reliable: a session is a parent + # iff its `session.reference` has a SIBLING that extends + # it AND removing it from the count makes biological + # sense (i.e. there are downstream refs for the leaves, + # so the lab's experimental data lives in the leaves). + # + # The chosen policy: + # 1. If prefix heuristic returns a conclusive count (not + # None) and that count is strictly less than raw, use it. + # The structural signal "session B's name extends session + # A's name" is hard to fake; this is the safest filter + # whenever it applies, regardless of depends_on result. + # 2. Else, use the depends_on count if > 0 (canonical signal + # for labs that use the dependency graph for session + # identity). + # 3. Else (both 0 / inconclusive), fail-open with raw count. + filtered_via_refs = _filter_by_reference_prefix(session_docs) + + if filtered_via_refs is not None and 0 < filtered_via_refs < raw_session_count: + log.info( + "dataset_summary.session_filter", + dataset_id=dataset_id, + raw_count=raw_session_count, + filtered_count=filtered_via_refs, + parent_or_aggregate_sessions=raw_session_count - filtered_via_refs, + depends_on_count=real_count_via_deps, + via="reference_prefix", + ) + return filtered_via_refs + + if real_count_via_deps > 0: + if real_count_via_deps != raw_session_count: + log.info( + "dataset_summary.session_filter", + dataset_id=dataset_id, + raw_count=raw_session_count, + filtered_count=real_count_via_deps, + parent_or_aggregate_sessions=raw_session_count - real_count_via_deps, + via="depends_on", + ) + return real_count_via_deps + + # Both inconclusive. Fail-open with raw count rather than + # reporting 0 or a wrong value. Most likely cause: newly- + # published dataset awaiting element_epoch ingestion, OR a + # cloud-side reverse-dep outage, OR a lab schema we haven't + # learned yet. + log.warning( + "dataset_summary.session_filter_all_zero", + dataset_id=dataset_id, + raw_count=raw_session_count, + fetched_session_docs=len(session_docs), + ) + return raw_session_count + # --- Ontology resolution -------------------------------------------- async def _enrich_ontology_labels( @@ -704,6 +972,92 @@ async def _empty_list() -> list[dict[str, Any]]: return [] +def _session_reference(doc: dict[str, Any]) -> str | None: + """Extract ``data.session.reference`` from a session doc body. + + Falls back through observed cloud shapes — sometimes the field is + ``reference``, sometimes ``session_reference``. Returns ``None`` if + no reference string is available so callers can skip ambiguous + sessions cleanly. + """ + data = doc.get("data") + if not isinstance(data, dict): + return None + sess = data.get("session") + if not isinstance(sess, dict): + return None + for key in ("reference", "session_reference", "name"): + v = sess.get(key) + if isinstance(v, str) and v: + return v + return None + + +def _filter_by_reference_prefix( + session_docs: list[dict[str, Any]], +) -> int | None: + """B6 fallback heuristic — return the count of "real" sessions when + each session's ``reference`` is structurally a prefix or a leaf in + the dataset's reference tree. + + A session is treated as a **parent** (not real) iff: + * its reference is a strict prefix (with ``_`` separator) of some + OTHER session's reference in the same dataset. + + Equivalently, a session is **real** iff no other session extends + its reference. Concretely on Haley: + + * `haley_2025_Celegans` — no other ref starts with + `haley_2025_Celegans_` → real (leaf). + * `haley_2025_Ecoli` — same → real (leaf). + * `haley_2025` — `haley_2025_Celegans` starts with + `haley_2025_` → parent (filtered). + + Returns ``None`` if the heuristic can't be applied (any session + has no reference field, or all sessions share the exact same + reference string — ambiguous). Returns 0 if EVERY session looks + like a parent (also ambiguous, treat as inconclusive). Otherwise + returns the leaf count. + """ + refs = [_session_reference(doc) for doc in session_docs] + # Need a reference string for every session to make a confident + # determination. Missing fields → bail. + if any(r is None for r in refs): + return None + + refs_set = {r for r in refs if r is not None} + if len(refs_set) <= 1: + # All sessions share the same reference (or there's only one). + # No way to identify a parent — bail. + return None + + leaf_count = 0 + for i, ref_i in enumerate(refs): + if ref_i is None: + # Already bailed above, but keep mypy happy. + return None + is_parent = False + prefix = ref_i + "_" + for j, ref_j in enumerate(refs): + if i == j or ref_j is None: + continue + if ref_j.startswith(prefix): + is_parent = True + break + if not is_parent: + leaf_count += 1 + + return leaf_count + + +async def _identity_int(value: int) -> int: + """Awaitable identity-int. Used by the session-filter gather leg so the + skip-path (counts.sessions ≤ 1) returns the raw count without firing the + per-session reverse-dep walk, while still exposing a coroutine to + asyncio.gather alongside the structured-facts fetches.""" + return value + + def _result_or_warn( result: Any, what: str, warnings: list[str], ) -> list[dict[str, Any]]: @@ -713,6 +1067,28 @@ def _result_or_warn( return cast(list[dict[str, Any]], result) +def _filtered_sessions_or_warn( + result: Any, *, raw: int, warnings: list[str], +) -> int: + """B6 session-filter leg result unwrap. On exception, surfaces a + typed warning into ``extractionWarnings`` and returns the raw + pre-filter count so the summary stays consistent with what the + cloud reported. On success, expects an int (the filtered count + from ``_count_real_sessions``) and returns it unchanged.""" + if isinstance(result, BaseException): + warnings.append(f"session filter failed: {result!s}") + return raw + if isinstance(result, int): + return result + # Defensive: unexpected return type (shouldn't happen — the only + # source is ``_count_real_sessions`` which returns int). Treat as + # fail-open. + warnings.append( + f"session filter returned unexpected type {type(result).__name__}", + ) + return raw + + def _counts_from_raw(raw: dict[str, Any]) -> DatasetSummaryCounts: """``/document-class-counts`` returns ``{datasetId, totalDocuments, classCounts: {class_name: n}}``. We map the canonical classes; any @@ -722,22 +1098,142 @@ def _counts_from_raw(raw: dict[str, Any]) -> DatasetSummaryCounts: # Sessions and probes: the cloud reports whichever class name the # dataset used. Fall back across `probe` / `element` and `session` / # `session_in_a_dataset` so older and newer datasets both reconcile. + # + # Epoch counting: NDI has at least four canonical epoch-bearing + # document classes depending on ingestion path and dataset vintage. + # The fallback chain (tried in priority order, first hit wins) is: + # + # 1. ``element_epoch`` — explicit per-element epoch documents + # (newer NDI datasets, native MATLAB ingest) + # 2. ``epoch`` — legacy plain epoch documents + # 3. ``epochfiles_ingested`` — Phase-A ingest path emits one of + # these per epoch file. Francesconi (BNST patch-clamp) uses + # this exclusively — without this fallback the EPOCHS chip + # reads 0 on the workspace even though the dataset has + # thousands of recorded epochs (caught live during the + # 2026-05-14 tutorial-parity smoke). + # 4. ``daqreader_mfdaq_epochdata_ingested`` — alternate Phase-A + # class for multi-function-DAQ ingest (covers some Van Hooser + # lab datasets); 1:1 with epochs in datasets where it appears. + # + # First-non-zero-wins (not summed). When multiple classes are + # present (e.g. ``epochfiles_ingested`` and the mfdaq variant for + # the same epochs) the chain picks the most authoritative class + # and avoids double-counting. + epoch_classes = ( + "element_epoch", + "epoch", + "epochfiles_ingested", + "daqreader_mfdaq_epochdata_ingested", + ) + epochs = 0 + for cls in epoch_classes: + n = int(class_counts.get(cls) or 0) + if n > 0: + epochs = n + break + + sessions = int( + class_counts.get("session") + or class_counts.get("session_in_a_dataset") + or 0, + ) + subject_count = int(class_counts.get("subject") or 0) + element_count = int(class_counts.get("element") or 0) + # 2026-05-19 — F-1c follow-up. `probe` is a Python runtime alias + # for `element` (per `_CLASS_ALIASES` in `summary_table_service.py`). + # Many datasets emit no literal `probe` documents — Francesconi + # (`67f723d574f5f79c6062389d`) has 0 `probe` + 606 `element`, + # but the snapshot tile previously rendered "Probes: 0" which + # contradicts the catalog's probeTypes facet (3 types). Fall back + # to element_count when literal probe is zero. Logged so we can + # tell the difference between "really zero probes" and "alias + # resolved" in observability. + literal_probe_count = int(class_counts.get("probe") or 0) + if literal_probe_count == 0 and element_count > 0: + log.info( + "dataset_summary.probes_alias_resolved", + raw_probe_count=literal_probe_count, + aliased_probe_count=element_count, + element_count=element_count, + total_documents=int(raw.get("totalDocuments") or 0), + ) + probe_count = element_count + else: + probe_count = literal_probe_count + # Stream 5.5 diagnostic (2026-05-15): some datasets land with + # elements + subjects but zero session-class documents (Mukherjee + # `6546c509…` on 2026-05-15: 1 subject + 7 elements + sessions=0). + # Per NDI's data model you can't have elements without a recording + # session — so a true zero here usually means either (a) ingest is + # mid-pipeline and the session docs haven't landed yet, or (b) the + # dataset uses a session-class spelling we don't yet recognize. + # Emit a structured log so operators can grep Railway logs to find + # the offending datasets + see what session-shaped class names + # they actually emit. NOT a user-visible change. + if sessions == 0 and (element_count > 0 or subject_count > 0): + session_shaped_keys = sorted( + k for k in class_counts if "session" in k.lower() + ) + log.info( + "summary.sessions_zero_with_elements", + element_count=element_count, + subject_count=subject_count, + total_documents=int(raw.get("totalDocuments") or 0), + session_shaped_class_keys=session_shaped_keys, + ) return DatasetSummaryCounts( - sessions=int( - class_counts.get("session") - or class_counts.get("session_in_a_dataset") - or 0, - ), - subjects=int(class_counts.get("subject") or 0), - probes=int(class_counts.get("probe") or 0), - elements=int(class_counts.get("element") or 0), - epochs=int( - class_counts.get("element_epoch") or class_counts.get("epoch") or 0, - ), + sessions=sessions, + subjects=subject_count, + probes=probe_count, + elements=element_count, + epochs=epochs, totalDocuments=int(raw.get("totalDocuments") or 0), ) +def _maybe_log_species_empty( + *, + dataset_id: str, + subjects_present: bool, + species: list[OntologyTerm] | None, + subject_count: int, + openminds_docs: list[dict[str, Any]], +) -> None: + """Stream 5.6 (2026-05-15) — diagnostic for the species-empty + anomaly: some published datasets (Reikersdorfer carbon-fiber, Van + Hooser tree-shrew, Mukherjee gustatory per the 2026-05-15 + cross-dataset smoke) land subjects-present + zero species. The + openminds_subject path requires each subject to have a Species + enrichment doc whose ``openminds_type`` URI ends in ``/Species``. + Datasets ingested under an older NDI version or via a non-canonical + pipeline may emit openminds_subject docs without the species + companion, OR with a ``openminds_type`` URI that doesn't follow + the canonical terminator convention. + + Emit a structured log when subjects exist but species came back + empty, including the set of ``openminds_type`` suffixes actually + present in the dataset's openminds_subject docs. This lets an + operator grep Railway logs to find affected datasets and SEE + what type suffix names the dataset uses — which then drives the + future ``_openminds_type_suffix`` alias map without guessing. + """ + if not subjects_present or species is None or len(species) > 0: + return + observed_suffixes: set[str] = set() + for om_doc in openminds_docs: + suffix = _openminds_type_suffix(om_doc) + if suffix: + observed_suffixes.add(suffix) + log.info( + "dataset_summary.species_empty_with_subjects", + dataset_id=dataset_id, + subjects=subject_count, + openminds_subject_doc_count=len(openminds_docs), + openminds_type_suffixes=sorted(observed_suffixes), + ) + + def _extract_om_terms( openminds_docs: list[dict[str, Any]], type_suffix: str, diff --git a/backend/services/dependency_graph_service.py b/backend/services/dependency_graph_service.py index 783490b..cc174b8 100644 --- a/backend/services/dependency_graph_service.py +++ b/backend/services/dependency_graph_service.py @@ -40,6 +40,10 @@ MAX_NODES_HARD_CAP = 500 # defensive guard for pathological graphs MAX_CONCURRENT_RESOLUTIONS = 8 +# F-3 (2026-05-19) — valid values for the `direction` query param. +# "both" preserves the legacy behaviour (full bidirectional walk). +VALID_DEPENDENCY_DIRECTIONS = frozenset({"both", "upstream", "downstream"}) + class DependencyGraphService: def __init__( @@ -58,22 +62,39 @@ async def get_graph( *, max_depth: int = 3, session: SessionData | None, + direction: str = "both", ) -> dict[str, Any]: + """Walk depends_on; optionally filter result to one direction. + + F-3 (2026-05-19) — added the ``direction`` filter. Accepted + values: ``"upstream"`` (only edges pointing toward the + target's ancestors — "what produced this"), ``"downstream"`` + (only edges pointing toward consumers — "what was produced + from this"), or ``"both"`` (default, preserves the pre-F-3 + full-bidirectional response). + + Cache stays keyed on (dataset_id, document_id, depth) — the + FULL graph is cached, the direction filter is applied to the + cached result. That way upstream- and downstream- queries on + the same target share one cold-compute. + """ depth = max(1, min(MAX_DEPTH_HARD_CAP, int(max_depth or 1))) access_token = session.access_token if session else None if self.cache is not None: key = _dep_graph_key( dataset_id, document_id, depth, user_scope=user_scope_for(session), ) - return await self.cache.get_or_compute( + full = await self.cache.get_or_compute( key, lambda: self._build_graph( dataset_id, document_id, depth, access_token=access_token, ), ) - return await self._build_graph( - dataset_id, document_id, depth, access_token=access_token, - ) + else: + full = await self._build_graph( + dataset_id, document_id, depth, access_token=access_token, + ) + return _filter_graph_by_direction(full, direction) async def _build_graph( # noqa: PLR0912, PLR0915 — single BFS orchestrator; splitting would obscure the frontier bookkeeping self, @@ -440,6 +461,59 @@ def _deduplicate_edges(edges: list[dict[str, Any]]) -> list[dict[str, Any]]: return out +def _filter_graph_by_direction( + full: dict[str, Any], direction: str, +) -> dict[str, Any]: + """Apply F-3 direction filter post-walk. + + ``direction='both'`` returns the full graph untouched. The other + two values keep only the edges whose ``direction`` field matches, + and prune nodes that no longer connect to the target via the kept + edges. The target node itself is always retained. + + Unknown direction values default to 'both' (defensive — the + router pydantic layer already constrains to the literal set). + """ + if direction == "both" or direction not in VALID_DEPENDENCY_DIRECTIONS: + return full + + target_ndi = full.get("target_ndi_id") + raw_nodes = full.get("nodes") or [] + raw_edges = full.get("edges") or [] + + kept_edges = [ + e for e in raw_edges + if isinstance(e, dict) and e.get("direction") == direction + ] + + # Find nodes reachable from target via kept edges. + if isinstance(target_ndi, str) and target_ndi: + reachable: set[str] = {target_ndi} + else: + reachable = set() + for edge in kept_edges: + s = edge.get("source") + t = edge.get("target") + if isinstance(s, str): + reachable.add(s) + if isinstance(t, str): + reachable.add(t) + + kept_nodes = [ + n for n in raw_nodes + if isinstance(n, dict) and n.get("ndi_id") in reachable + ] + + return { + **full, + "nodes": kept_nodes, + "edges": kept_edges, + "node_count": len(kept_nodes), + "edge_count": len(kept_edges), + "direction_filter": direction, + } + + def _empty_graph(document_id: str, *, reason: str) -> dict[str, Any]: return { "target_id": document_id, diff --git a/backend/services/document_service.py b/backend/services/document_service.py index 921c9c8..33f4bb3 100644 --- a/backend/services/document_service.py +++ b/backend/services/document_service.py @@ -7,6 +7,7 @@ from ..clients.ndi_cloud import NdiCloudClient from ..errors import NotFound from ..observability.logging import get_logger +from .class_aliases import CLASS_ALIASES log = get_logger(__name__) @@ -183,6 +184,62 @@ async def list_by_class( body.get("number_matches") or body.get("totalItems") or len(slice_ids), ) + # Class-alias resolution (B2 — 2026-05-18). When the caller + # asked for a specific class and the literal `isa` returned + # zero rows, follow the canonical alias chain (see + # `backend.services.class_aliases.CLASS_ALIASES`). This is + # what `summary_table_service._build_single_class` does + # already; mirroring it here makes the workspace's + # picker queries (`/documents?class=probe` for Haley etc.) + # return the right docs without each panel having to know + # which legacy class name a dataset uses. + # + # Behaviour: + # - Walks the chain ONCE per request; the first non-empty + # alias wins (matches summary_table_service semantics). + # - Total comes from the cloud's ndiquery body — never + # synthesized from the page length, so paginated callers + # see the canonical total. + # - Emits a structured `documents.alias_hit` log line so + # observability can distinguish literal-class hits from + # alias resolution (mirrors `table.single.alias_hit`). + # - When the chain is exhausted with zero rows, we return + # the original empty result — the inline/Mongo fallbacks + # below still don't fire (they're keyed on + # `not class_name`, by design — see PR-#96 / the + # `_inline_id_fallback` docstring). + if class_name and not slice_ids: + for alias in CLASS_ALIASES.get(class_name, []): + alt_body = await self.cloud.ndiquery( + searchstructure=[{"operation": "isa", "param1": alias}], + scope=dataset_id, + access_token=access_token, + page=page, + page_size=page_size, + fetch_all=False, + ) + alt_docs = alt_body.get("documents", []) + alt_ids: list[str] = [ + d.get("id") or d.get("ndiId") + for d in alt_docs + if d.get("id") or d.get("ndiId") + ] + if alt_ids: + log.info( + "documents.alias_hit", + dataset_id=dataset_id, + requested_class=class_name, + resolved_class=alias, + ids=len(alt_ids), + ) + slice_ids = alt_ids + total = int( + alt_body.get("number_matches") + or alt_body.get("totalItems") + or len(alt_ids), + ) + break + # Anonymous-fallback: ndiquery returned nothing AND the user # didn't ask for a class filter. Pull the inline document-id # array from the dataset detail (which works anonymously for diff --git a/backend/services/facet_service.py b/backend/services/facet_service.py index 504cc25..2e6d480 100644 --- a/backend/services/facet_service.py +++ b/backend/services/facet_service.py @@ -538,16 +538,27 @@ class _DedupedTermBucket: be mutated in-place when a higher-count casing wins) and a counter per cased-label-string. The most-frequently-seen casing is the surviving displayed label; ties are broken by first-seen. + + ``ontology_id`` is the merged bucket's authoritative ontologyId + (may be ``None`` until promoted by an incoming labeled term). It + governs label-alias merge eligibility: a label-keyed alias only + matches an incoming term when at most one side has an ontologyId. + Distinct ontologyIds with the same label intentionally stay + distinct (different upstream providers can legitimately catalog + the same name as different concepts). """ - __slots__ = ("counts", "index", "winning_label") + __slots__ = ("counts", "index", "ontology_id", "winning_label") - def __init__(self, *, index: int, label: str) -> None: + def __init__( + self, *, index: int, label: str, ontology_id: str | None, + ) -> None: self.index = index # cased label → seen count. First-seen entry inserted with count # 1; subsequent matches bump. self.counts: dict[str, int] = {label: 1} self.winning_label = label + self.ontology_id = ontology_id def record(self, label: str) -> str | None: """Bump the seen counter for ``label`` and, if the new count @@ -566,6 +577,61 @@ def record(self, label: str) -> str | None: return None +def _term_keys( + label: str, + ontology_id: str | None, + *, + use_paren_abbrev: bool, +) -> tuple[str | None, str | None, str]: + """Compute the (oid, abbrev, norm) dedupe keys for one term. + + All three are returned together so the caller can both look up an + existing bucket and register the term's aliases without recomputing. + ``oid`` and ``abbrev`` may be ``None`` when not applicable; ``norm`` + is always populated (every term has a normalizable label). + """ + oid_key = f"oid::{ontology_id}" if ontology_id else None + abbrev_key: str | None = None + if use_paren_abbrev: + abbrev = _extract_parenthesized_abbrev(label) + if abbrev: + abbrev_key = f"abbrev::{abbrev}" + norm_key = f"norm::{_normalize_label_key(label)}" + return oid_key, abbrev_key, norm_key + + +def _find_bucket( + seen: dict[str, _DedupedTermBucket], + oid_key: str | None, + abbrev_key: str | None, + norm_key: str, + ontology_id: str | None, +) -> _DedupedTermBucket | None: + """Locate an existing bucket for the incoming term, if any. + + Lookup priority: ``oid::`` > ``abbrev::`` > ``norm::``. A direct + ontologyId match always wins (same provider id == same concept). + Label-keyed matches honour the asymmetric merge guard: distinct + ontologyIds with the same label stay distinct, so a bucket carrying + its own ontologyId is skipped when the incoming term carries a + different one. + """ + if oid_key is not None and oid_key in seen: + return seen[oid_key] + for k in (abbrev_key, norm_key): + if k is None or k not in seen: + continue + candidate = seen[k] + if ( + candidate.ontology_id is not None + and ontology_id is not None + and candidate.ontology_id != ontology_id + ): + continue + return candidate + return None + + def _add_ontology_term( term: Any, seen: dict[str, _DedupedTermBucket], @@ -573,20 +639,33 @@ def _add_ontology_term( *, use_paren_abbrev: bool = False, ) -> bool: - """Append ``term`` to ``out`` if its dedupe key (ontologyId, then - parenthesized abbreviation when ``use_paren_abbrev`` is set, then - normalized label) has not been seen yet. Returns True iff something - new was added. - - On a collision: bump the bucket's per-cased-label counter and, if the - new casing now has the highest count, swap the displayed label of the - already-emitted :class:`OntologyTerm` in-place. Pre-fix: collisions - were silently ignored; the first-seen casing was the only one ever - surfaced. + """Append ``term`` to ``out`` if no existing entry matches it under + any of its dedupe keys (ontologyId, parenthesized abbreviation for + brain regions, normalized label). Returns True iff something new + was added. + + Bucket lookup walks the incoming term's candidate keys in priority + order — ``oid::`` first (most authoritative), then ``abbrev::`` for + brain regions, then ``norm::``. A label-keyed match is only accepted + when at most one of (existing bucket, incoming term) carries an + ontologyId: two distinct ontologyIds with the same label stay + distinct because different upstream providers can legitimately + catalog the same name as different concepts (see + ``test_ontology_id_still_takes_priority_over_label_normalization``). + + Each emitted entry registers all its candidate keys as aliases + pointing to the same bucket. Pre-fix the keyspace was disjoint: + an ontologyId-keyed entry and a label-keyed entry for the same + species — ``Caenorhabditis elegans`` with ``NCBITaxon:6239`` from + one dataset, label-only from another — would surface as two + distinct chips. Post-fix: the asymmetric label-alias merge + collapses them. ``ontologyId`` is preserved across merges; a + later-arriving labeled term promotes a label-only entry by + contributing its ontologyId. Tolerates both :class:`OntologyTerm` instances and serialized dicts - (facet builder has both on hand since full summaries come through as - ``model_dump``). + (facet builder has both on hand since full summaries come through + as ``model_dump``). """ if isinstance(term, OntologyTerm): raw_label: Any = term.label @@ -603,37 +682,42 @@ def _add_ontology_term( raw_ontology_id if isinstance(raw_ontology_id, str) and raw_ontology_id else None ) - # Dedupe key resolution, in priority order: - # 1. ontologyId (most authoritative — same provider id wins). - # 2. Parenthesized abbreviation (brain-region only) — collapses - # "Bed nucleus of the stria terminalis (BNST)" with - # "Bed nucleus of stria terminalis (BNST)". - # 3. Normalized label (lowercase + collapse-whitespace + strip) — - # collapses case-identical and trivial-whitespace duplicates. - key: str - if ontology_id: - key = f"oid::{ontology_id}" - else: - abbrev = ( - _extract_parenthesized_abbrev(label) if use_paren_abbrev else None - ) - key = ( - f"abbrev::{abbrev}" if abbrev else f"norm::{_normalize_label_key(label)}" - ) + oid_key, abbrev_key, norm_key = _term_keys( + label, ontology_id, use_paren_abbrev=use_paren_abbrev, + ) + bucket = _find_bucket(seen, oid_key, abbrev_key, norm_key, ontology_id) - bucket = seen.get(key) if bucket is not None: + # Register any candidate keys this term carries that the bucket + # didn't already have. Promotion below may also extend + # ``ontology_id``, which we then register as a new alias so a + # later same-ontologyId visit can match directly. + for k in (oid_key, abbrev_key, norm_key): + if k is not None: + seen.setdefault(k, bucket) new_winner = bucket.record(label) - if new_winner is not None: - # In-place label swap on the already-emitted term so the - # output list reflects the most-frequently-seen casing - # without us having to re-emit or sort. - existing = out[bucket.index] + existing = out[bucket.index] + # Promotion: if the incoming term carries an ontologyId and the + # bucket doesn't, adopt it. The label stays the winning-casing + # decision from ``bucket.record``. + promoted_ontology_id = existing.ontologyId or ontology_id + if bucket.ontology_id is None and ontology_id is not None: + bucket.ontology_id = ontology_id + # Register the newly-adopted oid as an alias so future + # same-oid arrivals find this bucket directly. + seen.setdefault(f"oid::{ontology_id}", bucket) + if new_winner is not None or promoted_ontology_id != existing.ontologyId: out[bucket.index] = OntologyTerm( - label=new_winner, ontologyId=existing.ontologyId, + label=new_winner or existing.label, + ontologyId=promoted_ontology_id, ) return False - seen[key] = _DedupedTermBucket(index=len(out), label=label) + new_bucket = _DedupedTermBucket( + index=len(out), label=label, ontology_id=ontology_id, + ) + for k in (oid_key, abbrev_key, norm_key): + if k is not None: + seen[k] = new_bucket out.append(OntologyTerm(label=label, ontologyId=ontology_id)) return True diff --git a/backend/services/image_service.py b/backend/services/image_service.py new file mode 100644 index 0000000..f903a5b --- /dev/null +++ b/backend/services/image_service.py @@ -0,0 +1,320 @@ +"""image_service — fetch + decode 2D image arrays from NDI binary documents. + +Used by the chat's ``fetch_image`` tool to render microscopy / fluorescence / +patch-encounter maps inline as Plotly heatmaps. The PI workflow is: + + "show me the patch encounter map for the Haley accept-reject dataset" + "show me the cell image from this Bhar memory recording" + +Returns a 2D array of floats (one row = one image row), plus min/max for +colorscale anchoring and a source provenance block for citation. + +Why a separate service (not a method on BinaryService)? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``BinaryService.get_image`` already exists but returns a base64-encoded +PNG/JPEG datauri for the Document Explorer's image viewer. The chat +needs the actual pixel array so Plotly can render it as a heatmap with +its own colorscale, tooltips, and axis-scaling — a datauri is opaque to +Plotly. Keeping the two paths separate avoids cross-coupling: the +viewer endpoint can keep its base64 contract; the chat endpoint gets +a clean float-array shape. + +NDI-native raw image formats (``.nim`` and friends) are NOT yet handled +here — Pillow handles TIFF/PNG/JPEG/GIF which covers the demo datasets. +A future enhancement will route raw-uint8 imageStack files through the +existing ``imageStack_parameters`` sidecar pattern (same shape +``BinaryService.get_raw`` already supports for the Document Explorer). +For now those return ``errorKind="unsupported"``. + +Returned dict shape on success:: + + { + "width": int, + "height": int, + "data": [[float, ...], ...], # height x width + "min": float, + "max": float, + "format": "tiff" | "png" | "jpeg" | ..., + "downsampled": bool, # True if thumbnailed to <= 512x512 + "source": { + "dataset_id": str, + "document_id": str, + "doc_class": str | None, + "doc_name": str | None, + "filename": str | None, + }, + } + +Soft-error envelope (no raise):: + + {"error": "...", "errorKind": "decode|notfound|unsupported"} +""" +from __future__ import annotations + +import io +from typing import TYPE_CHECKING, Any + +from ..auth.session import SessionData +from ..clients.ndi_cloud import NdiCloudClient +from ..observability.logging import get_logger +from .binary_service import _file_refs, _pick_default_image_ref + +if TYPE_CHECKING: # pragma: no cover + from PIL import Image as _PILImage # noqa: F401 + +log = get_logger(__name__) + +# Downsample threshold — Plotly heatmaps slow noticeably above ~512x512 +# (each pixel becomes a hover target). The chat surface is small anyway +# (~600 px wide in a typical message); a 512px thumbnail is sharper than +# the rendered size with room for retina-class displays. +MAX_DIMENSION = 512 + + +class ImageService: + """Decode 2D image arrays from NDI binary documents for chat rendering. + + Reuses BinaryService's file-ref extraction (handles the three observed + cloud document shapes) plus the cloud client's SSRF-hardened download + path. Pillow does the format dispatch — TIFF, PNG, JPEG, GIF all flow + through ``Image.open`` cleanly. + """ + + def __init__(self, cloud: NdiCloudClient) -> None: + self.cloud = cloud + + async def fetch_image( + self, + document: dict[str, Any], + *, + frame: int = 0, + session: SessionData | None = None, + ) -> dict[str, Any]: + """Fetch + decode the primary image file on ``document``. + + ``frame`` selects which frame to extract from a multi-frame TIFF / + animated GIF. Out-of-range frames clamp to (0, n_frames-1) and a + warning is logged. + + Returns a dict matching the module-docstring shape on success, or + a ``{"error", "errorKind"}`` envelope on a soft failure. The + envelope shape matches BinaryService's ``_timeseries_error`` so + the router can pass it through without re-shaping. + """ + refs = _file_refs(document) + if not refs: + return _image_error( + "notfound", + "No image file associated with this document.", + ) + + # B5 sweep (2026-05-18). Pre-fix behaviour `refs[0]` could pick a + # metadata sidecar (e.g. `imageStack_parameters.json`) on multi- + # file image docs, causing Pillow's `Image.open` to raise and the + # request to return `errorKind="unsupported"` even though the + # doc DID have a decodable image at a different position in the + # file list. The smart pick mirrors `_pick_default_signal_ref`: + # known image extension first, then non-metadata fallback, + # then legacy `refs[0]`. + ref = _pick_default_image_ref(refs) + if not ref.url: + return _image_error( + "notfound", + "No download URL available for this image file.", + ) + + access_token = session.access_token if session else None + try: + payload = await self.cloud.download_file( + ref.url, access_token=access_token, + ) + except Exception as e: + log.warning("image_service.download_failed", error=str(e)) + return _image_error( + "notfound", f"Failed to download image file: {e}", + ) + + return _decode_image( + payload, + frame=frame, + filename=ref.filename, + source=_source_block(document, filename=ref.filename), + ) + + +# --------------------------------------------------------------------------- +# Decode helpers — pure functions, no I/O. Tests exercise these directly +# with fixture bytes so the cloud-download stub stays minimal. +# --------------------------------------------------------------------------- + + +def _decode_image( # noqa: PLR0911, PLR0912 (linear per-failure-mode returns are clearer than a single accumulator; the branch count is one return per failure mode plus the success path) + payload: bytes, + *, + frame: int, + filename: str | None, + source: dict[str, Any], +) -> dict[str, Any]: + """Decode a raw image payload to a 2D float array. + + Pillow handles TIFF / PNG / JPEG / GIF auto-detect. Multi-channel + (RGB / RGBA) images are converted to grayscale via Pillow's ``"L"`` + mode — a heatmap renders a single channel, and Plotly's colorscale + is a more useful visual than three superimposed channels would be + for the typical microscopy / patch-encounter use case. + + For raw NDI-native image formats (.nim, .imageStack) Pillow will + raise — we surface as ``unsupported`` and the caller can prompt the + user to check back later or open the Document Explorer. + """ + if not payload: + return _image_error("notfound", "Image file is empty.") + + try: + # Lazy-import Pillow — matches BinaryService's pattern. Numpy is + # imported the same way (only paid when decoding actually runs). + from PIL import Image + except ImportError as e: + log.warning("image_service.pillow_unavailable", error=str(e)) + return _image_error("decode", f"Pillow import failed: {e}") + + try: + # Pillow's `Image.open()` returns an `ImageFile` subclass; subsequent + # `convert()` calls return the broader `Image.Image` type. We hold a + # widened reference here so mypy is happy with the rebind below. + img: Image.Image = Image.open(io.BytesIO(payload)) + except Exception as e: + log.warning("image_service.pil_open_failed", error=str(e)) + return _image_error( + "unsupported", + f"Image format not recognized by Pillow: {e}. " + "NDI-native raw image formats (.nim, raw imageStack) are not " + "yet supported by the chat heatmap renderer.", + ) + + fmt = (img.format or "").lower() or "raw" + + # Frame selection for multi-frame containers (TIFF stacks, animated + # GIFs). Pillow's `seek` raises on out-of-range; we clamp + log a + # warning rather than failing so the LLM gets a useful fallback. + n_frames = getattr(img, "n_frames", 1) + if frame > 0: + target = min(frame, n_frames - 1) if n_frames > 1 else 0 + if target != frame: + log.info( + "image_service.frame_clamped", + requested=frame, available=n_frames, used=target, + ) + try: + img.seek(target) + except Exception as e: + log.warning("image_service.frame_seek_failed", error=str(e)) + return _image_error( + "decode", + f"Failed to seek to frame {frame} (image has {n_frames} frame(s)): {e}", + ) + + # Convert to single-channel grayscale BEFORE thumbnailing — Pillow's + # mode-aware downscale is faster on `L` than on `RGBA`, and we'd + # discard the chroma anyway for the heatmap output. + if img.mode not in ("L", "I", "I;16", "F"): + img = img.convert("L") + + # Downsample to bound the response payload size. A 4K TIFF is 16M + # cells * ~6 bytes-per-cell JSON = ~100 MB response otherwise; the + # chat surface absolutely cannot ship that. 512x512 keeps the JSON + # under ~1.5 MB and renders crisply at the chat's column width. + downsampled = False + if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION: + img.thumbnail((MAX_DIMENSION, MAX_DIMENSION), Image.Resampling.LANCZOS) + downsampled = True + + try: + import numpy as np + arr = np.asarray(img, dtype=np.float32) + except Exception as e: + log.warning("image_service.numpy_convert_failed", error=str(e)) + return _image_error("decode", f"Failed to convert image to array: {e}") + + if arr.ndim != 2: + # Defensive — convert("L") above should always give a 2D result, + # but the `F` and `I;16` modes Pillow surfaces for scientific + # TIFFs can occasionally come through as something else. Flatten + # the leading dimensions to 2D so the heatmap still renders. + if arr.ndim == 3 and arr.shape[2] in (1, 3, 4): + arr = arr.mean(axis=2) + else: + return _image_error( + "decode", + f"Unexpected array shape after decode: {arr.shape}. " + "Expected 2D (height x width).", + ) + + # Min/max for Plotly's `zmin`/`zmax` colorscale anchoring. Computed + # on the float array (after the optional downscale) so the chart + # matches what Plotly actually renders. Use safe casts to plain + # Python floats — np.float32 isn't JSON-serializable in some + # FastAPI response shapes. + if arr.size == 0: + return _image_error("decode", "Image decoded to an empty array.") + arr_min = float(arr.min()) + arr_max = float(arr.max()) + + # 2D list-of-lists for the JSON response. Each row materializes once; + # `.tolist()` is the cheapest numpy → JSON-able path and Pillow's + # decode already paid the per-pixel cost so this is at most a copy. + data: list[list[float]] = arr.tolist() + + return { + "width": int(arr.shape[1]), + "height": int(arr.shape[0]), + "data": data, + "min": arr_min, + "max": arr_max, + "format": fmt, + "downsampled": downsampled, + "source": source, + } + + +def _image_error(error_kind: str, message: str) -> dict[str, Any]: + """Soft-error envelope — matches the BinaryService ``errorKind`` shape + so the router doesn't need to re-translate. + + `errorKind` is one of: "notfound", "decode", "unsupported". The LLM + is taught to surface these plainly without emitting a chart fence. + """ + return {"error": message, "errorKind": error_kind} + + +def _source_block( + document: dict[str, Any], *, filename: str | None, +) -> dict[str, Any]: + """Build the citation source block. Mirrors signal_service's shape + so the chat-side reference builder works uniformly across tools. + + Defensive against partial document shapes: every field can be None + without crashing the dict assembly. + """ + base = document.get("base", {}) if isinstance(document, dict) else {} + doc_class: str | None = None + if isinstance(document, dict): + cls = document.get("document_class") or {} + if isinstance(cls, dict): + doc_class = cls.get("classname") or cls.get("class_name") + # Bulk-fetch shape buries the class on top-level `className`. + doc_class = doc_class or document.get("className") + doc_name = None + if isinstance(base, dict): + doc_name = base.get("name") + return { + "dataset_id": document.get("datasetId", "") if isinstance(document, dict) else "", + "document_id": ( + document.get("id") or document.get("_id") or "" + if isinstance(document, dict) else "" + ), + "doc_class": doc_class, + "doc_name": doc_name, + "filename": filename, + } diff --git a/backend/services/ndi_python_service.py b/backend/services/ndi_python_service.py new file mode 100644 index 0000000..f4171a6 --- /dev/null +++ b/backend/services/ndi_python_service.py @@ -0,0 +1,385 @@ +"""ndi_python_service — thin wrappers over the three NDI-python entry points +we use in Phase A. + +Why a separate service? Two reasons: + +1. **Centralized lazy imports.** NDI-python (~150 MB resident if everything + loaded eagerly) is gated behind module-level functions that import on first + call. The rest of the backend doesn't pay the import cost until something + actually exercises an NDI path. + +2. **Consistent error envelope.** Every call returns either a typed Python + value (numpy array / dict) on success, or `None` (or a sentinel) on a + recoverable miss. None of these raise on miss — that's the contract our + callers in `binary_service` and `ontology_service` rely on so they can + fall through to their existing inline / external paths. + +The three entry points are documented in: +`docs/plans/2026-05-13-ndi-python-integration.md`. + +Phase B may layer a real `ndi.dataset.Dataset` here. Phase A intentionally +operates only on byte payloads + ID strings, no Dataset object. +""" + +from __future__ import annotations + +import contextlib +import tempfile +from pathlib import Path +from typing import Any, Literal + +from ..observability.logging import get_logger + +log = get_logger(__name__) + +# Lightweight import guard. We don't want the *import* of this module to +# pull in pandas / numpy / etc. — that's deferred to first call. The flag +# below caches the result of the first import attempt so subsequent calls +# pay nothing extra. `None` = not-yet-tried; `True` = imported OK; +# `False` = import failed (NDI stack not available; callers fall back). +_NDI_AVAILABLE: bool | None = None + +# Separate, optional flag for the Sprint 1.5 dataset binding (Dataset +# materialization via `ndi.cloud.orchestration.downloadDataset`). Even +# when the Phase A stack is happy, the dataset binding can fail +# independently (e.g. missing `ndi.dataset`, missing cloud helpers, +# the auto-client decorator can't find creds at module-import time). +# We probe it separately so callers of `is_ndi_available()` don't see +# a flap caused by the optional Sprint 1.5 surface. +_DATASET_BINDING_AVAILABLE: bool | None = None + + +def is_ndi_available() -> bool: + """Best-effort check that the NDI-python stack is importable. Caches + the result so health checks + first-request paths don't pay the import + cost more than once. + + Side-effect: also runs the Sprint 1.5 dataset-binding probe (see + :func:`is_dataset_binding_available`) so the boot log shows ONE entry + summarising both. The dataset binding is treated as a separate, + OPTIONAL capability — its failure must NOT mark the Phase A stack + unavailable. + """ + global _NDI_AVAILABLE # noqa: PLW0603 — module-level cache flag + if _NDI_AVAILABLE is not None: + return _NDI_AVAILABLE + try: + # We probe one module from each git-sourced package to make sure + # the full transitive surface is on PYTHONPATH. Errors here at + # boot time become clear startup failures rather than mysterious + # first-request 500s. + import ndi.ontology # noqa: F401 + import ndicompress # noqa: F401 + import vlt.file.custom_file_formats # noqa: F401 + _NDI_AVAILABLE = True + except ImportError as e: + log.warning("ndi_python_service.import_failed", error=str(e)) + _NDI_AVAILABLE = False + + # Probe the optional dataset binding even on Phase-A success — log + # both findings together for clarity in the boot log. + binding_ok = is_dataset_binding_available() + log.info( + "ndi_python_service.boot_probe", + phase_a=_NDI_AVAILABLE, + dataset_binding=binding_ok, + ) + return _NDI_AVAILABLE + + +def is_dataset_binding_available() -> bool: + """Best-effort check for the Sprint 1.5 cloud-backed dataset binding. + + Probes ``ndi.dataset`` and ``ndi.cloud.orchestration`` — the two + modules :mod:`backend.services.dataset_binding_service` reaches into. + A True result does NOT mean the binding will succeed at runtime + (cloud-node auth + network are required for that); it only means the + imports are wired and the service is safe to wire into the router. + """ + global _DATASET_BINDING_AVAILABLE # noqa: PLW0603 + if _DATASET_BINDING_AVAILABLE is not None: + return _DATASET_BINDING_AVAILABLE + try: + import ndi.cloud.orchestration + import ndi.dataset # noqa: F401 + _DATASET_BINDING_AVAILABLE = True + except ImportError as e: + log.warning( + "ndi_python_service.dataset_binding_import_failed", + error=str(e), + ) + _DATASET_BINDING_AVAILABLE = False + return _DATASET_BINDING_AVAILABLE + + +# --------------------------------------------------------------------------- +# VHSB — vlt.file.custom_file_formats.vhsb_read +# --------------------------------------------------------------------------- +# +# Important contract from the Phase A recon (see plan doc): +# - vhsb_read takes a FILE PATH (str), not bytes / BytesIO +# - It internally reopens the file with `open(filename, 'rb')` +# - There is ONLY ONE VHSB format. It always begins with a 200-byte +# ASCII tag ("This is a VHSB file, http://github.com/VH-Lab\n" zero- +# padded) followed by a 1836-byte binary header, then payload. +# - Returns `(y, x)` — numpy arrays of values and time-axis samples. +# +# We materialize the payload bytes to a NamedTemporaryFile, call +# vhsb_read, then unlink. The 200-byte text tag is what current +# binary_service.py treats as the `vlt_library` early-return — the +# whole point of Phase A is to actually decode it instead. + + +def read_vhsb_from_bytes( + payload: bytes, +) -> dict[str, Any] | None: + """Decode a VHSB binary payload via vlt.file. + + Returns a dict matching `binary_service._ts_shape_single_channel`'s + envelope on success (so callers can drop it directly into a + timeseries response), or `None` on failure so the caller can fall + back to inline parsing or surface a typed error. + + No raise; all failures log + return None. + """ + if not is_ndi_available(): + return None + if not payload or len(payload) < 2036: + # Minimum: 200 byte text-tag + 1836 byte header. Smaller payloads + # cannot possibly be valid VHSB. + return None + + tmp_path = None + try: + # vhsb_read needs a real on-disk path. Suffix matters: the helper + # doesn't sniff the file type from extension, but downstream + # logging is clearer if we keep it. + with tempfile.NamedTemporaryFile( + delete=False, suffix=".vhsb", prefix="ndb_vhsb_" + ) as fh: + fh.write(payload) + tmp_path = fh.name + + # Lazy-import inside the function so the import cost is paid only + # on the first VHSB decode (or never, if no one ever hits this). + import numpy as np + from vlt.file.custom_file_formats import vhsb_read, vhsb_readheader + + header = vhsb_readheader(tmp_path) + n_samples = int(header.get("num_samples", 0)) + if n_samples <= 0: + log.warning("vhsb_read.bad_header", header=header) + return None + + y, x = vhsb_read(tmp_path, 0, n_samples) + if y is None or len(y) == 0: + log.warning("vhsb_read.empty_payload", n_samples=n_samples) + return None + + # Translate to the existing envelope shape. y is the value array + # (possibly multi-dim if Y_dim > 1), x is the time axis. We + # flatten y to a single channel for now — multi-channel VHSB + # support is a future enhancement (binary_service's envelope + # naturally supports it, but the demo datasets are all 1-D). + sample_rate = float(header.get("X_increment", 0.0)) + # X_increment is seconds-per-sample. Convert to Hz, guarding + # against zero. + sample_rate_hz = (1.0 / sample_rate) if sample_rate > 0 else 0.0 + + # Flatten to 1-D if vhsb_read returned (N, 1) or (N,). + values = np.asarray(y).reshape(-1).astype(np.float32, copy=False) + + return { + "channels": {"ch0": _nan_to_none(values.tolist())}, + "timestamps": np.asarray(x).reshape(-1).astype(np.float64, copy=False).tolist(), + "sample_count": int(values.size), + "format": "vhsb", + "sample_rate_hz": sample_rate_hz, + "error": None, + } + except Exception as e: + # vhsb_read raises on type mismatch / bad sizes; treat all as soft + # errors so callers can fall back. + log.warning("vhsb_read.failed", error=str(e), error_type=type(e).__name__) + return None + finally: + if tmp_path is not None: + with contextlib.suppress(OSError): + Path(tmp_path).unlink(missing_ok=True) + + +def _nan_to_none(values: list[float]) -> list[float | None]: + """Replace NaN with None so the frontend's uPlot sees explicit gaps + rather than rendering through NaN-poisoned line segments. Matches the + `_to_nullable_list` convention in binary_service.""" + import math + out: list[float | None] = [] + for v in values: + if isinstance(v, float) and math.isnan(v): + out.append(None) + else: + out.append(float(v)) + return out + + +# --------------------------------------------------------------------------- +# NDI-compressed binaries — ndicompress.expand_* +# --------------------------------------------------------------------------- +# +# Phase A scope: detect + decompress only. Like vhsb_read, ndicompress +# operates on file paths (subprocess-based, wraps platform-specific C +# executables). Magic byte detection: +# - Outer wrapper is gzipped tar (.nbf.tgz) +# - One inner file has the extension `.nbh` and starts with the +# 15-byte ASCII string `b"NDIBINARYHEADER"` +# +# encode_method dispatch (per the recon): +# 1 = Ephys (analog input/output) — most common for us +# 2 = Metadata (JSON-like; rarely shown as timeseries) +# 21 = Digital (uint8 0/1 channels) +# 41 = EventMarkText (sparse markers) +# 61 = Time (time-only data; used as derived axis) +# +# We only auto-handle method 1 (Ephys) on the timeseries path; the +# others surface a typed soft-error and fall through to the existing +# code (which already handles raw .nbf). + + +def is_ndi_compressed(payload: bytes) -> bool: + """Cheap prefix check for NDI's `.nbf.tgz` wrapper. + + Doesn't validate the inner contents — that's the job of the actual + expand call. False positives here are fine because the expand path + will fail gracefully and the caller will fall back to inline parsing. + + A gzipped tar archive begins with two bytes `0x1f 0x8b` (gzip magic). + We don't fingerprint deeper than that — every gzip stream we'd see + in this context is going to be either NDI-compressed or something + legitimately broken; in either case the expand call will tell us. + """ + return len(payload) >= 2 and payload[0] == 0x1F and payload[1] == 0x8B + + +def expand_ephys_from_bytes(payload: bytes) -> dict[str, Any] | None: + """Decode an NDI-compressed Ephys payload (encode_method=1). + + Returns the same envelope shape as `read_vhsb_from_bytes` so it's a + drop-in replacement in `BinaryService.get_timeseries`. Multi-channel + ephys becomes `{"ch0": [...], "ch1": [...], ...}`. + + None on miss / wrong codec / errors. No raise. + """ + if not is_ndi_available(): + return None + if not is_ndi_compressed(payload): + return None + + tmp_path = None + try: + with tempfile.NamedTemporaryFile( + delete=False, suffix=".nbf.tgz", prefix="ndb_ndic_" + ) as fh: + fh.write(payload) + tmp_path = fh.name + + import ndicompress + import numpy as np + + # ndicompress.expand_ephys returns (np.ndarray[S, C], None). + arr, _ = ndicompress.expand_ephys(tmp_path) + if arr is None or arr.size == 0: + return None + + # Shape: (n_samples, n_channels). We don't have a sample rate + # from ndicompress's return (yet — the .nbh header has it but + # the wrapper doesn't surface it). Caller can post-process with + # the document's metadata if needed. + n_samples, n_channels = arr.shape if arr.ndim == 2 else (arr.size, 1) + if arr.ndim == 1: + arr = arr.reshape(-1, 1) + + channels: dict[str, list[float | None]] = {} + for c in range(n_channels): + channels[f"ch{c}"] = _nan_to_none(arr[:, c].astype(np.float32).tolist()) + + return { + "channels": channels, + "timestamps": list(range(n_samples)), # sample-index axis; caller may rescale + "sample_count": int(n_samples), + "format": "nbf_compressed", + "sample_rate_hz": 0.0, # unknown without sidecar metadata + "error": None, + } + except Exception as e: + log.warning( + "ndicompress.expand_ephys.failed", + error=str(e), + error_type=type(e).__name__, + ) + return None + finally: + if tmp_path is not None: + with contextlib.suppress(OSError): + Path(tmp_path).unlink(missing_ok=True) + + +# --------------------------------------------------------------------------- +# Ontology — ndi.ontology.lookup +# --------------------------------------------------------------------------- +# +# Phase A's ontology contribution: when our existing external-provider +# lookup misses, fall back to NDI's. NDI's `lookup` knows lab-specific +# terms (WBStrain, Cre lines, internal NDIC identifiers) that public +# providers don't. +# +# Critical contract: +# - Input is a single CURIE string (`"WBStrain:00000001"`, `"CL:0000540"`) +# - Output is an OntologyResult with truthy-on-hit, falsy-on-miss +# - Never raises (provider errors swallowed internally) +# - Has a small module-level FIFO cache (~100 entries) +# - Most non-NDIC prefixes hit OLS4 (EBI) via `requests.get`, 30s timeout +# +# We re-cache results in our own redis-backed `ontology_cache` so a hit +# survives process restart. NDI's internal cache is per-process only. + + +_OntologyLookupKind = Literal["hit", "miss"] + + +def lookup_ontology(curie: str) -> dict[str, Any] | None: + """Resolve an ontology CURIE via NDI-python's ontology service. + + Returns the OntologyResult's `.to_dict()` on hit, `None` on miss + (incl. malformed input, unknown prefix, provider error — all silent + in NDI's implementation, surfaced as None here). + + Callers in `ontology_service.py` should use this as a FALLBACK after + their existing external-provider lookup misses — NOT as the primary + path (NDI's lookup hits the same OLS4 endpoints for many ontologies, + so duplication would double network traffic for hits we'd see anyway). + """ + if not is_ndi_available(): + return None + if not curie or ":" not in curie: + return None + + try: + from ndi.ontology import lookup + result = lookup(curie) + if not result: # OntologyResult __bool__ returns True only on hit + return None + # to_dict() yields {id, name, prefix, definition, synonyms, short_name}. + # mypy sees `lookup` as `Any` (NDI-python has no stubs), so the cast + # is needed to keep the function's declared return type honest under + # strict mode. + result_dict: dict[str, Any] = dict(result.to_dict()) + return result_dict + except Exception as e: + # NDI's lookup is documented to never raise on misses, but defensive: + log.warning( + "ndi.ontology.lookup.failed", + curie=curie, + error=str(e), + error_type=type(e).__name__, + ) + return None diff --git a/backend/services/ontology_service.py b/backend/services/ontology_service.py index e4f4f02..22e32c2 100644 --- a/backend/services/ontology_service.py +++ b/backend/services/ontology_service.py @@ -2,6 +2,8 @@ from __future__ import annotations import asyncio +import html as _html +import re from typing import Any import httpx @@ -12,6 +14,26 @@ log = get_logger(__name__) +# WormBase page-title pattern. The canonical strain page returns a title +# like `` N2 (strain) - WormBase : Nematode Information Resource`` +# regardless of release (verified on WS294 via a Wayback snapshot; the +# template hasn't changed in years). We anchor on ``(strain)`` rather than +# just stripping the suffix so we don't accidentally pick up the +# "Just a moment..." Cloudflare interstitial as a strain name. +_WB_TITLE_RE = re.compile( + r"]*>\s*([^<(]+?)\s*\(strain\)\s*-\s*WormBase", + re.IGNORECASE | re.DOTALL, +) + +# Secondary parse target: the page-title breadcrumb +# ``

Strain » N2

``. Used when the +# ```` element is missing or mangled (older snapshots, partial +# loads), so the scrape still resolves on a degraded page. +_WB_BREADCRUMB_RE = re.compile( + r"<a[^>]*>\s*Strain\s*</a>\s*»\s*<span[^>]*>\s*([^<]+?)\s*</span>", + re.IGNORECASE | re.DOTALL, +) + class OntologyService: PROVIDERS = { @@ -42,16 +64,93 @@ async def lookup(self, term: str) -> OntologyTerm: if provider is None: raise OntologyLookupFailed("Term must be PROVIDER:ID, e.g. CL:0000540") cached = self.cache.get(provider, term_id) - if cached is not None: + # IMPORTANT — only return cached entries that are REAL hits. + # A "stub" cache entry (label=None AND definition=None) is what + # ``OntologyCache.get`` returns for terms that were previously + # looked up but came back empty. We DO NOT want to return such + # stubs here, because: + # + # 1. Phase A (2026-05-13) wired ``ndi.ontology.lookup`` as a + # fallback for lab-specific prefixes (WBStrain, NDIC, etc.) + # that the legacy providers couldn't resolve. Terms looked + # up BEFORE Phase A were cached as stubs. + # + # 2. ``ONTOLOGY_CACHE_TTL_DAYS`` defaults to 30, so those + # pre-Phase-A stubs live for ~a month — and short-circuit + # the NDI-python fallback every time the term resurfaces. + # + # By treating stubs as cache MISSES we let the lookup pipeline + # retry: existing providers (cheap; the OLS/SciCrunch/etc. + # calls have their own outbound throttling) AND the NDI-python + # fallback. On a successful resolution the new ``self.cache.set`` + # below OVERWRITES the stub — so each stuck stub heals on first + # use rather than waiting for the 30-day TTL to expire. + if cached is not None and (cached.label or cached.definition): return cached + fetched: OntologyTerm | None = None try: fetched = await self._fetch_from_provider(provider, term_id) except Exception as e: log.warning("ontology.fetch_failed", provider=provider, term_id=term_id, error=str(e)) - raise OntologyLookupFailed(f"Could not look up {term}") from e + # Don't raise yet — fall through to the NDI-python fallback, which + # knows lab-specific terms (NDIC, WBStrain, internal Cre lines) + # the existing providers may miss. + + # NDI-python fallback: only fire when existing path didn't yield a + # usable record (stub with no label/definition, OR raised above). + # This is a Phase A addition (2026-05-13) — see plan doc. Wrapped in + # to_thread because ndi.ontology.lookup is sync and uses `requests` + # internally, which would block the event loop if called directly. + if fetched is None or (not fetched.label and not fetched.definition): + ndi_term = await self._try_ndi_fallback(term, provider, term_id) + if ndi_term is not None: + self.cache.set(ndi_term) + return ndi_term + + if fetched is None: + # Both legacy AND NDI-python failed. Cache a stub so we don't + # hammer the upstream providers, but if we had a prior stub + # in cache just return it (avoid a redundant set). + if cached is None: + stub = OntologyTerm( + provider=provider, term_id=term_id, + label=None, definition=None, url=None, + ) + self.cache.set(stub) + return stub + return cached self.cache.set(fetched) return fetched + async def _try_ndi_fallback( + self, term: str, provider: str, term_id: str, + ) -> OntologyTerm | None: + """Probe NDI-python's bundled ontology lookup. Returns None on miss + (incl. NDI stack not installed, malformed input, unknown prefix). + + NDI's lookup hits the same OLS4 endpoints we do for many ontologies, + but it ALSO ships a local CSV for NDIC and has hand-curated providers + for WBStrain and a few others — that's where the additional hits + come from. Net: this fallback rarely fires but catches the long tail.""" + try: + from .ndi_python_service import lookup_ontology + result = await asyncio.to_thread(lookup_ontology, term) + except Exception as e: + log.warning("ontology.ndi_fallback_failed", term=term, error=str(e)) + return None + if result is None: + return None + # NDI's `.to_dict()` shape: {id, name, prefix, definition, synonyms, short_name}. + # Map onto our OntologyTerm. We preserve the original PROVIDER (case + # as it was passed in) so the cache key matches what the caller asked for. + return OntologyTerm( + provider=provider, + term_id=term_id, + label=result.get("name") or None, + definition=result.get("definition") or None, + url=None, + ) + async def batch_lookup(self, terms: list[str]) -> list[OntologyTerm]: unique = list(dict.fromkeys(t for t in terms if t)) results = await asyncio.gather(*[self._safe_lookup(t) for t in unique]) @@ -63,7 +162,21 @@ async def _safe_lookup(self, term: str) -> OntologyTerm | None: except OntologyLookupFailed: return None - _OLS_PROVIDERS = {"CL": "cl", "NCBITaxon": "ncbitaxon", "CHEBI": "chebi", "PATO": "pato", "EFO": "efo"} + # OLS-resolvable providers. UBERON was previously omitted (live + # check showed UBERON:0001870 returning label=null even though + # OLS has it as "frontal cortex"). GO and OBI added for similar + # completeness — these are all OBO ontologies hosted at the same + # EBI OLS4 endpoint with identical query semantics. + _OLS_PROVIDERS = { + "CL": "cl", + "NCBITaxon": "ncbitaxon", + "CHEBI": "chebi", + "PATO": "pato", + "EFO": "efo", + "UBERON": "uberon", + "GO": "go", + "OBI": "obi", + } async def _fetch_from_provider(self, provider: str, term_id: str) -> OntologyTerm: ols = self._OLS_PROVIDERS.get(provider) @@ -121,8 +234,51 @@ async def _fetch_scicrunch(self, rrid: str) -> OntologyTerm: return OntologyTerm(provider="RRID", term_id=rrid, label=None, definition=None, url=url) async def _fetch_wormbase(self, strain_id: str) -> OntologyTerm: + """Resolve a WBStrain CURIE to its human-readable strain name. + + NDI-python's WBStrain provider only returns a URL, not a label, so + we GET the canonical strain page and parse the strain name from + ``<title>`` (primary) or the page-title breadcrumb (secondary). + Any failure — Cloudflare interstitial, timeout, 404, parse miss — + falls through to ``label=None`` so the lookup pipeline degrades + cleanly rather than crashing. Cache layering upstream means each + strain page is hit at most once per TTL. + """ url = f"https://wormbase.org/species/c_elegans/strain/{strain_id}" - return OntologyTerm(provider="WBStrain", term_id=strain_id, label=strain_id, definition=None, url=url) + label = await self._scrape_wormbase_label(url) + return OntologyTerm( + provider="WBStrain", term_id=strain_id, + label=label, definition=None, url=url, + ) + + async def _scrape_wormbase_label(self, url: str) -> str | None: + """Fetch ``url`` and extract the strain name from the HTML. + + Total budget is 5 seconds — WormBase pages are small (~70 KB) but + Cloudflare can interpose. Returns ``None`` on any failure so the + caller can fall through; never raises. + """ + try: + r = await self._http.get(url, timeout=5.0) + except Exception as e: + log.warning("ontology.wormbase.fetch_failed", url=url, error=str(e)) + return None + if r.status_code != 200: + log.warning( + "ontology.wormbase.bad_status", + url=url, status=r.status_code, + ) + return None + body = r.text + m = _WB_TITLE_RE.search(body) or _WB_BREADCRUMB_RE.search(body) + if m is None: + return None + label = _html.unescape(m.group(1)).strip() + # Guard against empty captures and the Cloudflare "Just a moment" + # text leaking through despite the ``(strain)`` anchor. + if not label or label.lower().startswith("just a moment"): + return None + return label async def _fetch_pubchem(self, cid: str) -> OntologyTerm: url = f"https://pubchem.ncbi.nlm.nih.gov/compound/{cid}" diff --git a/backend/services/psth_service.py b/backend/services/psth_service.py new file mode 100644 index 0000000..3e21178 --- /dev/null +++ b/backend/services/psth_service.py @@ -0,0 +1,787 @@ +"""psth_service — peri-stimulus time histogram orchestration. + +PSTH is the canonical sensory-neuroscience visualization: align a unit's +spike train to a series of stimulus events, count spikes per fixed-width +bin in a [t0, t1] window around each event, then average across trials +to get a firing-rate estimate per bin. + +Endpoint strategy +───────────────── +The service does TWO doc fetches: + + 1. ``unit_doc_id`` — the vmspikesummary doc containing the spike + train. Same extraction path as + :mod:`backend.services.spike_summary_service` (probes + ``data.vmspikesummary.spike_times``, ``spiketimes``, + ``sample_times``); also probes for a separate binary file when + the JSON body doesn't carry inlined spike times. + 2. ``stimulus_doc_id`` — a stimulus_presentation OR stimulus_response + doc. Event timestamps live under different paths depending on the + NDI doc class; we try a few canonical locations in order: + · ``data.stimulus_presentation.presentations[*].time_started`` + · ``data.stimulus_response.responses[*].stim_time`` + · ``data.events`` (preprocessed top-level array) + · ``events`` (top-level fallback) + +Binning +─────── +We build the histogram with ``numpy.histogram`` over the merged set of +relative spike times across all trials. The bin layout is +``np.linspace(t0, t1, N_bins + 1)`` so the centers are deterministic +and the user can re-derive them client-side from ``t0``, ``t1``, and +``bin_size_ms``. + +Output caps (hard, server-side): + +* ``bin_size_ms >= 1`` (1 ms is the typical fine-grained PSTH bin) +* ``t1 - t0 <= 10`` seconds (PSTH analysis windows >10 s are unusual) +* ``N_bins <= 1000`` + +These mirror the spike-summary caps in spirit — keep response shapes +bounded so the chart layer doesn't choke and the chat tool can predict +payload size. + +Soft-error envelope +─────────────────── +The service surfaces problems via ``error`` + ``error_kind`` on the +response object rather than raising: + +* ``"decode_failed"`` — unit doc had no parseable spike-times array +* ``"no_events"`` — stimulus doc had no extractable event timestamps +* ``"empty_window"`` — events extracted but every window was empty + (still returns valid zero-counts arrays so the chart renders) + +The router translates the cloud-tier exceptions +(``CloudUnreachable``, ``CloudTimeout``, ``CloudInternalError``) into a +``"cloud_unavailable"`` envelope at the HTTP boundary. +""" +from __future__ import annotations + +import math +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field + +from ..auth.session import SessionData +from ..observability.logging import get_logger +from .binary_service import BinaryService +from .document_service import DocumentService + +log = get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Tunables — module-level so tests can monkeypatch and so the constants +# are reachable from tests without re-importing internals. +# --------------------------------------------------------------------------- + +# Default analysis window (seconds) around each stimulus event. +DEFAULT_T0 = -0.5 +DEFAULT_T1 = 1.5 + +# Default histogram bin width (milliseconds). 20 ms strikes a balance +# between rate-curve smoothness and temporal resolution for typical +# visual / somatosensory stimuli. +DEFAULT_BIN_SIZE_MS = 20.0 + +# Hard caps. ``bin_size_ms`` floor of 1 ms keeps the bin count bounded +# even for the maximum 10 s window (10000 / 1 = 10000 — we cap further +# at MAX_BINS). The 10-second window is enough for any typical +# stimulus response; longer-window analyses should use a different +# tool. +MIN_BIN_SIZE_MS = 1.0 +MAX_WINDOW_SECONDS = 10.0 +MAX_BINS = 1000 + +# Per-trial raster cap. The optional raster underneath the PSTH gets +# one array per trial; we cap the total returned spike count to keep +# the payload bounded. The PSTH histogram itself is computed on the +# UNCAPPED spike set so the rate-curve accuracy is preserved. +MAX_RASTER_SPIKES_TOTAL = 10_000 + + +# --------------------------------------------------------------------------- +# Pydantic request/response models +# --------------------------------------------------------------------------- + + +class PsthRequest(BaseModel): + """Input shape for ``POST /api/datasets/{id}/psth``. + + Aliases let the router accept camelCase from the TS chat proxy + (``unitDocId``, ``stimulusDocId``, etc.) without translation. + """ + + model_config = ConfigDict(populate_by_name=True, extra="ignore") + + unit_doc_id: str = Field(..., alias="unitDocId", min_length=1) + stimulus_doc_id: str = Field(..., alias="stimulusDocId", min_length=1) + t0: float = Field(default=DEFAULT_T0) + t1: float = Field(default=DEFAULT_T1) + bin_size_ms: float = Field(default=DEFAULT_BIN_SIZE_MS, alias="binSizeMs") + include_raster: bool = Field(default=False, alias="includeRaster") + title: str | None = Field(default=None, max_length=160) + + +class PsthResponse(BaseModel): + """Top-level PSTH response. + + ``bin_centers`` and ``counts`` / ``mean_rate_hz`` are parallel + arrays of length N_bins. ``per_trial_raster`` is included only when + the request set ``include_raster=True``; it's a list of N_trials + sublists, each holding the spike times for that trial expressed + relative to its event onset (i.e. ``spike_time - event_time``, + bounded to ``[t0, t1]``). + + ``error`` + ``error_kind`` populated for soft failures; consumers + branch on ``error_kind`` to render a friendly message rather than + a hard error boundary. + """ + + bin_centers: list[float] + counts: list[int] + mean_rate_hz: list[float] + n_trials: int + n_spikes: int + bin_size_ms: float + t0: float + t1: float + unit_name: str + unit_doc_id: str + stimulus_doc_id: str + per_trial_raster: list[list[float]] | None = None + error: str | None = None + error_kind: str | None = None + + +# --------------------------------------------------------------------------- +# Public orchestration entry point +# --------------------------------------------------------------------------- + + +async def compute_psth( + request: PsthRequest, + *, + document_service: DocumentService, + binary_service: BinaryService, + session: SessionData | None, + dataset_id: str, +) -> PsthResponse: + """Build a PSTH response for one unit + one stimulus doc. + + Parameters + ---------- + request: + Validated PSTH input (see :class:`PsthRequest`). + document_service: + Used to fetch the unit + stimulus doc bodies. + binary_service: + Used to decode the unit's binary file when spike times aren't + inlined in the JSON body. The same fallback path the + spike-summary service uses. + session: + Optional session — propagated as ``access_token`` so private + datasets work for logged-in users. + dataset_id: + From the URL path. Source of truth for routing. + """ + access_token = session.access_token if session else None + t0, t1, bin_size_ms, validation_error = _validate_window(request) + + # Bail early on validation failure — return a soft envelope rather + # than raising so the chat tool can surface a friendly explanation. + if validation_error is not None: + return _empty_response( + request, + unit_name="", + error=validation_error, + error_kind="invalid_window", + t0=t0, + t1=t1, + bin_size_ms=bin_size_ms, + ) + + unit_name, spike_times, unit_err = await _resolve_unit( + request, + document_service=document_service, + binary_service=binary_service, + dataset_id=dataset_id, + access_token=access_token, + ) + if unit_err is not None: + return _empty_response( + request, unit_name=unit_name, error=unit_err, + error_kind="decode_failed", + t0=t0, t1=t1, bin_size_ms=bin_size_ms, + ) + + events, events_err = await _resolve_events( + request, + document_service=document_service, + dataset_id=dataset_id, + access_token=access_token, + ) + if events_err is not None: + return _empty_response( + request, unit_name=unit_name, error=events_err, + error_kind="no_events", + t0=t0, t1=t1, bin_size_ms=bin_size_ms, + ) + + # --- Compute the histogram --- + bin_edges, bin_centers = _build_bin_arrays(t0, t1, bin_size_ms) + spike_arr = np.asarray(spike_times, dtype=np.float64) + + all_relative: list[float] = [] + per_trial_raster: list[list[float]] = [] + for event_t in events: + lo = event_t + t0 + hi = event_t + t1 + # Use boolean mask + slice — numpy.searchsorted would also work + # but the mask is clearer and the spike arrays are small enough + # that the extra alloc doesn't matter. + in_window = spike_arr[(spike_arr >= lo) & (spike_arr <= hi)] + relatives = (in_window - event_t).tolist() + all_relative.extend(relatives) + if request.include_raster: + per_trial_raster.append(relatives) + + n_trials = len(events) + bin_size_seconds = bin_size_ms / 1000.0 + + if all_relative: + counts_arr, _ = np.histogram( + np.asarray(all_relative, dtype=np.float64), + bins=bin_edges, + ) + else: + # Window emptied — still return the zero-counts arrays so the + # chart renders a flat trace. n_trials is still meaningful + # (events were found, they just had no spikes near them). + counts_arr = np.zeros(len(bin_centers), dtype=np.int64) + + counts = [int(c) for c in counts_arr.tolist()] + # Normalize: counts / (n_trials * bin_size_seconds) gives Hz. + # Guard against div-by-zero just in case (events list is non-empty + # here but defensive). + norm = n_trials * bin_size_seconds + mean_rate_hz = ( + [c / norm for c in counts] if norm > 0 else [0.0] * len(counts) + ) + + raster_field: list[list[float]] | None = None + if request.include_raster: + raster_field = _cap_raster(per_trial_raster, MAX_RASTER_SPIKES_TOTAL) + + error: str | None = None + error_kind: str | None = None + if not all_relative: + # Soft envelope — chart still renders but caller can surface a hint. + error = ( + f"No spikes fell within the [{t0:.3f}, {t1:.3f}] s window of " + f"any of the {n_trials} stimulus events" + ) + error_kind = "empty_window" + + return PsthResponse( + bin_centers=[float(c) for c in bin_centers.tolist()], + counts=counts, + mean_rate_hz=mean_rate_hz, + n_trials=n_trials, + n_spikes=len(all_relative), + bin_size_ms=bin_size_ms, + t0=t0, + t1=t1, + unit_name=unit_name, + unit_doc_id=request.unit_doc_id, + stimulus_doc_id=request.stimulus_doc_id, + per_trial_raster=raster_field, + error=error, + error_kind=error_kind, + ) + + +# --------------------------------------------------------------------------- +# Unit + event resolution — fetch + extract + soft-error mapping +# --------------------------------------------------------------------------- + + +async def _resolve_unit( + request: PsthRequest, + *, + document_service: DocumentService, + binary_service: BinaryService, + dataset_id: str, + access_token: str | None, +) -> tuple[str, list[float], str | None]: + """Resolve the unit doc + extract the spike-times array. + + Returns ``(unit_name, spike_times, error_message_or_none)``. Empty + string + empty list on hard fetch failure; populated tuple + + error message on extraction failure; empty error_message on + success. + """ + try: + unit_doc = await document_service.detail( + dataset_id, request.unit_doc_id, access_token=access_token, + ) + except Exception as exc: + log.warning( + "psth.unit_doc_fetch_failed", + dataset_id=dataset_id, + unit_doc_id=request.unit_doc_id, + error=str(exc), + error_type=type(exc).__name__, + ) + return "", [], ( + f"Could not fetch unit document {request.unit_doc_id}: {exc}" + ) + + unit_name = _pick_unit_name(unit_doc, request.unit_doc_id) + spike_times = _extract_spike_times_from_doc(unit_doc) + if not spike_times: + # Try the binary-file fallback. Most vmspikesummary docs inline + # the spike-times array in JSON; some have a separate binary + # file. We probe the binary path only when JSON extraction + # returned nothing so the cheap path stays cheap. + spike_times = await _extract_spike_times_from_binary( + unit_doc, binary_service, access_token=access_token, + ) + if not spike_times: + return unit_name, [], ( + "vmspikesummary doc had no parseable spike_times array " + "(checked data.vmspikesummary.{spike_times, spiketimes, " + "sample_times} and binary-file fallback)" + ) + return unit_name, spike_times, None + + +async def _resolve_events( + request: PsthRequest, + *, + document_service: DocumentService, + dataset_id: str, + access_token: str | None, +) -> tuple[list[float], str | None]: + """Resolve the stimulus doc + extract its event-time array. + + Returns ``(events, error_message_or_none)``. Empty list + error + on any soft failure; populated list + None on success. + """ + try: + stim_doc = await document_service.detail( + dataset_id, request.stimulus_doc_id, access_token=access_token, + ) + except Exception as exc: + log.warning( + "psth.stimulus_doc_fetch_failed", + dataset_id=dataset_id, + stimulus_doc_id=request.stimulus_doc_id, + error=str(exc), + error_type=type(exc).__name__, + ) + return [], ( + f"Could not fetch stimulus document {request.stimulus_doc_id}: {exc}" + ) + + events = _extract_event_times(stim_doc) + if not events: + return [], ( + "stimulus document had no extractable event timestamps " + "(checked data.stimulus_presentation.presentations[*].time_started, " + "data.stimulus_response.responses[*].stim_time, " + "data.events, and top-level events)" + ) + return events, None + + +# --------------------------------------------------------------------------- +# Validation + bin layout +# --------------------------------------------------------------------------- + + +def _validate_window( + request: PsthRequest, +) -> tuple[float, float, float, str | None]: + """Validate the [t0, t1] window + bin_size_ms. + + Returns ``(t0, t1, bin_size_ms, error_or_none)``. When the error is + non-None the caller bails with a soft envelope. The values are + returned even on failure so the envelope can echo what the caller + asked for (useful in tests + caller diagnostics). + """ + t0 = float(request.t0) + t1 = float(request.t1) + bin_size_ms = float(request.bin_size_ms) + + if not (np.isfinite(t0) and np.isfinite(t1) and np.isfinite(bin_size_ms)): + return t0, t1, bin_size_ms, ( + "t0, t1, and bin_size_ms must all be finite numbers" + ) + if t1 <= t0: + return t0, t1, bin_size_ms, ( + f"t1 ({t1}) must be greater than t0 ({t0})" + ) + if (t1 - t0) > MAX_WINDOW_SECONDS: + return t0, t1, bin_size_ms, ( + f"Window ({t1 - t0:.3f} s) exceeds the maximum allowed " + f"({MAX_WINDOW_SECONDS} s)" + ) + if bin_size_ms < MIN_BIN_SIZE_MS: + return t0, t1, bin_size_ms, ( + f"bin_size_ms ({bin_size_ms}) is below the minimum " + f"({MIN_BIN_SIZE_MS} ms)" + ) + # Estimate bin count to enforce the MAX_BINS cap. + n_bins_est = round((t1 - t0) * 1000.0 / bin_size_ms) + if n_bins_est > MAX_BINS: + return t0, t1, bin_size_ms, ( + f"Bin count ({n_bins_est}) exceeds the maximum ({MAX_BINS}); " + f"increase bin_size_ms or narrow [t0, t1]" + ) + return t0, t1, bin_size_ms, None + + +def _build_bin_arrays( + t0: float, t1: float, bin_size_ms: float, +) -> tuple[np.ndarray, np.ndarray]: + """Return ``(bin_edges, bin_centers)`` for the histogram. + + The bin count is ``round((t1-t0)*1000 / bin_size_ms)`` so the bin + width matches the request as closely as integer-bin layout allows. + Edges are ``np.linspace(t0, t1, n_bins+1)``; centers are the + midpoints. + """ + n_bins = max(1, round((t1 - t0) * 1000.0 / bin_size_ms)) + edges = np.linspace(t0, t1, n_bins + 1) + centers = 0.5 * (edges[:-1] + edges[1:]) + return edges, centers + + +# --------------------------------------------------------------------------- +# Doc-body extraction — spike times +# --------------------------------------------------------------------------- + + +def _extract_spike_times_from_doc(doc: dict[str, Any]) -> list[float] | None: + """Extract inlined spike times from a vmspikesummary doc's JSON body. + + Mirrors the field-probe order in + :mod:`backend.services.spike_summary_service` so behaviour stays + consistent across the two services. + + Returns None when no array of numbers is found at any candidate + path. Non-numeric entries are skipped silently (matches the TS + handler); a doc with mixed-type entries returns the numeric subset. + """ + if not isinstance(doc, dict): + return None + data = doc.get("data") + if not isinstance(data, dict): + return None + inner = data.get("vmspikesummary") + if not isinstance(inner, dict): + return None + for key in ("spike_times", "spiketimes", "sample_times"): + v = inner.get(key) + if not isinstance(v, list) or not v: + continue + nums = _coerce_numeric_list(v) + if nums: + return nums + return None + + +async def _extract_spike_times_from_binary( + doc: dict[str, Any], + binary_service: BinaryService, + *, + access_token: str | None, +) -> list[float] | None: + """Try the binary-file fallback when JSON extraction returned nothing. + + Some vmspikesummary docs carry their spike data as a separate + binary file; for those, the same :meth:`BinaryService.get_timeseries` + pipeline used by /signal can produce the channel arrays. We treat + the first channel's timestamps as the spike times (a single-channel + binary in this context is canonically a spike-time series). + + Returns None on any soft failure — caller surfaces the + decode_failed envelope. + """ + try: + ts = await binary_service.get_timeseries(doc, access_token=access_token) + except Exception as exc: + log.warning( + "psth.binary_fallback_failed", + error=str(exc), + error_type=type(exc).__name__, + ) + return None + if not isinstance(ts, dict) or ts.get("error"): + return None + timestamps = ts.get("timestamps") + if not isinstance(timestamps, list) or not timestamps: + return None + return _coerce_numeric_list(timestamps) + + +def _coerce_numeric_list(values: list[Any]) -> list[float]: + """Defensive numeric coerce — matches the spike-summary helper.""" + nums: list[float] = [] + for x in values: + if isinstance(x, bool): + # bool is a subclass of int; explicitly skip so True/False + # don't slip through as 1.0/0.0. + continue + if isinstance(x, (int, float)): + fx = float(x) + if _is_finite(fx): + nums.append(fx) + elif isinstance(x, str): + try: + parsed = float(x) + except (TypeError, ValueError): + continue + if _is_finite(parsed): + nums.append(parsed) + return nums + + +def _is_finite(v: float) -> bool: + """True iff ``v`` is a finite float — NaN/inf rejected. Wraps + :func:`math.isfinite` so callers can pass either int or float + without an explicit cast (math.isfinite accepts both). + """ + return math.isfinite(v) + + +def _pick_unit_name(doc: dict[str, Any], doc_id: str) -> str: + """Prefer ``data.vmspikesummary.name``, then top-level ``name``, + then a synthesized name from the doc ID tail. + """ + if isinstance(doc, dict): + data = doc.get("data") + if isinstance(data, dict): + inner = data.get("vmspikesummary") + if isinstance(inner, dict): + n = inner.get("name") + if isinstance(n, str) and n: + return n[:80] + top = doc.get("name") + if isinstance(top, str) and top: + return top[:80] + return f"Unit {doc_id[-6:]}" if doc_id else "Unit" + + +# --------------------------------------------------------------------------- +# Doc-body extraction — stimulus event timestamps +# --------------------------------------------------------------------------- + + +def _extract_event_times(doc: dict[str, Any]) -> list[float]: + """Extract event timestamps from a stimulus document. + + Probe order (canonical NDI doc-class paths first, preprocessed + arrays last): + + 1. ``data.stimulus_presentation.presentations[*].time_started`` + — the standard ``stimulus_presentation`` doc class. Each entry + in ``presentations`` represents one trial; ``time_started`` is + the onset in seconds. + 2. ``data.stimulus_response.responses[*].stim_time`` + — the ``stimulus_response`` doc class. ``stim_time`` is the + per-trial stimulus onset. + 3. ``data.events`` (list of floats / list of dicts with ``time`` + or ``t``) — preprocessed top-level array. + 4. ``events`` (top-level fallback) — same shape as #3 but at the + doc root. + + Returns an empty list when no candidate path yields numeric values; + caller surfaces a ``"no_events"`` envelope. Non-numeric entries are + silently skipped. + """ + if not isinstance(doc, dict): + return [] + data = doc.get("data") + + # Path 1: stimulus_presentation.presentations[*].time_started + if isinstance(data, dict): + sp = data.get("stimulus_presentation") + if isinstance(sp, dict): + presentations = sp.get("presentations") + times = _times_from_event_list(presentations, ("time_started", "time", "t")) + if times: + return times + + # Path 2: stimulus_response.responses[*].stim_time + sr = data.get("stimulus_response") + if isinstance(sr, dict): + responses = sr.get("responses") + times = _times_from_event_list(responses, ("stim_time", "time", "t")) + if times: + return times + + # Path 3: data.events (preprocessed; can be list-of-floats or list-of-dicts) + ev = data.get("events") + times = _times_from_event_list(ev, ("time", "t", "time_started", "stim_time")) + if times: + return times + + # Path 4: top-level events fallback + top_ev = doc.get("events") + times = _times_from_event_list(top_ev, ("time", "t", "time_started", "stim_time")) + if times: + return times + + return [] + + +def _times_from_event_list( + items: Any, + keys: tuple[str, ...], +) -> list[float]: + """Walk an events-style list, extracting numeric timestamps. + + Accepts either: + - ``list[float|int]`` — raw timestamps; coerce + filter finite. + - ``list[dict]`` — each entry contributes the value at the first + present key in ``keys``. + + Returns an empty list when ``items`` is not a list or yields no + numerics. + """ + if not isinstance(items, list) or not items: + return [] + out: list[float] = [] + for entry in items: + if isinstance(entry, dict): + v = _first_numeric_from_dict(entry, keys) + else: + v = _coerce_scalar(entry) + if v is not None: + out.append(v) + return out + + +def _coerce_scalar(entry: Any) -> float | None: + """Coerce a scalar entry to a finite float; return None when not numeric. + + ``bool`` is rejected explicitly (subclass of int in Python). Strings + parseable as floats are accepted so doc bodies that round-trip + through JSON-as-strings still work. + """ + if isinstance(entry, bool): + return None + if isinstance(entry, (int, float)): + fx = float(entry) + return fx if _is_finite(fx) else None + if isinstance(entry, str): + try: + parsed = float(entry) + except (TypeError, ValueError): + return None + return parsed if _is_finite(parsed) else None + return None + + +def _first_numeric_from_dict( + entry: dict[str, Any], keys: tuple[str, ...], +) -> float | None: + """Return the first key in ``keys`` whose value coerces to a finite + float, or None when nothing matched. + """ + for key in keys: + v = _coerce_scalar(entry.get(key)) + if v is not None: + return v + return None + + +# --------------------------------------------------------------------------- +# Response builders +# --------------------------------------------------------------------------- + + +def _empty_response( + request: PsthRequest, + *, + unit_name: str, + error: str, + error_kind: str, + t0: float, + t1: float, + bin_size_ms: float, +) -> PsthResponse: + """Build a soft-error PsthResponse with empty histogram arrays. + + We still return valid (zero-length) bin arrays so the chart layer + can render a clean empty state without branching on response + shape. The ``error`` + ``error_kind`` carry the diagnostic. + """ + return PsthResponse( + bin_centers=[], + counts=[], + mean_rate_hz=[], + n_trials=0, + n_spikes=0, + bin_size_ms=bin_size_ms, + t0=t0, + t1=t1, + unit_name=unit_name, + unit_doc_id=request.unit_doc_id, + stimulus_doc_id=request.stimulus_doc_id, + per_trial_raster=None, + error=error, + error_kind=error_kind, + ) + + +def _cap_raster( + per_trial: list[list[float]], + total_cap: int, +) -> list[list[float]]: + """Cap the total spike count across the per-trial raster. + + If the raw raster is already under the cap we return it verbatim. + Otherwise we stride-sample each trial proportionally so the trial + structure is preserved (callers branch on ``len(per_trial_raster)`` + for the trial count). + + The cap is total spikes across ALL trials, not per-trial. A 50- + trial recording with 1000 spikes/trial = 50k total → over the + default 10k cap → each trial gets stride-sampled to ~200 spikes. + """ + total = sum(len(t) for t in per_trial) + if total <= total_cap: + return per_trial + if total == 0: + return per_trial + ratio = total_cap / total + out: list[list[float]] = [] + for trial in per_trial: + n = len(trial) + if n == 0: + out.append([]) + continue + keep = max(1, int(n * ratio)) + if keep >= n: + out.append(list(trial)) + continue + # Stride-sample preserving first + last so the trial's onset + # + offset spikes survive. + if keep <= 2: + out.append([trial[0], trial[-1]][:keep]) + continue + step = (n - 1) / (keep - 1) + seen: set[int] = set() + picked: list[float] = [] + for i in range(keep): + idx = round(i * step) + if idx in seen: + continue + seen.add(idx) + picked.append(trial[idx]) + out.append(picked) + return out diff --git a/backend/services/spike_summary_service.py b/backend/services/spike_summary_service.py new file mode 100644 index 0000000..2a51322 --- /dev/null +++ b/backend/services/spike_summary_service.py @@ -0,0 +1,499 @@ +"""spike_summary_service — pull per-unit spike trains from +``vmspikesummary`` documents and shape them for spike-raster and/or +ISI histogram rendering. + +This is the Python port of the chat-side TS handler at +``ndi-cloud-app/apps/web/lib/ndi/tools/fetch-spike-summary.ts``. +Moving the orchestration to Railway keeps the heart of NDI processing +next to ndi-python where it belongs; the TS handler shrinks to a thin +proxy after this lands. + +Discovery — three modes, cheapest first: + + 1. ``unit_doc_id`` — direct fetch of a single vmspikesummary doc. + Cheapest path; used when the caller has already resolved which + unit it wants (chained from a query). + 2. ``unit_name_match`` — substring filter against the doc's + ``vmspikesummary.name`` field. Hits ``/ndiquery`` with a + two-clause structured query. + 3. Bare dataset scan — first N vmspikesummary docs in the dataset. + Use for "show me a raster from dataset X". + +Spike-times path +──────────────── +The TS implementation extracts ``spike_times`` directly from the +document's JSON body (``data.vmspikesummary.spike_times`` with +fallbacks to ``spiketimes`` and ``sample_times``). vmspikesummary +docs inline their spike data in the JSON; there is no separate +binary file to open. We preserve that canonical path here. + +Caller-facing differences vs the TS implementation +────────────────────────────────────────────────── +The router returns RAW per-unit data +(``{units: [{name, doc_id, spike_times, isi_intervals}], ...}``) +NOT the chat-specific ``chart_payloads`` wrapper. The TS layer +reshapes raw data into chart_payloads on the chat side; the +workspace consumes raw data directly. This keeps the backend +agnostic to UI framing. + +Soft-error envelope +─────────────────── +When a document is found but its spike-times array is missing or +unparseable, we surface a per-unit ``{error, error_kind: +'decode_failed'}`` rather than crashing the whole request. The +``error_kind`` taxonomy mirrors the existing /signal route so the +chat tool / workspace can branch on it. +""" +from __future__ import annotations + +import math +from typing import Any, Literal + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field + +from ..auth.session import SessionData +from ..clients.ndi_cloud import NdiCloudClient +from ..observability.logging import get_logger +from .document_service import DocumentService + +log = get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Tunables — module-level so tests can monkeypatch and so the constants +# are reachable from tests without re-importing internals. +# --------------------------------------------------------------------------- + +# Server-side cap on per-call unit count. Mirrors the TS handler's +# MAX_UNITS_HARD. The chart components also cap (SpikeRaster at 50) but +# the right place to enforce is here so we never download more than we'll +# render. +MAX_UNITS_HARD = 50 +DEFAULT_MAX_UNITS = 10 + +# Per-unit spike-times cap. Mirrors the TS handler's stride-sample limit +# of 500. Plotly comfortably renders this density and the visual shape is +# preserved for any reasonable spike train. The full spike list is used +# for ISI computation BEFORE this cap is applied so the histogram +# remains statistically accurate. +MAX_SPIKES_PER_UNIT = 5000 + +# Per-unit ISI-intervals cap. The TS handler caps the consolidated +# payload at 5000 (across all units) but our raw-data shape returns +# per-unit arrays, so the cap is applied per-unit. +MAX_ISI_INTERVALS_PER_UNIT = 5000 + + +# --------------------------------------------------------------------------- +# Pydantic request/response models. +# +# Field aliases let the router accept either camelCase (TS proxy passing +# through its existing input) or snake_case body keys without the caller +# having to translate. +# --------------------------------------------------------------------------- + + +SpikeKind = Literal["raster", "isi_histogram", "both"] + + +class SpikeSummaryRequest(BaseModel): + """Input shape mirrors the TS ``fetchSpikeSummaryInput`` schema + so the TS handler can pass its input through verbatim. + """ + + model_config = ConfigDict(populate_by_name=True, extra="ignore") + + dataset_id: str = Field(..., alias="datasetId", min_length=1) + unit_doc_id: str | None = Field(default=None, alias="unitDocId", min_length=1) + unit_name_match: str | None = Field( + default=None, alias="unitNameMatch", min_length=1, + ) + kind: SpikeKind = "both" + t_window: tuple[float, float] | None = Field(default=None, alias="tWindow") + max_units: int | None = Field( + default=None, alias="maxUnits", ge=1, le=MAX_UNITS_HARD, + ) + title: str | None = Field(default=None, max_length=160) + + +class SpikeSummaryUnit(BaseModel): + """One unit's contribution to the response. + + ``spike_times`` is included when ``kind`` is ``raster`` or ``both``. + ``isi_intervals`` is included when ``kind`` is ``isi_histogram`` + or ``both``. Both are absent when the unit's binary decode failed + (``error`` populated instead). + """ + + name: str + doc_id: str + spike_times: list[float] | None = None + isi_intervals: list[float] | None = None + # When set, the unit's spike-times array was unparseable. The unit + # is still included in `units` so callers see a placeholder + the + # decode reason; soft-error envelope matches the /signal route. + error: str | None = None + error_kind: str | None = None + + +class SpikeSummaryResponse(BaseModel): + """Top-level response. ``total_matching`` is the count BEFORE the + ``max_units`` slice — callers can disclose "showed 10 of N" when + truncated. + """ + + units: list[SpikeSummaryUnit] + total_matching: int + kind: SpikeKind + # Diagnostic — populated when no units matched / decoded so the + # caller can explain or retry. Empty-string ``error`` is reserved + # for "no failure"; consumers should check ``len(units)``. + error: str | None = None + error_kind: str | None = None + + +# --------------------------------------------------------------------------- +# Public orchestration entry point +# --------------------------------------------------------------------------- + + +async def compute_spike_summary( + request: SpikeSummaryRequest, + *, + document_service: DocumentService, + cloud: NdiCloudClient, + session: SessionData | None, +) -> SpikeSummaryResponse: + """Orchestrate vmspikesummary discovery + per-unit spike-train + extraction. + + Parameters + ---------- + request: + Validated input (see :class:`SpikeSummaryRequest`). + document_service: + Used for the ``unit_doc_id`` single-doc fetch path. The detail + endpoint handles ndiId-vs-Mongo-id resolution. + cloud: + Used directly for the ``unit_name_match`` + bare-scan + ndiquery calls. We bypass ``QueryService`` here because its + scope-validator enforces a Mongo-ObjectId regex that's + redundant with the path validator and would reject the + free-form dataset IDs the rest of the stack accepts. + session: + Optional session — propagated as ``access_token`` so private + datasets work for logged-in users while public datasets work + anonymously. + """ + access_token = session.access_token if session else None + max_units = min(request.max_units or DEFAULT_MAX_UNITS, MAX_UNITS_HARD) + + docs, total_matching = await _resolve_units( + request, + document_service=document_service, + cloud=cloud, + access_token=access_token, + max_units=max_units, + ) + + if not docs: + return SpikeSummaryResponse( + units=[], + total_matching=0, + kind=request.kind, + error=_empty_reason(request), + error_kind="no_matches", + ) + + units: list[SpikeSummaryUnit] = [] + for doc in docs: + doc_id = _pick_doc_id(doc) + name = _pick_unit_name(doc, doc_id) + raw_spikes = _extract_spike_times(doc) + if raw_spikes is None or len(raw_spikes) == 0: + # Soft error per doc — same envelope as /signal so the + # chat tool can branch on `error_kind`. The doc is kept + # in the response so the caller sees which unit failed. + units.append( + SpikeSummaryUnit( + name=name, + doc_id=doc_id, + error=( + "vmspikesummary doc had no parseable spike_times " + "array (checked data.vmspikesummary.spike_times, " + "spiketimes, sample_times)" + ), + error_kind="decode_failed", + ), + ) + continue + + # t_window filter — done BEFORE the spike-count cap so the cap + # bounds the rendered density, not the unfiltered density. + spikes = _apply_t_window(raw_spikes, request.t_window) + if len(spikes) == 0: + # Window emptied the unit. Skip silently — the unit isn't + # "failed" per se, just outside the requested window. + continue + + spike_times = _build_spike_field(spikes, request.kind) + isi_intervals = _build_isi_field(spikes, request.kind) + units.append( + SpikeSummaryUnit( + name=name, + doc_id=doc_id, + spike_times=spike_times, + isi_intervals=isi_intervals, + ), + ) + + # Stable name-order so the response is deterministic for callers + # iterating in display order. ``unit_doc_id`` (single-doc path) + # produces a one-element list so this is a no-op there. + units.sort(key=lambda u: u.name.lower()) + + return SpikeSummaryResponse( + units=units, + total_matching=total_matching, + kind=request.kind, + ) + + +# --------------------------------------------------------------------------- +# Discovery helpers +# --------------------------------------------------------------------------- + + +async def _resolve_units( + request: SpikeSummaryRequest, + *, + document_service: DocumentService, + cloud: NdiCloudClient, + access_token: str | None, + max_units: int, +) -> tuple[list[dict[str, Any]], int]: + """Return ``(docs, total_matching)``. + + Three modes (mirrors the TS handler): + 1. ``unit_doc_id`` — single-doc fetch (one doc, total=1). + 2. ``unit_name_match`` — ndiquery with ``isa(vmspikesummary)`` + + ``contains_string(vmspikesummary.name, <substr>)``. + 3. Bare scan — ndiquery with just ``isa(vmspikesummary)``. + """ + if request.unit_doc_id: + try: + doc = await document_service.detail( + request.dataset_id, + request.unit_doc_id, + access_token=access_token, + ) + except Exception as exc: + log.warning( + "spike_summary.single_doc_fetch_failed", + dataset_id=request.dataset_id, + doc_id=request.unit_doc_id, + error=str(exc), + error_type=type(exc).__name__, + ) + return ([], 0) + return ([doc], 1) + + searchstructure: list[dict[str, Any]] = [ + {"operation": "isa", "param1": "vmspikesummary"}, + ] + if request.unit_name_match: + searchstructure.append({ + "operation": "contains_string", + "field": "vmspikesummary.name", + "param1": request.unit_name_match, + }) + try: + body = await cloud.ndiquery( + searchstructure=searchstructure, + scope=request.dataset_id, + access_token=access_token, + ) + except Exception as exc: + log.warning( + "spike_summary.query_failed", + dataset_id=request.dataset_id, + unit_name_match=request.unit_name_match, + error=str(exc), + error_type=type(exc).__name__, + ) + return ([], 0) + docs = list(body.get("documents") or []) + total = len(docs) + return (docs[:max_units], total) + + +def _empty_reason(request: SpikeSummaryRequest) -> str: + if request.unit_doc_id: + return ( + f"No vmspikesummary document {request.unit_doc_id} " + f"in dataset {request.dataset_id}" + ) + if request.unit_name_match: + return ( + f"No vmspikesummary documents matched " + f"name~\"{request.unit_name_match}\" in dataset " + f"{request.dataset_id}" + ) + return f"No vmspikesummary documents in dataset {request.dataset_id}" + + +# --------------------------------------------------------------------------- +# Field extraction — field-path probe order mirrors the TS handler so +# behavior stays consistent across the two implementations. +# --------------------------------------------------------------------------- + + +def _extract_spike_times(doc: dict[str, Any]) -> list[float] | None: + """Extract the spike-times array from a vmspikesummary doc body. + + Probe order (most-likely → least-likely): + 1. ``data.vmspikesummary.spike_times`` + 2. ``data.vmspikesummary.spiketimes`` + 3. ``data.vmspikesummary.sample_times`` ← the schema-canonical name + + Returns None when no array of numbers is found at any candidate + path. Caller handles the empty case by surfacing a per-unit soft + error. + + Non-numeric entries are skipped silently (matches the TS handler); + a doc with mixed-type entries returns the numeric subset. + """ + data = doc.get("data") if isinstance(doc, dict) else None + if not isinstance(data, dict): + return None + inner = data.get("vmspikesummary") + if not isinstance(inner, dict): + return None + for key in ("spike_times", "spiketimes", "sample_times"): + v = inner.get(key) + if not isinstance(v, list) or not v: + continue + nums: list[float] = [] + for x in v: + if isinstance(x, (int, float)) and not isinstance(x, bool): + # Guard against NaN/inf which would poison downstream + # math; matches the TS handler's Number.isFinite check. + fx = float(x) + if _is_finite(fx): + nums.append(fx) + elif isinstance(x, str): + try: + parsed = float(x) + except (TypeError, ValueError): + continue + if _is_finite(parsed): + nums.append(parsed) + if nums: + return nums + return None + + +def _is_finite(v: float) -> bool: + return math.isfinite(v) + + +def _pick_doc_id(doc: dict[str, Any]) -> str: + for key in ("id", "_id", "ndiId"): + v = doc.get(key) + if isinstance(v, str) and v: + return v + return "" + + +def _pick_unit_name(doc: dict[str, Any], doc_id: str) -> str: + """Prefer ``data.vmspikesummary.name``, then top-level ``name``, + then a synthesized name from the doc ID tail. + """ + data = doc.get("data") + if isinstance(data, dict): + inner = data.get("vmspikesummary") + if isinstance(inner, dict): + n = inner.get("name") + if isinstance(n, str) and n: + return n[:80] + top = doc.get("name") + if isinstance(top, str) and top: + return top[:80] + return f"Unit {doc_id[-6:]}" if doc_id else "Unit" + + +# --------------------------------------------------------------------------- +# Per-unit computation — t_window filter, stride-sample, ISI compute. +# Pure functions kept module-level so they're trivially unit-testable. +# --------------------------------------------------------------------------- + + +def _apply_t_window( + spikes: list[float], window: tuple[float, float] | None, +) -> list[float]: + if window is None: + return spikes + t0, t1 = window + return [t for t in spikes if t0 <= t <= t1] + + +def _build_spike_field( + spikes: list[float], kind: SpikeKind, +) -> list[float] | None: + """Cap + return the spike-times list when ``kind`` requests it, + None otherwise. ``kind == 'isi_histogram'`` omits the field so the + response stays compact for histogram-only callers. + """ + if kind == "isi_histogram": + return None + return _stride_sample(spikes, MAX_SPIKES_PER_UNIT) + + +def _build_isi_field( + spikes: list[float], kind: SpikeKind, +) -> list[float] | None: + """Compute ISI intervals in MILLISECONDS from the FULL spike-times + list (not the capped one) so the histogram's statistical + accuracy is preserved. Then stride-sample the intervals before + returning to bound wire size. + + Returns None when ``kind == 'raster'`` so raster-only callers get + a compact response. + """ + if kind == "raster": + return None + if len(spikes) < 2: + return [] + sorted_spikes = np.sort(np.asarray(spikes, dtype=np.float64)) + diffs_ms = np.diff(sorted_spikes) * 1000.0 + # Drop non-finite / non-positive intervals — matches the TS + # handler's defensive filter. Spike times sorted ascending means + # diff is always >= 0 but a duplicate timestamp produces 0 which + # is meaningless for an ISI histogram. + intervals = [float(d) for d in diffs_ms.tolist() if _is_finite(d) and d > 0] + return _stride_sample(intervals, MAX_ISI_INTERVALS_PER_UNIT) + + +def _stride_sample(values: list[float], cap: int) -> list[float]: + """Stride-sample down to ``cap`` entries preserving first + last. + + Mirrors :func:`backend.services.tabular_query_service._stride_sample` + (and the TS handler's ``strideSample``). When ``len(values) <= + cap`` returns a copy. + """ + n = len(values) + if n <= cap: + return list(values) + if cap <= 2: + return [values[0], values[-1]][:cap] + step = (n - 1) / (cap - 1) + seen: set[int] = set() + out: list[float] = [] + for i in range(cap): + idx = round(i * step) + if idx in seen: + continue + seen.add(idx) + out.append(values[idx]) + return out diff --git a/backend/services/summary_table_service.py b/backend/services/summary_table_service.py index e451e49..5d5a7db 100644 --- a/backend/services/summary_table_service.py +++ b/backend/services/summary_table_service.py @@ -44,6 +44,7 @@ from ..clients.ndi_cloud import BULK_FETCH_MAX, NdiCloudClient from ..observability.logging import get_logger from ..observability.metrics import table_build_duration_seconds +from .class_aliases import CLASS_ALIASES log = get_logger(__name__) @@ -53,6 +54,46 @@ # 6 comfortably — confirm with Steve before raising further. MAX_CONCURRENT_BULK_FETCH = 6 +# distinct_summary caps — keep work bounded for very large tables. +# Smoke-tested 2026-05-13 (Dabrowska BNST): `query_documents(class=treatment)` +# returned 49 rows, all named "Optogenetic Tetanus Stimulation Target Location" +# — the LLM didn't realize all rows were duplicates and assumed the dataset had +# only optogenetic treatments. distinct_summary surfaces that collapse so the +# model can say "9 distinct strains across 215 subjects" without sampling. +# +# Skip the computation entirely above this row count — the table is too big to +# scan in-memory affordably (>10K rows x ~15 columns = 150K cell reads). +DISTINCT_SUMMARY_MAX_ROWS = 10_000 +# How many top values per column to surface. Beyond ~5 the LLM's context +# bloats without adding signal; "top-5 + counts + distinct_count" is the +# tight summary the prompt is tuned against. +DISTINCT_SUMMARY_TOP_K = 5 + +# Aliases for the canonical NDI class names. The cloud's `isa` operator +# walks ``classLineage`` so a query for the BASE class returns docs of +# any subclass — but the cloud (as deployed) does NOT walk the LINEAGE +# in the OTHER direction. Datasets ingested under the modern schema +# emit ``element`` rather than the legacy ``probe`` class name; an `isa +# probe` query against those datasets returns 0 IDs even though the +# user-facing concept ("the probes") is fully represented. +# +# When the caller requests one of the entries in this map and the +# literal query returns 0 IDs, ``_build_single_class`` retries with the +# fallback. Logged so the alias hit shows up in observability. +# +# Smoke-tested 2026-05-14 against Dabrowska BNST (id 67f723d574f5f79c +# 6062389d, 0 probes / 606 elements): ``query_documents(className=probe)`` +# returned 0 rows pre-fix; with the fallback it returns the 606 +# element rows (which IS what summary.probeTypes computes its values +# from). +# 2026-05-18 — moved to backend/services/class_aliases.py so +# document_service can share the same source-of-truth (B2: workspace +# Probes picker calls /documents?class=probe and needs to follow +# `probe→element` too). The alias name stays prefixed with `_` +# for backwards-compat with any external test stubs that monkey-patched +# it; the canonical name is `CLASS_ALIASES`. +_CLASS_ALIASES: dict[str, list[str]] = CLASS_ALIASES + # Enrichment plan per primary class. Each listed class is fetched dataset- # wide in parallel and its docs indexed by the `depends_on` edge the # projection needs. `subject` and `openminds_subject` are always pulled when @@ -60,11 +101,18 @@ # element-centric rows; treatment is pulled when the row is subject- # attributable (subject, element, element_epoch). _ENRICHMENTS_FOR: dict[str, list[str]] = { - "subject": ["openminds_subject", "treatment"], + # 2026-05-19 (F-1b follow-up) — extended to fetch treatment_drug + + # treatment_transfer in addition to literal `treatment`. Bhar + # publishes treatments via the subclasses only (0 literal + + # 24466 drug + 1675 transfer); without the subclass enrichments + # F-1b's _broadcast_treatments_onto_subjects saw an empty list + # and emitted no broadcast columns. _project_for_class's subject + # branch merges all three lists post-fetch. + "subject": ["openminds_subject", "treatment", "treatment_drug", "treatment_transfer"], "element": ["subject", "openminds_subject", "probe_location"], "probe": ["subject", "openminds_subject", "probe_location"], - "element_epoch": ["element", "subject", "openminds_subject", "probe_location", "treatment"], - "epoch": ["element", "subject", "openminds_subject", "probe_location", "treatment"], + "element_epoch": ["element", "subject", "openminds_subject", "probe_location", "treatment", "treatment_drug", "treatment_transfer"], + "epoch": ["element", "subject", "openminds_subject", "probe_location", "treatment", "treatment_drug", "treatment_transfer"], "treatment": ["subject", "openminds_subject"], "openminds_subject": [], "probe_location": [], @@ -103,21 +151,61 @@ async def single_class( class_name: str, *, session: SessionData | None, + page: int | None = None, + page_size: int | None = None, + subject_filter: str | None = None, ) -> dict[str, Any]: + """Build (or read from cache) the full per-class table, then optionally + slice for pagination. + + Pagination semantics (Stream 5.8, 2026-05-16): + + - When BOTH ``page`` and ``page_size`` are ``None`` the response is the + backward-compatible envelope ``{columns, rows, distinct_summary}`` + carrying every row. Existing callers (Document Explorer's full-set + fetch, the cron warm-cache) stay on this path. + + - When EITHER is provided the response adds ``page``, ``pageSize``, + ``totalRows``, and ``hasMore`` fields, and ``rows`` is sliced + server-side. Default page=1, page_size=200 when only one is given. + + The cache stays keyed by ``(dataset_id, class_name, user_scope)`` — + the FULL row set is cached, never per-page. Slicing happens in-memory + after the cache get/compute, so the warm-cache cron (which fetches + unpaged) still hydrates every page for downstream paged readers. + Egress savings come entirely from the smaller response body; the + cloud-fetch work is unchanged. + """ access_token = session.access_token if session else None if self.cache is not None: key = RedisTableCache.table_key( dataset_id, class_name, user_scope=user_scope_for(session), ) - return await self.cache.get_or_compute( + full = await self.cache.get_or_compute( key, lambda: self._build_single_class( dataset_id, class_name, access_token=access_token, ), ) - return await self._build_single_class( - dataset_id, class_name, access_token=access_token, - ) + else: + full = await self._build_single_class( + dataset_id, class_name, access_token=access_token, + ) + + # 2026-05-19 — F-2: subject filter applied AFTER cache read, + # BEFORE pagination, so the cache stays keyed by the full table + # (one cold-compute amortizes every per-subject query). Matches + # `subjectDocumentIdentifier` exactly. If the column doesn't + # exist on this class the result is an empty rows envelope — + # callers get a clear "no matches" instead of an error. + if subject_filter is not None and subject_filter.strip(): + full = _filter_rows_by_subject(full, subject_filter.strip()) + + if page is None and page_size is None: + # Backward-compatible unpaged envelope. + return full + + return _paginate(full, page=page or 1, page_size=page_size or 200) async def _build_single_class( self, @@ -127,12 +215,37 @@ async def _build_single_class( access_token: str | None, ) -> dict[str, Any]: t0 = time.perf_counter() + # Try the literal class first; if it returns 0 IDs and we have a + # canonical alias (probe→element, epoch→element_epoch), retry on + # the alias. The projection key stays the literal so PROBE_COLUMNS + # are emitted regardless — the LLM and the document-explorer + # client both render rows under the user-requested class label. body = await self.cloud.ndiquery( searchstructure=[{"operation": "isa", "param1": class_name}], scope=dataset_id, access_token=access_token, ) ids = _extract_ids(body) + resolved_class = class_name + if not ids: + for alias in _CLASS_ALIASES.get(class_name, []): + alt_body = await self.cloud.ndiquery( + searchstructure=[{"operation": "isa", "param1": alias}], + scope=dataset_id, + access_token=access_token, + ) + alt_ids = _extract_ids(alt_body) + if alt_ids: + log.info( + "table.single.alias_hit", + dataset_id=dataset_id, + requested_class=class_name, + resolved_class=alias, + ids=len(alt_ids), + ) + ids = alt_ids + resolved_class = alias + break docs = await self._bulk_fetch_all(dataset_id, ids, access_token=access_token) # Fetch all enrichment classes in parallel. If any REQUIRED enrichment @@ -140,7 +253,11 @@ async def _build_single_class( # transient cloud failure doesn't pin a broken empty-enrichment table # into Redis for the full TTL window. Plan §M4a step 3: "Skip cache # if cloud call fails." - enrich_classes = _ENRICHMENTS_FOR.get(class_name, []) + # Enrichment plan keys off the RESOLVED class (so an aliased + # probe→element fetch still pulls subject + openminds_subject + + # probe_location, which `_row_probe` needs to populate location + + # cell-type columns). + enrich_classes = _ENRICHMENTS_FOR.get(resolved_class, []) enriched: dict[str, list[dict[str, Any]]] = {} if enrich_classes and docs: results = await asyncio.gather( @@ -154,23 +271,46 @@ async def _build_single_class( if isinstance(r, BaseException): log.warning( "table.enrichment_failed", - primary=class_name, + primary=resolved_class, enrichment=ec, error=str(r), ) # Required enrichment failed — propagate so the cache # doesn't pin a broken build. The caller re-tries on the # next request, which may succeed against a healthy cloud. - if ec in _REQUIRED_ENRICHMENTS.get(class_name, set()): + if ec in _REQUIRED_ENRICHMENTS.get(resolved_class, set()): raise RuntimeError( f"Required enrichment {ec!r} failed while building " - f"{class_name} table: {r}", + f"{resolved_class} table: {r}", ) enriched[ec] = [] else: enriched[ec] = r + # Project under the REQUESTED class so the alias resolution + # is invisible to the caller — `useSummaryTable('element_epoch')` + # gets EPOCH_COLUMNS regardless of whether the alias chain + # resolved to `element_epoch`, `epochfiles_ingested`, or + # `daqreader_mfdaq_epochdata_ingested`. The projection + # helpers (`_row_probe`, `_row_epoch`, etc.) tolerate the + # slightly different data shapes via the `_first()` path-list + # walker — fields that don't exist on the legacy class shape + # surface as null cells, which the cloud-app's + # auto-hide-empty-column logic suppresses gracefully. + # + # Pre-2026-05-19 this used `resolved_class`, which meant + # `element_epoch → epochfiles_ingested` fell through to + # GENERIC_COLUMNS (just `name` + `documentIdentifier`) — + # invisible to the F-1d alias addition because the rows + # came back but the columns were generic. columns, rows = _project_for_class(class_name, docs, enriched) + # distinct_summary is computed over ALL projected rows (not the + # client-side paged slice), so consumers (notably the /ask chat's + # query_documents tool, which limits to 10-30 rows) can still see + # "9 distinct strains across 215 subjects" without sampling. The + # work is bounded by DISTINCT_SUMMARY_MAX_ROWS; cached alongside + # rows so a Redis hit returns it for free. + distinct_summary = _build_distinct_summary(columns, rows) table_build_duration_seconds.labels(class_name=class_name).observe( time.perf_counter() - t0, ) @@ -178,11 +318,16 @@ async def _build_single_class( "table.build.single", dataset_id=dataset_id, class_name=class_name, + resolved_class=resolved_class, ids=len(ids), rows=len(rows), ms=int((time.perf_counter() - t0) * 1000), ) - return {"columns": columns, "rows": rows} + return { + "columns": columns, + "rows": rows, + "distinct_summary": distinct_summary, + } async def combined( self, @@ -314,6 +459,27 @@ def _accept_optional(name: str, r: object) -> list[dict[str, Any]]: "epochId": _ndi_id(epoch), }) + combined_columns = [ + {"key": "subject", "label": "Subject"}, + {"key": "species", "label": "Species"}, + {"key": "speciesOntology", "label": "Species Ontology"}, + {"key": "strain", "label": "Strain"}, + {"key": "strainOntology", "label": "Strain Ontology"}, + {"key": "sex", "label": "Sex"}, + {"key": "probe", "label": "Probe"}, + {"key": "probeLocationName", "label": "Probe Location"}, + {"key": "probeLocationOntology", "label": "Probe Location Ontology"}, + {"key": "type", "label": "Probe type"}, + {"key": "epoch", "label": "Epoch"}, + {"key": "approachName", "label": "Approach"}, + {"key": "approachOntology", "label": "Approach Ontology"}, + {"key": "start", "label": "Start"}, + {"key": "stop", "label": "Stop"}, + ] + # Same distinct_summary rationale as single_class: surfaces per-column + # cardinality so the LLM can answer "how many distinct strains" without + # paging the full table. + distinct_summary = _build_distinct_summary(combined_columns, rows) elapsed = time.perf_counter() - build_start table_build_duration_seconds.labels(class_name="combined").observe(elapsed) log.info( @@ -325,24 +491,9 @@ def _accept_optional(name: str, r: object) -> list[dict[str, Any]]: ms=int(elapsed * 1000), ) return { - "columns": [ - {"key": "subject", "label": "Subject"}, - {"key": "species", "label": "Species"}, - {"key": "speciesOntology", "label": "Species Ontology"}, - {"key": "strain", "label": "Strain"}, - {"key": "strainOntology", "label": "Strain Ontology"}, - {"key": "sex", "label": "Sex"}, - {"key": "probe", "label": "Probe"}, - {"key": "probeLocationName", "label": "Probe Location"}, - {"key": "probeLocationOntology", "label": "Probe Location Ontology"}, - {"key": "type", "label": "Probe type"}, - {"key": "epoch", "label": "Epoch"}, - {"key": "approachName", "label": "Approach"}, - {"key": "approachOntology", "label": "Approach Ontology"}, - {"key": "start", "label": "Start"}, - {"key": "stop", "label": "Stop"}, - ], + "columns": combined_columns, "rows": rows, + "distinct_summary": distinct_summary, } async def ontology_tables( @@ -514,6 +665,90 @@ async def _fetch(batch: list[str]) -> list[dict[str, Any]]: return flat +# --------------------------------------------------------------------------- +# Subject filter helper (F-2, 2026-05-19) +# --------------------------------------------------------------------------- + +# Column keys that carry a per-row subject document identifier. We try +# the canonical ``subjectDocumentIdentifier`` first, then ``subjectId`` +# (older shape, kept for forward-compat). Anything else falls through +# to "no rows matched." +_SUBJECT_FILTER_KEYS: tuple[str, ...] = ( + "subjectDocumentIdentifier", + "subjectId", +) + + +def _filter_rows_by_subject( + full: dict[str, Any], subject: str, +) -> dict[str, Any]: + """Return a copy of ``full`` with ``rows`` filtered to those whose + subject column matches ``subject`` exactly. + + Operates on the cached full envelope — the cache stays keyed by + (dataset_id, class_name, user_scope), so subject-filtered requests + benefit from the same cold-compute amortization the unfiltered + requests get. The ``columns`` and ``distinct_summary`` blocks are + carried verbatim (distinct_summary intentionally reflects the + pre-filter full table so callers can still answer "how many + distinct subjects in this dataset" without re-fetching). + + When the table doesn't carry a subject column the result is the + envelope with empty ``rows`` — clear "no matches" signal rather + than an error. + """ + rows = full.get("rows") or [] + if not rows: + return full + filtered: list[dict[str, Any]] = [] + for row in rows: + if not isinstance(row, dict): + continue + for key in _SUBJECT_FILTER_KEYS: + value = row.get(key) + if isinstance(value, str) and value == subject: + filtered.append(row) + break + out = dict(full) + out["rows"] = filtered + return out + + +# --------------------------------------------------------------------------- +# Pagination helper (Stream 5.8, 2026-05-16) +# --------------------------------------------------------------------------- + +def _paginate( + full: dict[str, Any], *, page: int, page_size: int, +) -> dict[str, Any]: + """Slice a full single-class table envelope into a paged envelope. + + Carries over ``columns`` and ``distinct_summary`` verbatim — the latter + is computed over the FULL row set (capped by ``DISTINCT_SUMMARY_MAX_ROWS``) + so consumers can still answer "how many distinct strains across all 215 + rows" without paging the whole table. ``rows`` is sliced to the requested + page. + + Inputs are validated by the FastAPI Query layer (``page >= 1``, + ``1 <= page_size <= 1000``); this helper assumes those bounds. + """ + rows = full.get("rows", []) or [] + total = len(rows) + start = (page - 1) * page_size + end = start + page_size + out: dict[str, Any] = { + "columns": full.get("columns", []), + "rows": rows[start:end], + "page": page, + "pageSize": page_size, + "totalRows": total, + "hasMore": end < total, + } + if "distinct_summary" in full: + out["distinct_summary"] = full["distinct_summary"] + return out + + # --------------------------------------------------------------------------- # Generic extraction helpers # --------------------------------------------------------------------------- @@ -996,13 +1231,30 @@ def _subject_display_name(d: dict[str, Any]) -> str | None: {"key": "subjectDocumentIdentifier", "label": "Subject Doc ID"}, ] +# F-1 (2026-05-19) — projection for stimulus_presentation. The +# workspace's StimuliPicker previously hit the generic +# `/documents?class=stimulus_presentation` endpoint, which is capped +# at 200 rows per page and gives no projected columns. This projection +# extracts the stimulus name + presentation count + first/last +# presentation time so a dataset with thousands of stim docs can be +# scanned/sorted in the rail. Mirrors PROBE_COLUMNS / EPOCH_COLUMNS +# shape for visual consistency. +STIMULUS_COLUMNS: list[dict[str, str]] = [ + {"key": "stimulusDocumentIdentifier", "label": "Stimulus Doc ID"}, + {"key": "stimulusName", "label": "Name"}, + {"key": "elementDocumentIdentifier", "label": "Element Doc ID"}, + {"key": "presentationCount", "label": "Presentations"}, + {"key": "firstPresentationTime", "label": "First Onset (s)"}, + {"key": "lastPresentationTime", "label": "Last Onset (s)"}, +] + GENERIC_COLUMNS: list[dict[str, str]] = [ {"key": "name", "label": "Name"}, {"key": "documentIdentifier", "label": "Doc ID"}, ] -def _project_for_class( +def _project_for_class( # noqa: PLR0911 — linear per-class dispatch is clearer than a lookup table; F-1b added the subject early-return branch. class_name: str, docs: list[dict[str, Any]], enriched: dict[str, list[dict[str, Any]]], @@ -1018,7 +1270,32 @@ def _project_for_class( _attach_openminds_enrichment(docs, om_subjects) if class_name == "subject": - return SUBJECT_COLUMNS, [_row_subject(d) for d in docs] + subject_rows = [_row_subject(d) for d in docs] + # F-1b (2026-05-19) — server-side broadcast of per-subject + # treatment values onto the subject summary table. Replaces the + # cloud-app's ~100-line joinTreatmentsToSubjects frontend pivot + # at table-shell.tsx so the workspace's SubjectsBrowser gets + # the same enriched columns at no extra cost (ADR-001 heart- + # on-Railway). Treatment rows are projected via the canonical + # _row_treatment which already handles treatment_drug + + # treatment_transfer subclasses. + # + # Datasets like Bhar publish treatments via subclasses only + # (0 literal `treatment` + 24466 `treatment_drug` + 1675 + # `treatment_transfer`), so we MUST merge all three enrichment + # buckets — `_fetch_class("treatment")` returns literal-class + # rows only; the alias chain only applies to PRIMARY-class + # fetches, not enrichments. + treatment_docs = ( + enriched.get("treatment", []) + + enriched.get("treatment_drug", []) + + enriched.get("treatment_transfer", []) + ) + treatment_rows = [_row_treatment(t) for t in treatment_docs] + new_rows, new_columns = _broadcast_treatments_onto_subjects( + subject_rows, list(SUBJECT_COLUMNS), treatment_rows, + ) + return new_columns, new_rows if class_name in ("probe", "element"): return PROBE_COLUMNS, [_row_probe(d, enriched) for d in docs] @@ -1034,9 +1311,18 @@ def _project_for_class( rows.append(_row_epoch(epoch, enriched, subject=subject, element=element)) return EPOCH_COLUMNS, rows - if class_name == "treatment": + if class_name in ("treatment", "treatment_drug", "treatment_transfer"): + # F-1e (2026-05-19) — treatment_drug + treatment_transfer share + # the TREATMENT_COLUMNS row shape via the auto-detect branches + # in `_row_treatment`. The Gantt projection downstream + # (treatment_timeline_service) merges rows across all three + # source classes for datasets like Bhar that emit treatment + # data only in the legacy subclasses. return TREATMENT_COLUMNS, [_row_treatment(d) for d in docs] + if class_name in ("stimulus", "stimulus_presentation"): + return STIMULUS_COLUMNS, [_row_stimulus(d) for d in docs] + if class_name == "probe_location": return PROBE_COLUMNS, [_row_probe_location_only(d) for d in docs] @@ -1203,7 +1489,67 @@ def _row_epoch( def _row_treatment(d: dict[str, Any]) -> dict[str, Any]: - tdata = (d.get("data") or {}).get("treatment") or {} + """Project a treatment-shaped doc to TREATMENT_COLUMNS. + + Handles three input shapes via auto-detect on which sub-block is + present under ``data``: + + - ``treatment`` (canonical) — reads ``name``, ``ontologyName``, + ``numeric_value``, ``string_value``; subject via ``subject_id`` + depends_on edge. + + - ``treatment_drug`` (Bhar etc.) — name parsed from + ``mixture_table`` (first non-header row, second CSV column); + timing from ``administration_onset_time`` / + ``administration_offset_time``; encodes the (onset, offset) + pair into ``numericValue`` for the downstream Gantt projection. + Subject via ``subject_id`` depends_on edge. + + - ``treatment_transfer`` (Bhar) — name from ``entity_name``; + ``timestamp`` is a scalar onset, encoded as ``numericValue``. + Subject via ``recipient_id`` depends_on edge (NOT ``subject_id`` + — transfer docs use the recipient/donor pair). + + For backwards compatibility callers using TREATMENT_COLUMNS get the + same column set regardless of source class — the auto-hide-empty- + column UX suppresses fields that don't exist on a given subclass. + """ + data = d.get("data") or {} + + # treatment_drug — newer NDI ingest path (Bhar etc.) + if "treatment_drug" in data: + td = data["treatment_drug"] or {} + # mixture_table format: header row + 1+ data rows, CSV. + # First data row's second column is the human-readable name. + name = _parse_mixture_table_name(td.get("mixture_table")) + onset = td.get("administration_onset_time") + offset = td.get("administration_offset_time") + return { + "treatmentName": _clean(name) or _project_name(d), + "treatmentOntology": _clean(td.get("location_ontologyName")), + "numericValue": _coerce_timing_pair(onset, offset), + "stringValue": _clean(td.get("location_name")), + "subjectDocumentIdentifier": _depends_on_value_by_name(d, "subject_id"), + } + + # treatment_transfer — Bhar transfer events + if "treatment_transfer" in data: + tt = data["treatment_transfer"] or {} + timestamp = tt.get("timestamp") + return { + "treatmentName": _clean(tt.get("entity_name")) or _project_name(d), + "treatmentOntology": _clean(tt.get("entity_ontologyNode")), + # Single-point onset becomes a length-1 list — the timeline + # projection treats it as a 1-unit-duration tick at `timestamp`. + "numericValue": [timestamp] if _is_finite_number(timestamp) else None, + "stringValue": _clean(tt.get("method_name")), + # transfer docs use recipient_id / donor_id pair instead + # of subject_id. Recipient is the subject affected. + "subjectDocumentIdentifier": _depends_on_value_by_name(d, "recipient_id"), + } + + # treatment (canonical) — legacy path. + tdata = data.get("treatment") or {} return { "treatmentName": _clean(tdata.get("name")) or _project_name(d), "treatmentOntology": _clean(tdata.get("ontologyName")), @@ -1211,3 +1557,406 @@ def _row_treatment(d: dict[str, Any]) -> dict[str, Any]: "stringValue": _clean(tdata.get("string_value")), "subjectDocumentIdentifier": _depends_on_value_by_name(d, "subject_id"), } + + +def _parse_mixture_table_name(s: Any) -> str | None: + """Extract a treatment name from a ``treatment_drug.mixture_table`` + CSV string. + + Shape: ``"ontologyName,name\\nNCBITaxon:637912,Eschericia coli OP50\\n"`` + The first line is the CSV header; the second line carries the + actual name in the second column. Multi-row mixtures join names + with ' + ' so the Gantt label reflects the mixture. + + Returns None for empty/malformed input — caller falls back to + `_project_name(d)`. + """ + if not isinstance(s, str) or not s.strip(): + return None + lines = [ln for ln in s.split("\n") if ln.strip()] + if len(lines) < 2: + return None + names: list[str] = [] + for ln in lines[1:]: + parts = [p.strip() for p in ln.split(",")] + if len(parts) >= 2 and parts[1]: + names.append(parts[1]) + if not names: + return None + return " + ".join(names) + + +def _coerce_timing_pair(onset: Any, offset: Any) -> Any: + """Return a 2-element list [onset_seconds, offset_seconds] when the + pair represents finite numbers OR ISO-8601 duration strings; None + otherwise. + + treatment_drug carries timing in mixed shapes: + - Both numeric (e.g., onset=-21600, offset=0) → [-21600.0, 0.0] + - Strings in ``HH:MM:SS`` or ``-HH:MM:SS`` form + (e.g., "-06:00:00") → seconds since reference + - One or both missing/empty → None (downstream falls back to + ordinal timing) + """ + def _to_seconds(v: Any) -> float | None: + if _is_finite_number(v): + return float(v) # type: ignore[arg-type] + if isinstance(v, str) and v: + # Best-effort HH:MM:SS parse (allow leading sign for negative + # relative times: "-06:00:00" = 6 hours before reference). + sign = 1.0 + s = v.strip() + if s.startswith("-"): + sign = -1.0 + s = s[1:] + elif s.startswith("+"): + s = s[1:] + parts = s.split(":") + try: + if len(parts) == 3: + h, m, sec = (float(parts[0]), float(parts[1]), float(parts[2])) + return sign * (h * 3600 + m * 60 + sec) + if len(parts) == 2: + m, sec = (float(parts[0]), float(parts[1])) + return sign * (m * 60 + sec) + if len(parts) == 1: + return sign * float(parts[0]) + except ValueError: + return None + return None + + a = _to_seconds(onset) + b = _to_seconds(offset) + if a is None or b is None: + return None + return [a, b] + + +def _is_finite_number(v: Any) -> bool: + """Identical to the helper in treatment_timeline_service — duplicated + here so the projection can stay self-contained without importing + across services. + """ + if isinstance(v, bool): + return False + if isinstance(v, (int, float)): + import math + return math.isfinite(float(v)) + return False + + +def _row_stimulus(d: dict[str, Any]) -> dict[str, Any]: + """F-1 (2026-05-19) — project a stimulus_presentation doc to the + StimuliPicker row shape. + + Field sources: + - ``stimulusName`` → ``data.stimulus_presentation.name`` + OR ``data.base.name`` (fallback) + - ``elementDocumentIdentifier`` → ``depends_on.value`` where + name == ``"element_id"`` OR + ``"stimulus_element_id"`` + - ``presentationCount`` → ``len(data.stimulus_presentation.presentations)`` + - ``firstPresentationTime`` → presentations[0].time_started + - ``lastPresentationTime`` → presentations[-1].time_started + + All temporal fields default to ``None`` when the doc shape doesn't + carry them — auto-hide-empty-column UX hides the absent columns. + """ + data = d.get("data") or {} + sp = data.get("stimulus_presentation") or {} + presentations = sp.get("presentations") or [] + pres_list = presentations if isinstance(presentations, list) else [] + + first_time: Any = None + last_time: Any = None + if pres_list: + first_entry = pres_list[0] if isinstance(pres_list[0], dict) else None + last_entry = pres_list[-1] if isinstance(pres_list[-1], dict) else None + if first_entry is not None: + first_time = first_entry.get("time_started") + if last_entry is not None: + last_time = last_entry.get("time_started") + + # depends_on can carry either `element_id` (most common) or + # `stimulus_element_id` (some older datasets). Try both. + element_ref = ( + _depends_on_value_by_name(d, "element_id") + or _depends_on_value_by_name(d, "stimulus_element_id") + ) + + return { + "stimulusDocumentIdentifier": _ndi_id(d) or _first(d, "base.id"), + "stimulusName": ( + _clean(sp.get("name")) + or _project_name(d) + or _clean(_first(d, "base.name")) + ), + "elementDocumentIdentifier": element_ref, + "presentationCount": len(pres_list), + "firstPresentationTime": first_time, + "lastPresentationTime": last_time, + } + + +# --------------------------------------------------------------------------- +# F-1b — broadcast per-subject treatment columns onto subject summary +# (2026-05-19). Ports the cloud-app's frontend joinTreatmentsToSubjects +# pivot from +# `apps/web/app/(app)/datasets/[id]/tables/[className]/table-shell.tsx` +# into the backend so the workspace SubjectsBrowser receives the same +# dynamic columns. Heart-on-Railway, ADR-001. +# --------------------------------------------------------------------------- + +def _pascal_case_from_treatment_name(s: Any) -> str | None: + """Convert a treatment name into a PascalCase prefix for dynamic + column keys. + + Mirror of the cloud-app's ``pascalCaseFromTreatmentName`` helper at + table-shell.tsx lines 768-779. Whitespace is collapsed, then each + word is upper-cased on the first letter. Non-alphanumeric + characters are stripped per word (these are not expected in + canonical treatment names; including them would produce illegal + column-key characters that break header rendering downstream). + + Empty / whitespace-only / null / non-string input returns ``None`` — + caller must skip the row to keep parity with the JS behaviour. + + Examples: + "Optogenetic Tetanus Stimulation Target Location" + -> "OptogeneticTetanusStimulationTargetLocation" + "Foo Bar Baz Quux" -> "FooBarBazQuux" + "with-hyphens here" -> "WithhyphensHere" (hyphen stripped per word) + "" -> None + None -> None + 42 -> None (non-string) + """ + if not isinstance(s, str): + return None + trimmed = s.strip() + if not trimmed: + return None + parts: list[str] = [] + for word in trimmed.split(): + clean = "".join(ch for ch in word if ch.isalnum()) + if not clean: + continue + parts.append(clean[0].upper() + clean[1:]) + joined = "".join(parts) + return joined or None + + +def _treatment_broadcast_value(t_row: dict[str, Any]) -> Any: + """Pick the broadcast cell value for a treatment row. + + Mirrors the JS: + const value = (typeof stringVal === 'string' && stringVal) + || (typeof stringVal === 'number' ? stringVal : null) + || (typeof numericVal === 'number' ? numericVal : null) + || (Array.isArray(numericVal) && numericVal.length > 0 + ? numericVal : null); + + Priority: + 1. ``stringValue`` if a non-empty string (e.g. ``"UBERON:0001930"`` + for a Location-typed treatment). + 2. ``stringValue`` if a number (rare; defensive parity). + 3. ``numericValue`` if a number (dose / duration). + 4. ``numericValue`` if a non-empty list (timing pair from + treatment_drug etc.). + + Returns ``None`` when nothing matches — the cell stays empty. + + Bools are explicitly excluded since Python's ``isinstance(x, int)`` + is True for booleans and would otherwise match the numeric branch. + """ + string_val = t_row.get("stringValue") + if isinstance(string_val, str) and string_val: + return string_val + if isinstance(string_val, (int, float)) and not isinstance(string_val, bool): + return string_val + numeric_val = t_row.get("numericValue") + if isinstance(numeric_val, (int, float)) and not isinstance(numeric_val, bool): + return numeric_val + if isinstance(numeric_val, list) and len(numeric_val) > 0: + return numeric_val + return None + + +def _broadcast_treatments_onto_subjects( # noqa: PLR0912 — direct port of the cloud-app's pure JS pivot; branch count reflects per-row guards and the two-pass injection. Splitting would obscure the parity contract with table-shell.tsx joinTreatmentsToSubjects. + subject_rows: list[dict[str, Any]], + subject_columns: list[dict[str, Any]], + treatment_rows: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Inject per-subject treatment columns onto the subject summary table. + + Pure function. Replaces the cloud-app's ``joinTreatmentsToSubjects`` + pivot at table-shell.tsx lines 821-923. Tested directly via the + unit-test suite. + + Behaviour: + - Treatment rows whose ``subjectDocumentIdentifier`` is missing or + non-string are skipped. + - Treatment rows whose ``treatmentName`` does not yield a legal + PascalCase key (empty / null / non-string after normalisation) + are skipped silently. The user still sees that treatment in + the dedicated Treatments tab. + - Per treatment kind two column keys are emitted: + ``<prefix>Name`` and ``<prefix>Ontology``. Labels are + ``"{treatmentName} Name"`` and ``"{treatmentName} Ontology"``. + - Multiple treatments of the same kind on the same subject + collect into an array (csvJoinFormatter then renders + ``"a, b, c"``). + - Subjects with no matching treatment of a given kind get + ``None`` cells (NOT broadcast) — important so the cloud-app's + discoverDynamicColumns sees the column on every row. + - When NO dynamic columns are discovered the originals are + returned unchanged (avoid needlessly mutating column identity). + - Columns already present in ``subject_columns`` (defensive — the + backend doesn't emit treatment columns directly today) are + skipped from the appended list. + + Returns + ------- + (rows, columns) + A new rows list (shallow copy of each row with the dynamic + cells injected) and a new columns list (original + discovered). + Both inputs are NOT mutated. + """ + # Outer key = subjectDocumentIdentifier; inner = column key; + # value = collected array across multiple matching treatments. + by_subject: dict[str, dict[str, list[Any]]] = {} + # Track every dynamic column key discovered, with its label, in + # insertion order — first treatmentName wins for label rendering. + discovered_keys: dict[str, str] = {} + + for t_row in treatment_rows: + if not isinstance(t_row, dict): + continue + subject_id = t_row.get("subjectDocumentIdentifier") + if not isinstance(subject_id, str) or not subject_id: + continue + treatment_name = t_row.get("treatmentName") + prefix = _pascal_case_from_treatment_name(treatment_name) + if not prefix: + continue + + name_key = f"{prefix}Name" + ontology_key = f"{prefix}Ontology" + name_label = ( + f"{treatment_name} Name" + if isinstance(treatment_name, str) + else name_key + ) + ontology_label = ( + f"{treatment_name} Ontology" + if isinstance(treatment_name, str) + else ontology_key + ) + discovered_keys.setdefault(name_key, name_label) + discovered_keys.setdefault(ontology_key, ontology_label) + + per_subject = by_subject.setdefault(subject_id, {}) + value = _treatment_broadcast_value(t_row) + if value is not None: + per_subject.setdefault(name_key, []).append(value) + ontology = t_row.get("treatmentOntology") + if isinstance(ontology, str) and ontology: + per_subject.setdefault(ontology_key, []).append(ontology) + + # No discovered dynamic columns → return originals unchanged (column + # object identity matters for the cloud-app's column-toggle picker). + if not discovered_keys: + return subject_rows, subject_columns + + new_rows: list[dict[str, Any]] = [] + for row in subject_rows: + subject_id = row.get("subjectDocumentIdentifier") + per_subject_for_row: dict[str, list[Any]] | None = ( + by_subject.get(subject_id) if isinstance(subject_id, str) else None + ) + out: dict[str, Any] = dict(row) + for key in discovered_keys: + collected = ( + per_subject_for_row.get(key) + if per_subject_for_row is not None + else None + ) + if not collected: + out[key] = None + elif len(collected) == 1: + out[key] = collected[0] + else: + out[key] = collected + new_rows.append(out) + + existing_keys = {c.get("key") for c in subject_columns} + new_columns: list[dict[str, Any]] = list(subject_columns) + for key, label in discovered_keys.items(): + if key in existing_keys: + continue + new_columns.append({"key": key, "label": label}) + return new_rows, new_columns + + +# --------------------------------------------------------------------------- +# distinct_summary — per-column cardinality + top values +# --------------------------------------------------------------------------- + +def _hashable(value: Any) -> Any: + """Convert a cell value into something usable as a dict key. + + Most projected cells are scalars (str, int, float, bool, None). A few + columns (e.g. element_epoch.epochStart) project as small dicts like + `{"devTime": 0, "globalTime": None}`. Those still need to be countable, + so we JSON-serialize the unhashables into stable string keys. + """ + if isinstance(value, (str, int, float, bool)) or value is None: + return value + # dict / list / tuple — stringify deterministically. + try: + import json as _json + return _json.dumps(value, sort_keys=True, default=str) + except (TypeError, ValueError): + return repr(value) + + +def _build_distinct_summary( + columns: list[dict[str, str]], + rows: list[dict[str, Any]], +) -> dict[str, Any]: + """For each column compute {distinct_count, top_values} across ALL rows. + + `top_values` is a list of `{value, count}` ordered by descending count, + capped at `DISTINCT_SUMMARY_TOP_K`. Empty/None cells are tallied as + `None` so the LLM can still see "5/49 rows had a null treatmentName". + + Returns `{"_meta": "skipped due to large row count"}` shape when there + are more than `DISTINCT_SUMMARY_MAX_ROWS` rows — the table is too big + to scan affordably and the LLM doesn't need per-cell stats at that + point (it should pivot to ndi_query or get_facets). + + The function is pure: no I/O, no mutation. Folded into the cached + table response so a Redis hit returns it for free. + """ + if len(rows) > DISTINCT_SUMMARY_MAX_ROWS: + return {"_meta": "skipped due to large row count"} + + out: dict[str, Any] = {} + for col in columns: + key = col.get("key") + if not isinstance(key, str): + continue + counts: dict[Any, int] = {} + for row in rows: + cell = row.get(key) + counts[_hashable(cell)] = counts.get(_hashable(cell), 0) + 1 + # Sort by count desc, then by string-form of the value for deterministic + # tie-break (matters for test snapshots and cache key stability). + ordered = sorted( + counts.items(), key=lambda kv: (-kv[1], str(kv[0])), + ) + top = [{"value": v, "count": c} for v, c in ordered[:DISTINCT_SUMMARY_TOP_K]] + out[key] = { + "distinct_count": len(counts), + "top_values": top, + } + return out diff --git a/backend/services/tabular_query_service.py b/backend/services/tabular_query_service.py new file mode 100644 index 0000000..042457d --- /dev/null +++ b/backend/services/tabular_query_service.py @@ -0,0 +1,1292 @@ +"""tabular_query_service — aggregate ``ontologyTableRow`` documents +into per-group statistics + raw values for violin/jitter rendering. + +Used by the chat's ``tabular_query`` tool. The chat passes a substring +match against an ``ontologyTableRow`` column name (e.g. +``"ElevatedPlusMaze_OpenArmNorth_Entries"``) plus an optional +grouping column key (e.g. ``"treatment_group"``). The service: + +1. Calls :meth:`SummaryTableService.ontology_tables` which projects + the dataset's ``ontologyTableRow`` docs into one group per + distinct ``variableNames`` schema +2. Finds the first group containing a column whose key/label matches + the substring; that column is the value column +3. If ``groupBy`` is given, finds the column with that key inside the + same group; that's the grouping column +4. Iterates rows (each row is a dict keyed by column key), bucketing + numeric values by group label +5. Computes per-group stats (mean, median, std, min/max, q1/q3, + count) plus the raw values (capped + stride-sampled) for the + violin's jitter overlay +6. Returns the response shape :class:`ViolinChart` consumes + +Notable: this service does NOT call NDI-python — it operates on the +already-decoded ``ontologyTableRow`` shape that +``SummaryTableService`` projects from cloud-node. NDI-python becomes +valuable on the binary/decoding side, not the tabular-aggregation +side. Keeping this service pure-Python (statistics module only) keeps +it fast + side-effect-free. +""" +from __future__ import annotations + +import math +import statistics +from typing import Any, Literal + +from ..auth.session import SessionData +from ..observability.logging import get_logger +from .summary_table_service import SummaryTableService + +log = get_logger(__name__) + + +# Bound the response size — a violin with 100 groups isn't a chart, +# it's a wall of text. The chat tool's `groupOrder` parameter is the +# right escape hatch when callers really want a curated subset. +MAX_GROUPS = 20 + +# Per-group raw-value cap. Plotly's violin trace can comfortably +# render ~500 jitter points per group before the chart slows down on +# resize. Beyond that we stride-sample. The summary stats are computed +# on the FULL value list before sampling, so they remain accurate. +MAX_VALUES_PER_GROUP = 500 + +# Sample-row docId cap per group. The frontend builds one +# click-through citation chip per docId (e.g. "Sample Saline row"), +# so 3 per group keeps the chip count manageable on charts with many +# groups while still letting the user verify each group's data. The +# full set of contributing rows is reachable from the table-view +# citation (the primary chip). +MAX_DOC_IDS_PER_GROUP = 3 + +# --------------------------------------------------------------------------- +# S5.3 cross-table constants (2026-05-18). +# +# The cross-table sister-tool to ``violin_groups`` pairs two measurement +# columns per subject (or per treatment) for a scatter / strip-plot +# rendering. These constants keep the helper functions below + the +# orchestrator parameterized in one place. +# --------------------------------------------------------------------------- + +# Max pairs returned per cross-table response. Same scale as +# `MAX_VALUES_PER_GROUP * MAX_GROUPS` so a strip plot with the maximum +# allowed group count + per-group cap stays under this bound. Above +# this, the front-end's `ScatterChart.tsx` slows during pan/zoom. +MAX_PAIRS = 1000 + +# Treatment-join walks these classes in order (canonical → mixture +# variant → transfer variant). First-match-wins per subject; the +# downstream label-picker tolerates any of the three shapes since +# `summary_table_service::_row_treatment` projects all three onto the +# same TREATMENT_COLUMNS keys. +_TREATMENT_CLASS_CHAIN: tuple[str, ...] = ( + "treatment", + "treatment_drug", + "treatment_transfer", +) + +# Universal subject-id key on every projected row across single_class +# projections (subjects, probes, epochs, treatments, ontologyTableRow). +# Hardcoded rather than referenced because reflecting back through +# summary_table_service would create a circular import. +_SUBJECT_KEY = "subjectDocumentIdentifier" + +# Treatment-row label fields in priority order. When the caller passes +# `yVariableContains` (the "treatment field" needle, e.g. "name" or +# "reference"), the label picker tries (1) any row key that contains +# the needle, then (2) row keys that contain any of these field names. +# The list is biased toward the projected TREATMENT_COLUMNS shape but +# is forgiving of legacy / raw-cloud-shape rows (some older datasets +# emit `treatment.reference` / `mixture_table` directly). +_TREATMENT_LABEL_FIELDS: tuple[str, ...] = ( + "name", + "reference", + "treatment_reference", + "mixture", + "mixtureName", + "drugName", + "drug", +) + + +class TabularQueryService: + """Aggregate ontologyTableRow docs into per-group stats.""" + + def __init__(self, summary: SummaryTableService) -> None: + self.summary = summary + + async def violin_groups( # noqa: PLR0911 (linear-control-flow with early-return per failure mode is clearer than a state machine) + self, + dataset_id: str, + variable_name_contains: str, + *, + group_by: str | None, + group_order: list[str] | None, + session: SessionData | None, + ) -> dict[str, Any]: + """Return ``{groups: [...], yLabel, xLabel, source?}``. + + Each group has the shape consumed by + ``apps/web/components/charts/ViolinChart.tsx``:: + + {name, values, count, mean, median, std, min, max, q1, q3} + """ + if not variable_name_contains: + return _empty_response(group_by, reason="empty variableNameContains") + + ontology = await self.summary.ontology_tables(dataset_id, session=session) + groups = ontology.get("groups", []) + if not groups: + return _empty_response( + group_by, reason="no ontologyTableRow docs in dataset", + ) + + match = _find_matching_group(groups, variable_name_contains) + if match is None: + return _empty_response( + group_by, + reason=f"no ontologyTableRow column matched '{variable_name_contains}'", + available={"variable_names": [ + " | ".join(g.get("variableNames", []))[:120] + for g in groups[:5] + ]}, + ) + + group, value_col, value_label = match + rows = (group.get("table") or {}).get("rows") or [] + if not rows: + return _empty_response( + group_by, + reason="matched group had no rows", + yLabel=value_label, + ) + + # Resolve the groupBy column. Like the value column, callers + # rarely know the exact column key — substring-match against the + # group's columns (key OR label, case-insensitive). When the user + # leaves group_by unset, this returns None and the bucketing + # produces a single 'all' group. + resolved_group_col = ( + _resolve_group_column(group, group_by) if group_by else None + ) + if group_by and resolved_group_col is None: + return _empty_response( + group_by, + reason=f"no column matched groupBy '{group_by}' in the " + f"selected table", + yLabel=value_label, + available={"columns": [ + c.get("key") + for c in (group.get("table") or {}).get("columns") or [] + if c.get("key") != value_col + ][:20]}, + ) + + # docIds is parallel to rows (same index) per + # SummaryTableService.ontology_tables contract. + parallel_doc_ids = group.get("docIds") or [] + buckets, bucket_doc_ids, order_seen = _bucket_rows( + rows, parallel_doc_ids, value_col, resolved_group_col, + ) + if not buckets: + return _empty_response( + group_by, + reason="no numeric values in matched column", + yLabel=value_label, + ) + + ordered_keys = _ordered_group_keys(buckets, order_seen, group_order) + out_groups = _build_group_payloads( + buckets, bucket_doc_ids, ordered_keys, + ) + + result: dict[str, Any] = { + "groups": out_groups, + "yLabel": value_label, + "xLabel": group_by or "group", + } + # `source` is preserved for backwards compat — the per-group + # `docIds` arrays on each entry of `groups` are the granular + # truth. A consumer that only wants a single representative doc + # still has `source.document_id`; consumers that want per-group + # sample-row drill-downs read `groups[i].docIds`. + if parallel_doc_ids: + result["source"] = { + "dataset_id": dataset_id, + "document_id": parallel_doc_ids[0], + "variable_name": value_label, + } + return result + + # ----------------------------------------------------------------------- + # S5.3 cross-table joins — sister method to ``violin_groups`` that + # pairs two measurement columns per subject (or one measurement + # column with the subject's treatment label) for a scatter / strip- + # plot rendering. + # ----------------------------------------------------------------------- + + async def cross_table_pairs( + self, + dataset_id: str, + x_variable_contains: str, + y_variable_contains: str, + *, + join_on: Literal["subject", "treatment"], + group_by: str | None, + group_order: list[str] | None, + session: SessionData | None, + ) -> dict[str, Any]: + """Pair two ontologyTableRow columns (subject-join) or pair one + column with the subject's treatment label (treatment-join). + + Returns a dict matching the response contract: + ``{pairs, xLabel, yLabel, groupLabel, joinKind, unjoined, + source?, _meta?}`` — see ``apps/web/lib/ndi/tools/cross-table-query.ts`` + for the consumer side. + + Cap: response truncated at ``MAX_PAIRS`` pairs (1000); a + ``_meta.reason`` diagnostic is added when the cap fires. + """ + if not x_variable_contains or not y_variable_contains: + return _empty_pairs_response( + join_on, + reason="empty xVariableContains or yVariableContains", + ) + + if join_on == "subject": + return await self._cross_table_pairs_subject( + dataset_id, + x_variable_contains, + y_variable_contains, + group_by=group_by, + group_order=group_order, + session=session, + ) + # The treatment join is the only remaining branch — the + # `join_on: Literal["subject", "treatment"]` annotation locks + # the input to those two values at the route boundary. + return await self._cross_table_pairs_treatment( + dataset_id, + x_variable_contains, + y_variable_contains, + group_by=group_by, + group_order=group_order, + session=session, + ) + + async def _cross_table_pairs_subject( # noqa: PLR0912 — linear empty-state early-returns + groupBy resolution branches are clearer than a state machine + self, + dataset_id: str, + x_variable_contains: str, + y_variable_contains: str, + *, + group_by: str | None, + group_order: list[str] | None, + session: SessionData | None, + ) -> dict[str, Any]: + ontology = await self.summary.ontology_tables(dataset_id, session=session) + groups = ontology.get("groups", []) + if not groups: + return _empty_pairs_response( + "subject", + reason="no ontologyTableRow docs in dataset", + ) + + x_match = _find_matching_group(groups, x_variable_contains) + if x_match is None: + return _empty_pairs_response( + "subject", + reason=( + f"no ontologyTableRow column matched X " + f"'{x_variable_contains}'" + ), + available={ + "variable_names": [ + " | ".join(g.get("variableNames", []))[:120] + for g in groups[:5] + ], + }, + ) + x_group, x_col, x_label = x_match + + # Y must come from a DIFFERENT ontologyTableRow group than X — + # otherwise we'd be pairing two columns of the SAME table per + # subject, which is what ``violin_groups`` does, not a cross + # table join. + y_match = _find_matching_group( + groups, + y_variable_contains, + exclude_group_idx=_index_of_group(groups, x_group), + ) + if y_match is None: + return _empty_pairs_response( + "subject", + reason=( + f"no DIFFERENT ontologyTableRow column matched Y " + f"'{y_variable_contains}' (X matched " + f"'{x_col}' in group '{x_label}')" + ), + xLabel=x_label, + available={ + "variable_names": [ + " | ".join(g.get("variableNames", []))[:120] + for g in groups[:5] + ], + }, + ) + y_group, y_col, y_label = y_match + + x_rows = (x_group.get("table") or {}).get("rows") or [] + x_doc_ids = x_group.get("docIds") or [] + y_rows = (y_group.get("table") or {}).get("rows") or [] + y_doc_ids = y_group.get("docIds") or [] + + x_map = _build_subject_value_map( + x_rows, x_doc_ids, x_col, numeric=True, + ) + y_map = _build_subject_value_map( + y_rows, y_doc_ids, y_col, numeric=True, + ) + + # GroupBy resolution: try X first, then Y. The chat / panel + # tool documents this as "may live in EITHER table; backend + # searches group_x first, then group_y." + resolved_group_col: str | None = None + group_source_rows: list[dict[str, Any]] | None = None + group_label: str | None = None + if group_by: + x_resolved = _resolve_group_column(x_group, group_by) + if x_resolved is not None: + resolved_group_col = x_resolved + group_source_rows = x_rows + else: + y_resolved = _resolve_group_column(y_group, group_by) + if y_resolved is not None: + resolved_group_col = y_resolved + group_source_rows = y_rows + if resolved_group_col is None: + return _empty_pairs_response( + "subject", + reason=( + f"no column matched groupBy '{group_by}' in " + f"either matched table" + ), + xLabel=x_label, + yLabel=y_label, + available={ + "columns": _columns_for_pair_group_by( + x_group, y_group, x_col, y_col, + ), + }, + ) + group_label = resolved_group_col + + subject_to_group: dict[str, str] | None = None + if resolved_group_col is not None and group_source_rows is not None: + subject_to_group = _build_subject_group_map( + group_source_rows, resolved_group_col, + ) + + pairs, unjoined = _inner_join_pairs( + x_map, y_map, subject_to_group=subject_to_group, + ) + pairs = _order_pairs_by_group(pairs, group_order) + + truncated = len(pairs) > MAX_PAIRS + if truncated: + pairs = pairs[:MAX_PAIRS] + + result: dict[str, Any] = { + "pairs": pairs, + "xLabel": x_label, + "yLabel": y_label, + "groupLabel": group_label, + "joinKind": "subject", + "unjoined": unjoined, + } + if x_doc_ids: + first = x_doc_ids[0] + if isinstance(first, str) and first: + result["source"] = { + "dataset_id": dataset_id, + "document_id": first, + "x_variable_name": x_col, + "y_variable_name": y_col, + } + if not pairs: + result["_meta"] = { + "reason": ( + "X or Y group had no usable numeric values" + if not x_map or not y_map + else "no overlapping subjects between X and Y groups" + ), + "columns": _columns_for_pair_group_by( + x_group, y_group, x_col, y_col, + ), + "variable_names": [ + " | ".join(g.get("variableNames", []))[:120] + for g in groups[:5] + ], + } + elif truncated: + result["_meta"] = { + "reason": f"pair count truncated to MAX_PAIRS={MAX_PAIRS}", + } + return result + + async def _cross_table_pairs_treatment( # linear empty-state early-returns are clearer than a state machine + self, + dataset_id: str, + x_variable_contains: str, + y_variable_contains: str, + *, + group_by: str | None, + group_order: list[str] | None, + session: SessionData | None, + ) -> dict[str, Any]: + ontology = await self.summary.ontology_tables(dataset_id, session=session) + groups = ontology.get("groups", []) + if not groups: + return _empty_pairs_response( + "treatment", + reason="no ontologyTableRow docs in dataset", + ) + + x_match = _find_matching_group(groups, x_variable_contains) + if x_match is None: + return _empty_pairs_response( + "treatment", + reason=( + f"no ontologyTableRow column matched X " + f"'{x_variable_contains}'" + ), + available={ + "variable_names": [ + " | ".join(g.get("variableNames", []))[:120] + for g in groups[:5] + ], + }, + ) + x_group, x_col, x_label = x_match + + x_rows = (x_group.get("table") or {}).get("rows") or [] + x_doc_ids = x_group.get("docIds") or [] + x_map = _build_subject_value_map( + x_rows, x_doc_ids, x_col, numeric=True, + ) + + # Y-side: walk the treatment class chain, pick a per-subject + # label using ``_pick_treatment_label_for_needle``. Last-write- + # wins per subject across the chain so a later (treatment_drug) + # row supersedes an earlier (treatment) row when both are + # present — matches the F-1e Bhar shape. + treatment_map = await self._build_treatment_subject_map( + dataset_id, y_variable_contains, session=session, + ) + if not treatment_map: + return _empty_pairs_response( + "treatment", + reason=( + f"no treatment docs matched Y '{y_variable_contains}' " + f"across [treatment, treatment_drug, treatment_transfer]" + ), + xLabel=x_label, + yLabel=y_variable_contains, + ) + + # GroupBy: resolve against X table only — Y is a label, not a + # column. Cross-treatment grouping isn't meaningful here. + resolved_group_col: str | None = None + group_label: str | None = None + if group_by: + resolved_group_col = _resolve_group_column(x_group, group_by) + if resolved_group_col is None: + return _empty_pairs_response( + "treatment", + reason=( + f"no column matched groupBy '{group_by}' in X " + f"table" + ), + xLabel=x_label, + yLabel=y_variable_contains, + available={ + "columns": _columns_for_pair_group_by( + x_group, None, x_col, "", + ), + }, + ) + group_label = resolved_group_col + + subject_to_group: dict[str, str] | None = None + if resolved_group_col is not None: + subject_to_group = _build_subject_group_map( + x_rows, resolved_group_col, + ) + + pairs, unjoined = _inner_join_treatment_pairs( + x_map, treatment_map, subject_to_group=subject_to_group, + ) + pairs = _order_pairs_by_group(pairs, group_order) + + truncated = len(pairs) > MAX_PAIRS + if truncated: + pairs = pairs[:MAX_PAIRS] + + # The "groupLabel" reported back is either the explicit + # group_by column (when set) OR the literal "treatment" sentinel + # so the chart legend reads sensibly even when pairs are + # auto-colored by their treatment label. + effective_group_label = ( + group_label if group_label is not None else "treatment" + ) + + result: dict[str, Any] = { + "pairs": pairs, + "xLabel": x_label, + "yLabel": y_variable_contains or "treatment", + "groupLabel": effective_group_label, + "joinKind": "treatment", + "unjoined": unjoined, + } + if x_doc_ids: + first = x_doc_ids[0] + if isinstance(first, str) and first: + result["source"] = { + "dataset_id": dataset_id, + "document_id": first, + "x_variable_name": x_col, + "y_variable_name": y_variable_contains, + } + if not pairs: + result["_meta"] = { + "reason": ( + "X group had no usable numeric values" + if not x_map + else "no overlapping subjects between X and treatment" + ), + "columns": _columns_for_pair_group_by( + x_group, None, x_col, "", + ), + "variable_names": [ + " | ".join(g.get("variableNames", []))[:120] + for g in groups[:5] + ], + } + elif truncated: + result["_meta"] = { + "reason": f"pair count truncated to MAX_PAIRS={MAX_PAIRS}", + } + return result + + async def _build_treatment_subject_map( + self, + dataset_id: str, + y_variable_contains: str, + *, + session: SessionData | None, + ) -> dict[str, tuple[str, str | None]]: + """Build subject_id → (treatment_label, doc_id_or_None) by + walking ``_TREATMENT_CLASS_CHAIN``. + + Calls ``self.summary.single_class(...)`` for each class + (wrapping in try/except so a dead class is skipped, not fatal), + then picks per-subject labels via + ``_pick_treatment_label_for_needle``. Last-write-wins per + subject so a later class in the chain (e.g. treatment_drug) + supersedes an earlier match. + + Note: ``single_class`` rows don't carry per-row docIds in their + envelope (only ``columns`` + ``rows`` + ``distinct_summary``), + so doc_id is currently ``None`` for treatment-join pairs. The + scatter chart's primary citation still resolves via the X-side + (ontologyTableRow) docIds — the Y-side gap is acceptable. + """ + needle_lower = (y_variable_contains or "").lower() + out: dict[str, tuple[str, str | None]] = {} + for class_name in _TREATMENT_CLASS_CHAIN: + try: + envelope = await self.summary.single_class( + dataset_id, class_name, session=session, + ) + except Exception as e: # best-effort: any class can be missing on a given dataset, log and continue + log.warning( + "cross_table.treatment_class_fetch_failed", + dataset_id=dataset_id, + class_name=class_name, + error=str(e), + ) + continue + rows = envelope.get("rows") or [] + for row in rows: + sid = row.get(_SUBJECT_KEY) + if not isinstance(sid, str) or not sid: + continue + label = _pick_treatment_label_for_needle(row, needle_lower) + if not label: + continue + # Last-write-wins per subject for determinism. The chain + # is ordered with the most-canonical class (treatment) + # first; later passes (treatment_drug / treatment_transfer) + # are usually subclass-specific so the override is + # intentional when a more-specific label exists. + out[sid] = (label, None) + return out + + +# --------------------------------------------------------------------------- +# Internal helpers — each is single-purpose so the orchestrator stays linear. +# --------------------------------------------------------------------------- + + +def _empty_response( + group_by: str | None, + *, + reason: str, + yLabel: str = "", + available: dict[str, Any] | None = None, +) -> dict[str, Any]: + meta: dict[str, Any] = {"reason": reason} + if available: + meta.update(available) + return { + "groups": [], + "yLabel": yLabel, + "xLabel": group_by or "", + "_meta": meta, + } + + +def _alphanumeric_lower(s: str) -> str: + """Lowercase + strip non-alphanumerics for fuzzy substring matching. + + Stream 5.1 (2026-05-15): real column keys in ontologyTableRow tables + use underscores and CamelCase intermixed + (``ElevatedPlusMaze_OpenArmNorth_Entries``), while users / the chat + sometimes type contiguous CamelCase (``OpenArmNorthEntries``). A + direct case-insensitive substring match misses these because the + underscore breaks contiguity. Normalizing BOTH sides to + alphanumeric-only lowercase makes the comparison whitespace- and + punctuation-insensitive without changing the contiguity check. + """ + return "".join(ch for ch in s.lower() if ch.isalnum()) + + +def _find_matching_group( + groups: list[dict[str, Any]], + needle: str, + *, + exclude_group_idx: int | None = None, +) -> tuple[dict[str, Any], str, str] | None: + """Locate the best ontologyTableRow column matching the search + substring, preferring columns whose values are numeric. + + Real ontologyTableRow tables typically have multiple columns whose + names share the same topic prefix (e.g. ``ElevatedPlusMaze: Test + Identifier`` + ``ElevatedPlusMaze: Open Arm Entries`` + …). A naive + first-match would pick the identifier column → no numeric values → + empty violin. We instead score each matching column by how many + rows have finite-numeric values in it, and return the highest- + scoring column across all groups. + + Match strategy is two-pass (Stream 5.1, 2026-05-15): + 1. Direct case-insensitive substring match on key OR label + (precise; preserves existing semantics). + 2. Alphanumeric-stripped fallback when pass 1 returns no + numeric-column hit. Catches the `OpenArmNorthEntries` ↔ + `ElevatedPlusMaze_OpenArmNorth_Entries` mismatch. + + Ties broken by first-seen order (group order is already sorted by + row count desc in SummaryTableService). + + S5.3 (2026-05-18) — ``exclude_group_idx`` lets the cross-table + caller force the Y-side match into a DIFFERENT group than the X + side. ``None`` (the violin caller's default) preserves the + original semantics: every group considered. When set to an index + that exists, the corresponding group is skipped entirely (no + columns scored). Negative indices and indices ≥ len(groups) are + no-ops by design — they signal "the X group wasn't found in the + list" so we just behave like the unconstrained search. + """ + needle_lower = needle.lower() + needle_alnum = _alphanumeric_lower(needle) + # Pass 1 + Pass 2 share the loop; we capture the best precise hit + # first, then the best fuzzy hit. Precise wins on equal numeric + # counts. + best_precise: tuple[dict[str, Any], str, str, int] | None = None + best_fuzzy: tuple[dict[str, Any], str, str, int] | None = None + for idx, g in enumerate(groups): + if exclude_group_idx is not None and idx == exclude_group_idx: + continue + table = g.get("table") or {} + cols = table.get("columns") or [] + rows = table.get("rows") or [] + for col in cols: + key = str(col.get("key", "")) + label = str(col.get("label", "")) + key_lower = key.lower() + label_lower = label.lower() + is_precise = ( + needle_lower in key_lower or needle_lower in label_lower + ) + # Skip fuzzy work when precise already matches (saves the + # alnum compute on huge tables). + is_fuzzy = is_precise or ( + needle_alnum + and ( + needle_alnum in _alphanumeric_lower(key) + or needle_alnum in _alphanumeric_lower(label) + ) + ) + if not is_fuzzy: + continue + numeric_count = sum(1 for row in rows if _is_finite_numeric(row.get(key))) + if numeric_count == 0: + continue + tuple_ = (g, key, label or key, numeric_count) + if is_precise: + if best_precise is None or numeric_count > best_precise[3]: + best_precise = tuple_ + elif best_fuzzy is None or numeric_count > best_fuzzy[3]: + best_fuzzy = tuple_ + best = best_precise if best_precise is not None else best_fuzzy + if best is None: + return None + return best[0], best[1], best[2] + + +def _resolve_group_column( # noqa: PLR0911 — linear three-pass match is clearer than one collapsed branch + group: dict[str, Any], + group_by: str, +) -> str | None: + """Resolve a possibly-imprecise group_by argument to an actual + column key in the matched group. + + Three-pass resolution (Stream 5.1 expanded 2026-05-15): + 1. Exact key match (literal column-key argument from the user). + 2. Case-insensitive substring match on key, then on label + (preserves precision for column-name fragments). + 3. Alphanumeric-stripped substring match on key, then label — + catches the `Treatment_CNOOrSaline` ↔ `CNOorSaline` shape + where users mix underscore + CamelCase variants. + + Returns None when nothing matches so the caller can surface an + explicit error with the available column list. + """ + needle_lower = group_by.lower() + needle_alnum = _alphanumeric_lower(group_by) + cols = (group.get("table") or {}).get("columns") or [] + # Pass 1: exact key match wins immediately. + for col in cols: + if str(col.get("key", "")) == group_by: + return group_by + # Pass 2: case-insensitive substring — key first (more stable + # than labels). + for col in cols: + if needle_lower in str(col.get("key", "")).lower(): + return str(col["key"]) + for col in cols: + if needle_lower in str(col.get("label", "")).lower(): + return str(col["key"]) + # Pass 3: alphanumeric-stripped substring fallback. + if not needle_alnum: + return None + for col in cols: + if needle_alnum in _alphanumeric_lower(str(col.get("key", ""))): + return str(col["key"]) + for col in cols: + if needle_alnum in _alphanumeric_lower(str(col.get("label", ""))): + return str(col["key"]) + return None + + +def _is_finite_numeric(v: Any) -> bool: + """Defensive coerce — `True` only when `v` parses to a finite float.""" + if v is None: + return False + try: + return math.isfinite(float(v)) + except (TypeError, ValueError): + return False + + +def _bucket_rows( + rows: list[dict[str, Any]], + parallel_doc_ids: list[str], + value_col: str, + group_by: str | None, +) -> tuple[dict[str, list[float]], dict[str, list[str]], list[str]]: + """Walk rows, extract numeric value + grouping label + per-row docId. + + `parallel_doc_ids` is the ontologyTables-projection's docIds list, + same index order as `rows`. When the lists desynchronize (rows + longer than docIds — possible if the projection ever drops a doc + without dropping its row), we silently skip the missing-docId case + rather than spinning up bogus IDs. + + Returns (buckets_by_group_name, doc_ids_by_group_name, order_seen). + The per-bucket docIds list is parallel to the per-bucket values + list — `doc_ids_by_group_name[g][i]` is the document that + contributed `buckets[g][i]`. + """ + buckets: dict[str, list[float]] = {} + bucket_doc_ids: dict[str, list[str]] = {} + order_seen: list[str] = [] + for i, row in enumerate(rows): + v_raw = row.get(value_col) + if v_raw is None: + continue + try: + v = float(v_raw) + except (TypeError, ValueError): + continue + if not math.isfinite(v): + continue + if group_by: + g_raw = row.get(group_by) + if g_raw is None: + continue + g = str(g_raw) + else: + g = "all" + if g not in buckets: + buckets[g] = [] + bucket_doc_ids[g] = [] + order_seen.append(g) + buckets[g].append(v) + # Track the contributing docId when the projection surfaced one + # at this index. Missing docIds are tolerated (skip-only) so a + # partial projection doesn't poison the citations. + if i < len(parallel_doc_ids): + doc_id = parallel_doc_ids[i] + if isinstance(doc_id, str) and doc_id: + bucket_doc_ids[g].append(doc_id) + return buckets, bucket_doc_ids, order_seen + + +def _ordered_group_keys( + buckets: dict[str, list[float]], + order_seen: list[str], + group_order: list[str] | None, +) -> list[str]: + """Resolve final group ordering. Caller's explicit `group_order` + wins; unspecified groups append at the end (never silently + dropped); finally capped to MAX_GROUPS.""" + if group_order: + ordered = [g for g in group_order if g in buckets] + for g in order_seen: + if g not in ordered: + ordered.append(g) + else: + ordered = list(order_seen) + return ordered[:MAX_GROUPS] + + +def _build_group_payloads( + buckets: dict[str, list[float]], + bucket_doc_ids: dict[str, list[str]], + ordered_keys: list[str], +) -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + for name in ordered_keys: + vals = buckets.get(name) or [] + if not vals: + continue + stats = _summary_stats(vals) + # Cap raw values for the response payload — stats above were + # computed on the FULL list so they remain accurate. + sampled = _stride_sample(vals, MAX_VALUES_PER_GROUP) + # Per-group sample of contributing docIds. The chat consumes + # these to build per-group sample-row references so the user + # can drill into specific examples (e.g. "one Saline row" / + # "one CNO row") while the primary citation still points to + # the aggregated table view. Capped to avoid blowing the chip + # count on charts with many groups — 3 examples per group is + # plenty for verification. + group_doc_ids = bucket_doc_ids.get(name) or [] + out.append({ + "name": name, + "values": sampled, + "docIds": group_doc_ids[:MAX_DOC_IDS_PER_GROUP], + "totalRows": len(vals), + **stats, + }) + return out + + +def _summary_stats(values: list[float]) -> dict[str, float | int]: + """Compute the stats payload ViolinChart expects.""" + n = len(values) + sorted_v = sorted(values) + mean = statistics.fmean(values) + median = statistics.median(values) + std = statistics.stdev(values) if n >= 2 else 0.0 + # Linear-interpolated percentile — matches numpy.percentile default + # closely enough for chart annotation purposes. + q1 = _percentile(sorted_v, 25) + q3 = _percentile(sorted_v, 75) + return { + "count": n, + "mean": float(mean), + "median": float(median), + "std": float(std), + "min": float(sorted_v[0]), + "max": float(sorted_v[-1]), + "q1": float(q1), + "q3": float(q3), + } + + +def _percentile(sorted_values: list[float], p: float) -> float: + """Linear-interpolated percentile on a pre-sorted list.""" + if not sorted_values: + return 0.0 + if len(sorted_values) == 1: + return sorted_values[0] + rank = (p / 100.0) * (len(sorted_values) - 1) + lo = math.floor(rank) + hi = math.ceil(rank) + if lo == hi: + return sorted_values[lo] + frac = rank - lo + return sorted_values[lo] * (1 - frac) + sorted_values[hi] * frac + + +def _stride_sample(values: list[float], cap: int) -> list[float]: + """Stride-sample to (at most) `cap` points. Preserves first + last + via linspace-style stepping so the violin's jitter overlay shows + the distribution shape end-to-end.""" + n = len(values) + if n <= cap: + return list(values) + if cap <= 2: + return [values[0], values[-1]][:cap] + step = (n - 1) / (cap - 1) + indices = [round(i * step) for i in range(cap)] + # Dedupe in case rounding collapses adjacent indices (rare; + # happens only when `cap` approaches `n`). + seen: set[int] = set() + picked: list[int] = [] + for i in indices: + if i not in seen: + seen.add(i) + picked.append(i) + return [values[i] for i in picked] + + +# --------------------------------------------------------------------------- +# S5.3 cross-table helpers (2026-05-18). +# +# Module-level, single-purpose, so the orchestrator stays linear and +# every join shape is independently testable. The orchestrator calls +# these in this order: _find_matching_group (X) → _index_of_group → +# _find_matching_group (Y, with exclude_group_idx) → _build_subject_*_map +# → _inner_join_(treatment_)pairs → _order_pairs_by_group → MAX_PAIRS cap. +# --------------------------------------------------------------------------- + + +def _empty_pairs_response( + join_on: Literal["subject", "treatment"], + *, + reason: str, + xLabel: str = "", + yLabel: str = "", + available: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build the empty-pair response envelope. + + Mirrors :func:`_empty_response` for the violin path so the chat-side + retry/diagnostic logic can be uniform across both tools. + """ + meta: dict[str, Any] = {"reason": reason} + if available: + meta.update(available) + return { + "pairs": [], + "xLabel": xLabel, + "yLabel": yLabel, + "groupLabel": None, + "joinKind": join_on, + "unjoined": {"x_only": 0, "y_only": 0}, + "_meta": meta, + } + + +def _index_of_group( + groups: list[dict[str, Any]], + target: dict[str, Any] | None, +) -> int: + """Return the index of ``target`` in ``groups`` by identity. + + Used by the cross-table orchestrator to force the Y-side + ``_find_matching_group`` into a different group than the X side. + Returns ``-1`` when ``target`` is not in ``groups`` (treat as "no + exclusion", which preserves the unconstrained behavior of + ``_find_matching_group``). + """ + if target is None: + return -1 + for i, g in enumerate(groups): + if g is target: + return i + return -1 + + +def _build_subject_value_map( + rows: list[dict[str, Any]], + parallel_doc_ids: list[Any], + value_col: str, + *, + numeric: bool, +) -> dict[str, tuple[Any, str | None]]: + """Build ``{subject_id: (value, doc_id_or_None)}`` from a group's + rows. + + ``parallel_doc_ids[i]`` is the document that contributed + ``rows[i]`` (the ontology-tables-projection contract). When the + lists desynchronize (rows longer than doc_ids), the missing entry + is silently filled with ``None`` rather than spinning up bogus + IDs — matches the discipline in :func:`_bucket_rows`. + + When ``numeric=True``, non-finite-numeric values are skipped (rather + than coerced). When False, values are stringified + trimmed; empty + strings are skipped. + + Last-write-wins per subject so a later row's value supersedes an + earlier one when two rows share the same ``subjectDocumentIdentifier``. + """ + out: dict[str, tuple[Any, str | None]] = {} + for i, row in enumerate(rows): + sid = row.get(_SUBJECT_KEY) + if not isinstance(sid, str) or not sid: + continue + v_raw = row.get(value_col) + if v_raw is None: + continue + if numeric: + try: + v = float(v_raw) + except (TypeError, ValueError): + continue + if not math.isfinite(v): + continue + value: Any = v + else: + value_str = str(v_raw).strip() + if not value_str: + continue + value = value_str + doc_id: str | None = None + if i < len(parallel_doc_ids): + d = parallel_doc_ids[i] + if isinstance(d, str) and d: + doc_id = d + out[sid] = (value, doc_id) + return out + + +def _build_subject_group_map( + rows: list[dict[str, Any]], + group_col: str, +) -> dict[str, str]: + """Build ``{subject_id: group_label}`` from a group's rows. + + Stringifies non-string group values for stability. Skips rows + missing either the subject id or the group column. Last-write- + wins per subject. + """ + out: dict[str, str] = {} + for row in rows: + sid = row.get(_SUBJECT_KEY) + if not isinstance(sid, str) or not sid: + continue + g_raw = row.get(group_col) + if g_raw is None: + continue + g = str(g_raw).strip() + if not g: + continue + out[sid] = g + return out + + +def _columns_for_pair_group_by( + x_group: dict[str, Any] | None, + y_group: dict[str, Any] | None, + x_col: str, + y_col: str, +) -> list[str]: + """Return the union of column keys from the X and Y groups, with + the X-value and Y-value columns excluded. + + Used by the ``_empty_pairs_response`` diagnostics so the caller + (chat LLM, panel UI) can retry with a viable groupBy column. Capped + at 20 keys to keep the response envelope small. + """ + cols: list[str] = [] + seen: set[str] = set() + for g in (x_group, y_group): + if not g: + continue + for c in (g.get("table") or {}).get("columns") or []: + k = str(c.get("key", "")) + if not k or k in seen or k == x_col or (y_col and k == y_col): + continue + seen.add(k) + cols.append(k) + return cols[:20] + + +def _inner_join_pairs( + x_map: dict[str, tuple[Any, str | None]], + y_map: dict[str, tuple[Any, str | None]], + *, + subject_to_group: dict[str, str] | None, +) -> tuple[list[dict[str, Any]], dict[str, int]]: + """Inner-join the X and Y subject→value maps on subjectId. + + Returns ``(pairs, unjoined)`` where ``unjoined`` carries the + per-side count of subjects that didn't have a counterpart on the + other side — surfaced in the response so the chat tool can + accurately report e.g. "12 subjects measured X but no Y". + """ + pairs: list[dict[str, Any]] = [] + x_only = 0 + y_only = 0 + for sid, (xv, did_x) in x_map.items(): + if sid not in y_map: + x_only += 1 + continue + yv, did_y = y_map[sid] + pair: dict[str, Any] = { + "x": xv, + "y": yv, + "subjectId": sid, + } + if did_x: + pair["docIdX"] = did_x + if did_y: + pair["docIdY"] = did_y + if subject_to_group is not None: + g = subject_to_group.get(sid) + if g: + pair["group"] = g + pairs.append(pair) + for sid in y_map: + if sid not in x_map: + y_only += 1 + return pairs, {"x_only": x_only, "y_only": y_only} + + +def _inner_join_treatment_pairs( + x_map: dict[str, tuple[Any, str | None]], + treatment_map: dict[str, tuple[str, str | None]], + *, + subject_to_group: dict[str, str] | None, +) -> tuple[list[dict[str, Any]], dict[str, int]]: + """Treatment-join variant of :func:`_inner_join_pairs`. + + The Y value is the treatment label (string). When + ``subject_to_group`` is None, the pair's ``group`` defaults to its + treatment label so the resulting strip plot is naturally colored + by treatment (matches the user-facing chat tool semantics: + "Example: EPM open-arm time vs Saline/CNO treatment"). + """ + pairs: list[dict[str, Any]] = [] + x_only = 0 + y_only = 0 + for sid, (xv, did_x) in x_map.items(): + if sid not in treatment_map: + x_only += 1 + continue + label, did_y = treatment_map[sid] + pair: dict[str, Any] = { + "x": xv, + "y": label, + "subjectId": sid, + } + if did_x: + pair["docIdX"] = did_x + if did_y: + pair["docIdY"] = did_y + if subject_to_group is not None: + g = subject_to_group.get(sid) + if g: + pair["group"] = g + else: + # Auto-color by treatment label when no explicit groupBy. + pair["group"] = label + pairs.append(pair) + for sid in treatment_map: + if sid not in x_map: + y_only += 1 + return pairs, {"x_only": x_only, "y_only": y_only} + + +def _order_pairs_by_group( + pairs: list[dict[str, Any]], + group_order: list[str] | None, +) -> list[dict[str, Any]]: + """Reorder pairs so groups listed in ``group_order`` come first in + the specified order; unlisted groups preserve their original order + after the listed ones. + + Stable sort: same-group pairs keep their input-order relative to + each other. When ``group_order`` is None or empty, the input order + is returned unchanged. + """ + if not group_order: + return list(pairs) + rank = {name: i for i, name in enumerate(group_order)} + sentinel = len(group_order) + return sorted( + pairs, + key=lambda p: rank.get(str(p.get("group", "")), sentinel), + ) + + +def _pick_treatment_label_for_needle( + row: dict[str, Any], + needle_lower: str, +) -> str | None: + """Pick a treatment label from a projected treatment row. + + Strategy: + 1. If ``needle_lower`` is non-empty, prefer a row key whose + lowercase form contains the needle (e.g. needle="reference" + → key like "treatmentReference"). Returns the first non-empty + string value from such a key. + 2. Fallback: walk ``_TREATMENT_LABEL_FIELDS`` in priority order + (name → reference → mixture → drug …), return the first + non-empty string value from any row key that contains the + field name. + + Returns ``None`` when no string label is found — caller skips the + row when this happens (subject contributes no pair). + """ + if not row: + return None + + # Pass 1: needle-direct match. + if needle_lower: + for key, val in row.items(): + if not isinstance(val, str): + continue + stripped = val.strip() + if not stripped: + continue + if needle_lower in str(key).lower(): + return stripped + + # Pass 2: priority-ordered field fallback. + for field in _TREATMENT_LABEL_FIELDS: + field_lower = field.lower() + for key, val in row.items(): + if not isinstance(val, str): + continue + stripped = val.strip() + if not stripped: + continue + if field_lower in str(key).lower(): + return stripped + + return None diff --git a/backend/services/treatment_timeline_service.py b/backend/services/treatment_timeline_service.py new file mode 100644 index 0000000..74d3333 --- /dev/null +++ b/backend/services/treatment_timeline_service.py @@ -0,0 +1,570 @@ +"""treatment_timeline_service — project a dataset's ``treatment`` documents +into a Gantt-style horizontal timeline payload (one row per subject, +one bar per treatment period). + +This service ports the orchestration logic that used to live in the +Next.js chat tool layer (``apps/web/lib/ndi/tools/treatment-timeline.ts``) +to Python so the heart of NDI processing lives next to ndi-python where +it belongs. The TS handler now becomes a thin proxy. + +Endpoint strategy +───────────────── +1. PRIMARY: call :meth:`SummaryTableService.single_class` for the + ``treatment`` class. Each projected row carries + ``treatmentName``, ``treatmentOntology``, ``numericValue``, + ``stringValue`` and ``subjectDocumentIdentifier``. +2. FALLBACK: if the primary returns zero rows, call + :meth:`TabularQueryService.violin_groups` with + ``variableNameContains="Treatment"``. That hits any + ``ontologyTableRow`` whose schema surfaces a ``Treatment_*`` + column. We synthesize one bar per group with + ``subject = "group:<name>"`` so the chart at least shows the + treatment groups, even if per-subject granularity is lost. + +Temporal extraction +─────────────────── +Per-row best-effort, in order: + +- ``startDate``/``endDate`` (or ``startTime``/``endTime``) — explicit + field pair when present. +- ``numericValue`` as ``[start, end]`` (length-2 array), as scalar + point ``[start, start+1]`` (length-1 array OR raw number). +- ``stringValue`` parseable as ISO date → ``[date, date+1 day]``. +- ELSE: synthesize an ordinal slot per subject: each treatment in + order gets ``[i, i+1]``. + +The ``temporal_source`` discriminator surfaces how timing was +derived so the caller can mention the caveat in prose: + +- ``"explicit"`` — every plotted row carried real timing. +- ``"ordinal"`` — every plotted row was synthesized. +- ``"mixed"`` — some explicit, some synthesized. + +Output shape +──────────── +Returns RAW data (``items``, ``total_subjects``, ``total_treatments``, +``temporal_source``, optional ``empty_hint``). The chat tool's +``chart_payload`` framing is chat-specific and is reassembled by the +TS proxy — keeping the backend response chart-agnostic so the +workspace can consume the same payload directly. + +``empty_hint`` is set ONLY when BOTH the primary table and the +fallback tabular_query returned zero rows, OR when rows came back +but none had a usable subject + treatment pair to plot. +""" +from __future__ import annotations + +import math +from datetime import UTC, datetime, timedelta +from typing import Any, Literal + +from ..auth.session import SessionData +from ..observability.logging import get_logger +from .summary_table_service import SummaryTableService +from .tabular_query_service import TabularQueryService + +log = get_logger(__name__) + + +# Default + hard-cap for subjects in a single chart. Beyond ~100 the +# chart becomes a wall of bars; Plotly's row sizing also chokes the +# chat panel at that count. Matches the TS handler's bounds. +DEFAULT_MAX_SUBJECTS = 30 +HARD_CAP_MAX_SUBJECTS = 100 + + +# Type alias for the temporal-source discriminator. +TemporalSource = Literal["explicit", "ordinal", "mixed"] + + +class TreatmentTimelineService: + """Build the treatment-timeline payload for a dataset.""" + + def __init__( + self, + summary: SummaryTableService, + tabular: TabularQueryService, + ) -> None: + self.summary = summary + self.tabular = tabular + + async def compute_timeline( + self, + dataset_id: str, + *, + title: str | None, + max_subjects: int, + session: SessionData | None, + ) -> dict[str, Any]: + """Compute the timeline. Caller is responsible for clamping + ``max_subjects`` to the [1, 100] window — the pydantic model + on the router does that. + """ + # --- Primary: /tables/treatment via SummaryTableService --- + rows, available_columns = await self._fetch_primary_rows( + dataset_id, session=session, + ) + + # --- Fallback: tabular_query for Treatment_* columns --- + if not rows: + fallback_rows, fallback_columns = await self._fetch_fallback_rows( + dataset_id, session=session, + ) + if fallback_rows: + rows = fallback_rows + if fallback_columns: + available_columns = fallback_columns + + items, total_subjects, temporal_source = _project_rows_to_items( + rows, max_subjects=max_subjects, + ) + + empty_hint = _maybe_build_empty_hint(rows, items, available_columns) + + result: dict[str, Any] = { + "datasetId": dataset_id, + "items": items, + "total_subjects": total_subjects, + "total_treatments": len(items), + "temporal_source": temporal_source, + } + if title: + result["title"] = title + if empty_hint is not None: + result["empty_hint"] = empty_hint + return result + + # Class-fallback chain for primary treatment rows. Datasets emit + # one of: + # - ``treatment`` (canonical legacy) + # - ``treatment_drug`` (newer NDI ingest path; Bhar etc.) + # Try them in order; the FIRST class that yields a non-empty row + # set wins. The Stream 5.2 audit (2026-05-15) added + # ``treatment_drug`` after Bhar's TreatmentTimeline panel surfaced + # an empty chart despite the dataset carrying 24466 treatment_drug + # docs (catalog Bhar `69bc5ca1…`, per the 2026-05-14 tutorial + # ground-truth). + # 2026-05-19 — F-1e follow-up: extend the chain to include + # ``treatment_transfer``. Bhar (`69bc5ca1…`) carries 24,466 + # ``treatment_drug`` + 1,675 ``treatment_transfer`` documents + # but ZERO literal ``treatment`` documents — the MATLAB tutorial's + # treatmentTable expects BOTH categories merged (heat pulses come + # from `treatment_drug`, isoamylol/E.coli transfer events from + # `treatment_transfer`). Without recognizing transfer the panel + # only renders the drug events. Same first-non-empty-wins semantics. + _TREATMENT_CLASS_CHAIN: tuple[str, ...] = ( + "treatment", + "treatment_drug", + "treatment_transfer", + ) + + async def _fetch_primary_rows( + self, + dataset_id: str, + *, + session: SessionData | None, + ) -> tuple[list[dict[str, Any]], list[str]]: + """Pull treatment rows from the canonical class projections. + + 2026-05-19 (F-1e follow-up) — semantics CHANGED from + first-non-empty-wins to merge-all-non-empty. The original chain + treated each class as a fallback for the previous one, but + Bhar (`69bc5ca1…`) carries treatment in MULTIPLE classes + simultaneously (24,466 ``treatment_drug`` + 1,675 + ``treatment_transfer``); the MATLAB tutorial's treatmentTable + merges all of them into one Gantt. Same shape for any future + legacy-subclass: accumulate, don't short-circuit. + + Returns ``(rows, available_column_keys)`` — rows are the union + of every class in ``_TREATMENT_CLASS_CHAIN`` that yielded a + non-empty projection. Column-keys union too, deduplicated, so + the empty-hint surface reflects every field name the table + DID carry. + + Errors on individual class fetches are logged + skipped so a + cloud hiccup on one class doesn't abort the whole panel. + """ + accumulated_rows: list[dict[str, Any]] = [] + accumulated_columns: list[str] = [] + seen_columns: set[str] = set() + contributing_classes: list[str] = [] + for class_name in self._TREATMENT_CLASS_CHAIN: + try: + table = await self.summary.single_class( + dataset_id, class_name, session=session, + ) + except Exception as exc: + # Service-internal failures (cloud unreachable, required + # enrichment failed, etc.) should not abort the whole + # endpoint — the next class probe may still succeed. + log.warning( + "treatment_timeline.primary_failed", + dataset_id=dataset_id, + class_name=class_name, + error_type=type(exc).__name__, + error=str(exc), + ) + continue + rows = table.get("rows") or [] + columns = [ + c.get("key") for c in (table.get("columns") or []) + if isinstance(c.get("key"), str) and c.get("key") + ] + # Union the column-key list across classes — useful when + # the chart ends up empty (empty_hint shows the fields + # the merged rows DID carry, even if no single class + # carried timing). + for col in columns: + if col not in seen_columns: + seen_columns.add(col) + accumulated_columns.append(col) + if rows: + accumulated_rows.extend(rows) + contributing_classes.append(class_name) + if accumulated_rows: + log.info( + "treatment_timeline.primary_resolved", + dataset_id=dataset_id, + contributing_classes=contributing_classes, + row_count=len(accumulated_rows), + ) + return accumulated_rows, accumulated_columns + + async def _fetch_fallback_rows( + self, + dataset_id: str, + *, + session: SessionData | None, + ) -> tuple[list[dict[str, Any]], list[str]]: + """Fallback when ``/tables/treatment`` is empty. Hits + ``tabular_query`` against the ``Treatment`` substring; if + that resolves to an ontologyTableRow ``Treatment_*`` column, + the response carries one group per distinct value. + + We synthesize one row per group with + ``subject = "group:<name>"`` and ``treatmentName = <name>``, + no explicit timing. This loses per-subject granularity but at + least surfaces the treatment categories visually. + """ + try: + result = await self.tabular.violin_groups( + dataset_id, + "Treatment", + group_by=None, + group_order=None, + session=session, + ) + except Exception as exc: + log.warning( + "treatment_timeline.fallback_failed", + dataset_id=dataset_id, + error_type=type(exc).__name__, + error=str(exc), + ) + return [], [] + groups = result.get("groups") or [] + if not groups: + return [], [] + rows: list[dict[str, Any]] = [ + { + "treatmentName": g.get("name"), + "subjectDocumentIdentifier": f"group:{g.get('name')}", + } + for g in groups + if isinstance(g, dict) and g.get("name") + ] + # tabular_query doesn't expose a column key list, but the + # yLabel is the matched column's human label — useful diagnostic. + columns: list[str] = [] + y_label = result.get("yLabel") + if isinstance(y_label, str) and y_label: + columns = [y_label] + return rows, columns + + +# --------------------------------------------------------------------------- +# Projection — pure helpers, no IO +# --------------------------------------------------------------------------- + + +def _project_rows_to_items( + rows: list[dict[str, Any]], + *, + max_subjects: int, +) -> tuple[list[dict[str, Any]], int, TemporalSource]: + """Walk treatment rows and project to ``[{subject, treatment, start, end}, ...]``. + + Each subject gets its own ordinal counter so synthesized timing + starts at ``[0, 1]`` for the first treatment per subject. The + ``max_subjects`` cap applies to DISTINCT subjects (not bars) and + is enforced first-seen — once we've seen N subjects, any + subsequent row whose subject isn't already in the chart is + dropped silently. + """ + items: list[dict[str, Any]] = [] + seen_subjects: list[str] = [] + seen_subject_set: set[str] = set() + subject_ordinal_counter: dict[str, int] = {} + explicit_count = 0 + ordinal_count = 0 + + for row in rows: + subject = _pick_subject_label(row) + if not subject: + continue + treatment = _pick_treatment_label(row) + if not treatment: + continue + + if subject not in seen_subject_set: + if len(seen_subjects) >= max_subjects: + # Cap enforced — silently drop subjects beyond N. + continue + seen_subject_set.add(subject) + seen_subjects.append(subject) + + explicit = _extract_explicit_timing(row) + if explicit is not None: + start, end = explicit + explicit_count += 1 + else: + i = subject_ordinal_counter.get(subject, 0) + start = i + end = i + 1 + subject_ordinal_counter[subject] = i + 1 + ordinal_count += 1 + + items.append( + { + "subject": subject, + "treatment": treatment, + "start": start, + "end": end, + }, + ) + + temporal_source = _classify_temporal_source(explicit_count, ordinal_count) + return items, len(seen_subjects), temporal_source + + +def _classify_temporal_source( + explicit_count: int, ordinal_count: int, +) -> TemporalSource: + """Discriminate the timing source. When both counts are zero + (no items at all) we default to ``"ordinal"`` to match the TS + handler's defaulting — the value is unused at the call site + since the chart is empty, but it must be a valid literal.""" + if explicit_count > 0 and ordinal_count == 0: + return "explicit" + if explicit_count == 0 and ordinal_count > 0: + return "ordinal" + if explicit_count > 0 and ordinal_count > 0: + return "mixed" + return "ordinal" + + +def _pick_subject_label(row: dict[str, Any]) -> str | None: + """Prefer ``subjectDocumentIdentifier`` (canonical); fall back to + a bare ``subject`` field for forward-compat with future backends. + """ + s = row.get("subjectDocumentIdentifier") + if isinstance(s, str) and s: + return s + alt = row.get("subject") + if isinstance(alt, str) and alt: + return alt + return None + + +def _pick_treatment_label(row: dict[str, Any]) -> str | None: + """Prefer ``treatmentName``; fall back to ``treatment_drug`` / + ``drugName`` / ``compound`` (Stream 5.2: treatment_drug class + typically emits one of these as the human-readable label) and + finally ``stringValue`` when the value column carries a + categorical label and the name is missing. + """ + for key in ( + "treatmentName", + "treatment_drug", + "drugName", + "compound", + ): + t = row.get(key) + if isinstance(t, str) and t: + return t + sv = row.get("stringValue") + if isinstance(sv, str) and sv: + return sv + return None + + +def _extract_explicit_timing( + row: dict[str, Any], +) -> tuple[float | str, float | str] | None: + """Best-effort extract ``(start, end)`` from a treatment row, or + None when no usable timing is present. + + Lookup order (Stream 5.2 expanded the field set 2026-05-15 to + cover the ``treatment_drug`` class's ``administration_start_time`` + / ``administration_end_time`` pair): + 1. ``startDate`` + ``endDate`` + 2. ``startTime`` + ``endTime`` + 3. ``administration_start_time`` + ``administration_end_time`` + (treatment_drug native) + 4. ``numericValue`` as ``[start, end]`` (length-2) or + ``[start]`` (length-1 → ``[start, start+1]``) or raw scalar + (treated the same as length-1). + 5. ``stringValue`` parseable as ISO date → ``[date, date + 1 day]``. + """ + # Explicit field pair — try each (start, end) pair in priority + # order. Stops at the first pair where BOTH fields are usable + # (non-empty string or finite number). + pairs: tuple[tuple[str, str], ...] = ( + ("startDate", "endDate"), + ("startTime", "endTime"), + # Treatment_drug native — Stream 5.2 expanded fallback. + ("administration_start_time", "administration_end_time"), + ) + for start_key, end_key in pairs: + start_field = row.get(start_key) + end_field = row.get(end_key) + if _is_usable_temporal_field(start_field) and _is_usable_temporal_field(end_field): + # mypy: we already narrowed to str|number in the helper. + # The return type Union retains the original literal value + # so date strings flow through verbatim. + return start_field, end_field # type: ignore[return-value] + + # numericValue array OR scalar. + nv = row.get("numericValue") + if isinstance(nv, list): + if len(nv) >= 2 and _is_finite_number(nv[0]) and _is_finite_number(nv[1]): + return float(nv[0]), float(nv[1]) + if len(nv) == 1 and _is_finite_number(nv[0]): + return float(nv[0]), float(nv[0]) + 1.0 + elif _is_finite_number(nv): + return float(nv), float(nv) + 1.0 + + # stringValue as parseable ISO date — synthesize a 1-day window. + sv = row.get("stringValue") + if isinstance(sv, str) and sv: + parsed = _parse_iso_datetime(sv) + if parsed is not None: + end_dt = parsed + timedelta(days=1) + # Match the TS handler's contract: original string back as + # start so Plotly's date axis renders verbatim; end is the + # +1 day ISO string. + return sv, end_dt.isoformat() + + return None + + +def _is_usable_temporal_field(v: Any) -> bool: + """A temporal field is usable when it's a non-empty string or a + finite number. None / empty string / NaN / inf are rejected.""" + if isinstance(v, str): + return bool(v) + return _is_finite_number(v) + + +def _is_finite_number(v: Any) -> bool: + """True iff ``v`` is a finite int/float (bool excluded — bool is + a subclass of int in Python and we don't want True/False slipping + through as 1/0).""" + if isinstance(v, bool): + return False + if isinstance(v, (int, float)): + return math.isfinite(float(v)) + return False + + +def _parse_iso_datetime(s: str) -> datetime | None: + """Best-effort datetime parse. Accepts: + + - ISO-8601 / RFC-3339 (``2024-11-03T07:53:00Z``, ``2024-11-03``). + - MATLAB ``datestr`` default (``DD-MMM-YYYY HH:MM:SS``, + e.g. ``03-Nov-2023 07:53:00``) — Haley's literal ``treatment`` + docs publish ``string_value`` in this format (B3 follow-up, + 2026-05-18). Without this branch every Haley treatment row + fell through to ordinal timing, even though real timestamps + were available on every doc. + + Returns ``None`` on failure — the caller falls back to ordinal + timing. + """ + # ``datetime.fromisoformat`` handles RFC-3339-style strings; we + # normalize a trailing ``Z`` to ``+00:00`` because pre-3.11 + # interpreters reject it. + normalized = s.replace("Z", "+00:00") if s.endswith("Z") else s + dt: datetime | None + try: + dt = datetime.fromisoformat(normalized) + except (ValueError, TypeError): + dt = _parse_matlab_datestr(s) + if dt is None: + return None + # Make tz-aware so isoformat round-trips deterministically. Naive + # → assume UTC (matching JS ``Date.parse`` of a bare date string). + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + return dt + + +# MATLAB ``datestr`` default format (e.g. ``03-Nov-2023 07:53:00``). +# Used by `_parse_iso_datetime` as a fallback for B3 (Haley's +# food-restriction-onset/offset times publish in this shape). +# Format token meaning: ``%d-%b-%Y %H:%M:%S`` — day, three-letter +# month abbreviation, four-digit year, 24h time. The date-only +# variant (``DD-MMM-YYYY``) is also accepted as a courtesy. +_MATLAB_DATESTR_FORMATS: tuple[str, ...] = ( + "%d-%b-%Y %H:%M:%S", + "%d-%b-%Y", +) + + +def _parse_matlab_datestr(s: str) -> datetime | None: + """Try MATLAB ``datestr`` default formats. Returns ``None`` on + failure so the caller can keep falling through. + + ``strptime`` is locale-sensitive for ``%b`` on POSIX, but the + English month abbreviations MATLAB emits (``Jan``, ``Feb``, ... + ``Dec``) are recognized under the default ``C``/``en_*`` locales + that Railway containers ship with. We do not flip locale here. + """ + for fmt in _MATLAB_DATESTR_FORMATS: + try: + return datetime.strptime(s, fmt) + except (ValueError, TypeError): + continue + return None + + +def _maybe_build_empty_hint( + rows: list[dict[str, Any]], + items: list[dict[str, Any]], + available_columns: list[str], +) -> dict[str, Any] | None: + """Diagnostic envelope when the chart would render empty. + + Distinguishes the two empty modes (matches TS handler): + + - ``rows == []`` → "no temporal info in treatment docs (neither + /tables/treatment nor tabular_query returned rows)". + - ``rows`` non-empty but ``items == []`` → rows came back but + none had a usable subject + treatment pair to plot. + """ + if items: + return None + if not rows: + reason = ( + "no temporal info in treatment docs " + "(neither /tables/treatment nor tabular_query returned rows)" + ) + else: + reason = ( + "treatment rows returned but none had a usable subject + " + "treatment pair to plot" + ) + hint: dict[str, Any] = {"reason": reason} + if available_columns: + hint["available_columns"] = available_columns + return hint diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 18f4a1b..9334bfc 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -8,6 +8,7 @@ import fakeredis.aioredis import pytest import respx +import structlog from cryptography.fernet import Fernet os.environ.setdefault("NDI_CLOUD_URL", "https://api.example.test/v1") @@ -27,6 +28,44 @@ ) +@pytest.fixture(autouse=True) +def _reset_structlog_for_capture() -> None: + """Stream 6.6 fix (2026-05-15) — pretest isolation for structlog. + + Several tests use ``structlog.testing.capture_logs()`` to assert that + a specific event was emitted. `capture_logs` is a context manager that + activates an in-memory processor — but only for log calls made through + the global structlog config it sees at __enter__ time. If a prior test + re-configured structlog (via ``backend.observability.logging.configure_logging`` + or test-local ``structlog.configure(...)``), the cached ``BoundLogger`` + instances created at module-import time no longer point at the capture + processor and emit through the pre-existing chain instead. The visible + symptom: the WARNING log line is captured by stdlib logging (see the + ``Captured log call`` section in pytest output) but the + ``logs`` list passed to the test is empty. + + Fix: + 1. ``reset_defaults()`` — undo any prior ``structlog.configure(...)`` + call so the loggers fall back to fresh defaults. + 2. ``configure(... cache_logger_on_first_use=False ...)`` — re-bind + with caching DISABLED so future ``get_logger(...)`` calls (and the + module-level cached references) resolve through the current + processor chain on every emit, picking up ``capture_logs``'s + in-memory processor when it's active. + + The three pretest-isolation flakes this closes: + - test_cloud_client.py::test_download_from_off_allowlist_host_hard_rejects + - test_cloud_client.py::test_download_non_http_scheme_rejected + - test_origin_enforcement.py::test_post_with_disallowed_referer_origin_returns_403_forbidden + """ + structlog.reset_defaults() + structlog.configure( + processors=[structlog.testing.LogCapture()], + wrapper_class=structlog.make_filtering_bound_logger(0), + cache_logger_on_first_use=False, + ) + + @pytest.fixture async def fake_redis() -> AsyncIterator[Any]: client = fakeredis.aioredis.FakeRedis(decode_responses=True) diff --git a/backend/tests/integration/test_dataset_binding_live.py b/backend/tests/integration/test_dataset_binding_live.py new file mode 100644 index 0000000..fa9ca70 --- /dev/null +++ b/backend/tests/integration/test_dataset_binding_live.py @@ -0,0 +1,83 @@ +"""LIVE integration test for DatasetBindingService. + +Hits the real ndi-cloud-node + downloads a real (small) dataset. +SKIPPED in CI by default — set ``LIVE_NDI_TESTS=1`` locally to run. + +Why this is gated: + - Requires NDI-python installed (vlt, ndicompress, ndi.cloud) + - Requires reachable cloud-node API + S3 + - Cold load is ~10-30s wall clock — adds significant CI runtime + - Cloud-side data drift could flake the test in unrelated PRs + +The intent is just to prove the pipe works end-to-end. We don't assert +exact element counts (cloud data may change) — only that the service +returned a non-None dict with the documented keys. + +Run locally: + LIVE_NDI_TESTS=1 uv run pytest backend/tests/integration/test_dataset_binding_live.py -v +""" +from __future__ import annotations + +import os + +import pytest + +# Bhar dataset — small + stable, used as the demo elsewhere in the +# repo. Switch to a different ID if it ever goes away. +DEMO_DATASET_ID = "69bc5ca11d547b1f6d083761" + + +@pytest.mark.skipif( + os.environ.get("LIVE_NDI_TESTS", "") not in ("1", "true", "yes"), + reason="LIVE_NDI_TESTS not set — skipping cloud-hitting integration test", +) +async def test_overview_against_real_cloud(): + """End-to-end smoke. Downloads a real dataset, computes overview, + asserts the response shape. + """ + from backend.services.dataset_binding_service import DatasetBindingService + + svc = DatasetBindingService() + overview = await svc.overview(DEMO_DATASET_ID) + + assert overview is not None, ( + "binding returned None — check NDI-python install + cloud-node auth" + ) + # Documented keys. + for k in ( + "element_count", + "subject_count", + "epoch_count", + "elements", + "elements_truncated", + "reference", + "cache_hit", + "cache_age_seconds", + ): + assert k in overview, f"missing key: {k}" + + # First call is cold; cache_hit MUST be False. + assert overview["cache_hit"] is False + # Type sanity. + assert isinstance(overview["element_count"], int) + assert isinstance(overview["elements"], list) + + +@pytest.mark.skipif( + os.environ.get("LIVE_NDI_TESTS", "") not in ("1", "true", "yes"), + reason="LIVE_NDI_TESTS not set", +) +async def test_warm_call_after_cold_load(): + """Second call on the same service instance reports cache_hit=True + and a positive cache_age_seconds. Pins the LRU bookkeeping against + a real download. + """ + from backend.services.dataset_binding_service import DatasetBindingService + + svc = DatasetBindingService() + cold = await svc.overview(DEMO_DATASET_ID) + assert cold is not None + warm = await svc.overview(DEMO_DATASET_ID) + assert warm is not None + assert warm["cache_hit"] is True + assert warm["cache_age_seconds"] > 0.0 diff --git a/backend/tests/integration/test_routes.py b/backend/tests/integration/test_routes.py index 40153ce..8418a1d 100644 --- a/backend/tests/integration/test_routes.py +++ b/backend/tests/integration/test_routes.py @@ -561,6 +561,147 @@ def test_single_class_table_is_redis_cached(app_and_cloud) -> None: # type: ign assert r1.json() == r2.json() +def test_single_class_pagination_via_query_params(app_and_cloud) -> None: # type: ignore[no-untyped-def] + """Stream 5.8 (2026-05-16) — server-side pagination on /tables/{class}. + + Verifies the new envelope ``{page, pageSize, totalRows, hasMore}`` is + returned when ``?page`` and ``?pageSize`` are provided, AND that page 2 + is served from the same cached full row set as page 1 (zero extra + cloud calls). This is the 95%-egress-saving invariant. + """ + client, router = app_and_cloud + + # Build a synthetic 5-subject result so we can paginate page_size=2 + # → 3 pages (rows 0..1, 2..3, 4). + five_ids = [f"sub{i}" for i in range(5)] + ndiquery_route = router.post("/ndiquery").respond( + 200, + json={ + "number_matches": 5, + "pageSize": 1000, + "page": 1, + "documents": [{"id": sid} for sid in five_ids], + }, + ) + router.post("/datasets/DS1/documents/bulk-fetch").respond( + 200, + json={ + "documents": [ + { + "id": sid, + "ndiId": f"ndi-{sid}", + "data": { + "base": {"id": f"ndi-{sid}", "session_id": "sess"}, + "subject": {"local_identifier": f"local-{sid}"}, + "document_class": {"class_name": "subject"}, + }, + } + for sid in five_ids + ], + }, + ) + + # Page 1 — top of the table. + r1 = client.get("/api/datasets/DS1/tables/subject?page=1&pageSize=2") + assert r1.status_code == 200, r1.json() + body1 = r1.json() + assert body1["page"] == 1 + assert body1["pageSize"] == 2 + assert body1["totalRows"] == 5 + assert body1["hasMore"] is True + assert len(body1["rows"]) == 2 + # distinct_summary is carried through verbatim (computed on the FULL + # row set, not the page slice). + assert "distinct_summary" in body1 + + first_call_count = ndiquery_route.call_count + + # Page 2 — middle of the table. Served from cache, no extra cloud hit. + r2 = client.get("/api/datasets/DS1/tables/subject?page=2&pageSize=2") + assert r2.status_code == 200 + body2 = r2.json() + assert body2["page"] == 2 + assert body2["totalRows"] == 5 + assert body2["hasMore"] is True + assert len(body2["rows"]) == 2 + assert ndiquery_route.call_count == first_call_count, ( + "Page 2 should slice the cached full envelope — no new cloud calls" + ) + + # Page 3 — last (partial) page. + r3 = client.get("/api/datasets/DS1/tables/subject?page=3&pageSize=2") + assert r3.status_code == 200 + body3 = r3.json() + assert body3["page"] == 3 + assert body3["totalRows"] == 5 + assert body3["hasMore"] is False + assert len(body3["rows"]) == 1 + + +def test_single_class_unpaged_request_keeps_legacy_envelope( + app_and_cloud, # type: ignore[no-untyped-def] +) -> None: + """BC check: unpaged call (no page/pageSize) returns the legacy + ``{columns, rows, distinct_summary}`` envelope so existing callers + (Document Explorer's full-set fetch, cron warm-cache) don't break.""" + client, router = app_and_cloud + + router.post("/ndiquery").respond( + 200, + json={ + "number_matches": 1, "pageSize": 1000, "page": 1, + "documents": [{"id": "sub1"}], + }, + ) + router.post("/datasets/DS1/documents/bulk-fetch").respond( + 200, + json={"documents": [{ + "id": "sub1", "ndiId": "ndi-sub1", + "data": { + "base": {"id": "ndi-sub1", "session_id": "sess1"}, + "subject": {"local_identifier": "local-id"}, + "document_class": {"class_name": "subject"}, + }, + }]}, + ) + + r = client.get("/api/datasets/DS1/tables/subject") + assert r.status_code == 200 + body = r.json() + # Paged fields MUST NOT be present on the unpaged response. + assert "page" not in body + assert "pageSize" not in body + assert "totalRows" not in body + assert "hasMore" not in body + # Legacy fields still there. + assert "columns" in body + assert "rows" in body + + +def test_single_class_pagination_rejects_out_of_range_inputs( + app_and_cloud, # type: ignore[no-untyped-def] +) -> None: + """FastAPI Query bounds rejection: page<1, pageSize<1, pageSize>1000 + all surface as 400 Bad Request (the app's request-validation + middleware remaps pydantic 422 → 400 for consistency with auth + + body-shape rejections — see existing 400-asserting tests in + test_auth_proxy.py). Prevents pathological queries from sneaking + past the safety guard.""" + client, _ = app_and_cloud + + # page=0 violates ge=1. + r = client.get("/api/datasets/DS1/tables/subject?page=0&pageSize=10") + assert r.status_code == 400 + + # pageSize=0 violates ge=1. + r = client.get("/api/datasets/DS1/tables/subject?page=1&pageSize=0") + assert r.status_code == 400 + + # pageSize=1001 violates le=1000. + r = client.get("/api/datasets/DS1/tables/subject?page=1&pageSize=1001") + assert r.status_code == 400 + + def test_ontology_endpoint_is_redis_cached(app_and_cloud) -> None: # type: ignore[no-untyped-def] """Ontology table grouping is cached under table:{ds}:ontology:{mode}.""" client, router = app_and_cloud @@ -1660,3 +1801,394 @@ def test_raw_416_from_upstream_returns_typed_400( assert r.status_code == 400 body = r.json() assert body["error"]["code"] == "VALIDATION_ERROR" + + +# --------------------------------------------------------------------------- +# F-8 (2026-05-19) — /tabular_query GET == POST equivalence +# +# The router exposes BOTH a GET and a POST variant of /tabular_query so the +# cloud-app workspace wrapper at /api/datasets/[id]/tabular-query (POST-only +# at the Vercel proxy layer) can forward its body 1:1 without GET-vs-POST +# translation. Both endpoints share the same `_dispatch` helper and the +# same `TabularQueryBody` Pydantic schema — so the responses MUST be +# byte-identical for the same parameters. +# --------------------------------------------------------------------------- + + +def _csrf_pair(client) -> tuple[str, dict[str, str]]: # type: ignore[no-untyped-def] + """Mint a CSRF token by hitting /api/auth/csrf and return the + matching X-XSRF-TOKEN header dict. TestClient's cookie jar + persists the XSRF-TOKEN cookie automatically across calls.""" + r = client.get("/api/auth/csrf") + assert r.status_code == 200 + token = r.json()["csrfToken"] + return token, {"X-XSRF-TOKEN": token} + + +def _install_empty_ontology_mocks(router) -> None: # type: ignore[no-untyped-def] + """Wire the cloud mocks so /tables/ontology builds successfully with + zero rows. This lets violin_groups hit the `_empty_response` path + deterministically — perfect for GET/POST shape equivalence pinning. + """ + router.post("/ndiquery").respond( + 200, + json={ + "number_matches": 0, + "pageSize": 1000, + "page": 1, + "documents": [], + }, + ) + router.post("/datasets/DS1/documents/bulk-fetch").respond( + 200, json={"documents": []}, + ) + + +def test_tabular_query_get_and_post_return_identical_shape( + app_and_cloud, +) -> None: # type: ignore[no-untyped-def] + """F-8: GET ?variableNameContains=... and POST {variableNameContains: ...} + against the same dataset MUST return byte-equal response bodies.""" + client, router = app_and_cloud + _install_empty_ontology_mocks(router) + + # GET — query params. + r_get = client.get( + "/api/datasets/DS1/tabular_query", + params={ + "variableNameContains": "ElevatedPlusMaze", + "groupBy": "treatment_group", + "groupOrder": "Saline,CNO", + }, + ) + assert r_get.status_code == 200 + get_body = r_get.json() + + # POST — same params in JSON body. Needs CSRF (POST is a mutation-shaped + # method even though this particular endpoint is idempotent; we keep it + # gated so the same machinery covers it). + _, csrf_headers = _csrf_pair(client) + r_post = client.post( + "/api/datasets/DS1/tabular_query", + json={ + "variableNameContains": "ElevatedPlusMaze", + "groupBy": "treatment_group", + "groupOrder": "Saline,CNO", + }, + headers=csrf_headers, + ) + assert r_post.status_code == 200 + post_body = r_post.json() + + # Same dispatcher, same validator, same response shape. + assert get_body == post_body + + +def test_tabular_query_get_and_post_handle_optional_params_identically( + app_and_cloud, +) -> None: # type: ignore[no-untyped-def] + """F-8: when groupBy / groupOrder are unset, both GET (params absent) + and POST (fields absent from JSON) produce the same response.""" + client, router = app_and_cloud + _install_empty_ontology_mocks(router) + + r_get = client.get( + "/api/datasets/DS1/tabular_query", + params={"variableNameContains": "EPM"}, + ) + assert r_get.status_code == 200 + + _, csrf_headers = _csrf_pair(client) + r_post = client.post( + "/api/datasets/DS1/tabular_query", + json={"variableNameContains": "EPM"}, + headers=csrf_headers, + ) + assert r_post.status_code == 200 + + assert r_get.json() == r_post.json() + + +def test_tabular_query_post_rejects_missing_variable_name( + app_and_cloud, +) -> None: # type: ignore[no-untyped-def] + """F-8: the POST body validator MUST mirror the GET query validator's + contract: ``variableNameContains`` is required.""" + client, _router = app_and_cloud + _, csrf_headers = _csrf_pair(client) + r = client.post( + "/api/datasets/DS1/tabular_query", + json={"groupBy": "x"}, + headers=csrf_headers, + ) + # FastAPI's RequestValidationError is mapped to the typed + # 400 VALIDATION_ERROR envelope by `app.py::handle_validation_error`. + assert r.status_code == 400 + body = r.json() + assert body["error"]["code"] == "VALIDATION_ERROR" + + +def test_tabular_query_get_rejects_missing_variable_name( + app_and_cloud, +) -> None: # type: ignore[no-untyped-def] + """F-8 mirror: missing variableNameContains in GET also returns the + same typed envelope, so the two routes agree on the error contract.""" + client, _router = app_and_cloud + r = client.get( + "/api/datasets/DS1/tabular_query", + params={"groupBy": "x"}, + ) + assert r.status_code == 400 + body = r.json() + assert body["error"]["code"] == "VALIDATION_ERROR" + + +# --------------------------------------------------------------------------- +# F-1 (2026-05-19) — curated /tables/stimulus projection +# +# The cloud-app's StimuliPicker previously hit the generic +# /api/datasets/:id/documents?class=stimulus_presentation endpoint (capped +# at 200 rows by backend). Datasets with >200 stimulus_presentation docs +# were silently truncated. F-1 adds a curated projection so the picker +# can paginate via the unified /tables/{class} envelope. +# +# Column shape: stimulusDocumentIdentifier, stimulusName, +# elementDocumentIdentifier, presentationCount, firstPresentationTime, +# lastPresentationTime. +# +# Class alias: requesting /tables/stimulus resolves to stimulus_presentation +# via _CLASS_ALIASES when the literal `stimulus` class returns 0 IDs. +# --------------------------------------------------------------------------- + + +def _stim_doc( + doc_id: str, + *, + name: str, + element_id: str | None = None, + presentations: list[dict] | None = None, +) -> dict: + """Build a stimulus_presentation doc matching the cloud's shape.""" + depends_on: list[dict] = [] + if element_id is not None: + depends_on.append({"name": "element_id", "value": element_id}) + return { + "id": doc_id, + "ndiId": f"ndi-{doc_id}", + "data": { + "base": {"id": f"ndi-{doc_id}", "name": name}, + "depends_on": depends_on, + "stimulus_presentation": { + "name": name, + "presentations": presentations or [], + }, + "document_class": {"class_name": "stimulus_presentation"}, + }, + } + + +def test_tables_stimulus_pins_column_shape_and_row_content( + app_and_cloud, +) -> None: # type: ignore[no-untyped-def] + """F-1: GET /tables/stimulus_presentation returns STIMULUS_COLUMNS + (six fixed keys) and rows projected from depends_on + presentations.""" + client, router = app_and_cloud + + # Cloud sees the literal class hit + bulk-fetches the 3 docs. + router.post("/ndiquery").respond( + 200, + json={ + "number_matches": 3, + "pageSize": 1000, + "page": 1, + "documents": [{"id": "stim1"}, {"id": "stim2"}, {"id": "stim3"}], + }, + ) + router.post("/datasets/DS1/documents/bulk-fetch").respond( + 200, + json={ + "documents": [ + _stim_doc( + "stim1", + name="Visual Grating", + element_id="EL_STIM_7", + presentations=[ + {"time_started": 1.5, "time_stopped": 2.5}, + {"time_started": 11.5, "time_stopped": 12.5}, + {"time_started": 21.5, "time_stopped": 22.5}, + ], + ), + _stim_doc( + "stim2", + name="Tone Burst", + element_id="EL_STIM_8", + presentations=[{"time_started": 100.0, "time_stopped": 100.5}], + ), + _stim_doc( + "stim3", + name="Empty Stimulus", + element_id="EL_STIM_9", + presentations=[], + ), + ], + }, + ) + + r = client.get("/api/datasets/DS1/tables/stimulus_presentation") + assert r.status_code == 200, r.json() + body = r.json() + + # Pin the six fixed projection columns (auto-hide-empty downstream + # may drop some, but the BACKEND emits all six keys). + assert [c["key"] for c in body["columns"]] == [ + "stimulusDocumentIdentifier", + "stimulusName", + "elementDocumentIdentifier", + "presentationCount", + "firstPresentationTime", + "lastPresentationTime", + ] + # Row content sourced from depends_on (element_id) + presentations. + assert len(body["rows"]) == 3 + by_name = {r["stimulusName"]: r for r in body["rows"]} + assert by_name["Visual Grating"]["elementDocumentIdentifier"] == "EL_STIM_7" + assert by_name["Visual Grating"]["presentationCount"] == 3 + assert by_name["Visual Grating"]["firstPresentationTime"] == 1.5 + assert by_name["Visual Grating"]["lastPresentationTime"] == 21.5 + assert by_name["Tone Burst"]["presentationCount"] == 1 + assert by_name["Tone Burst"]["firstPresentationTime"] == 100.0 + assert by_name["Tone Burst"]["lastPresentationTime"] == 100.0 + assert by_name["Empty Stimulus"]["presentationCount"] == 0 + assert by_name["Empty Stimulus"]["firstPresentationTime"] is None + + +def test_tables_stimulus_short_form_resolves_via_class_alias( + app_and_cloud, +) -> None: # type: ignore[no-untyped-def] + """F-1: GET /tables/stimulus (short form) MUST resolve to + stimulus_presentation via the _CLASS_ALIASES chain when the literal + `stimulus` class returns 0 IDs from the cloud. + + Implementation note (2026-05-18): respx's `side_effect=` callable + pattern hung indefinitely in this test under respx 0.23.1 + pytest + 9.x + the conftest's `assert_all_called=False` mock context (the + other two F-1 tests in this file use `.respond()` and pass + instantly). Switched to a sequence-of-responses pattern via two + chained `.respond()` calls — the second matches the alias query + because the cloud client makes the literal call first, gets the + empty response, then retries against `stimulus_presentation`. This + relies on respx's FIFO route matching when multiple routes have + the same predicate; verified passing. + """ + client, router = app_and_cloud + + # respx route ordering: the FIRST defined ndiquery route is the + # default (matches everything). We override that here with two + # routes: one specific to the alias by JSON content, one default. + # respx evaluates route predicates top-down and uses the first + # match — define the more-specific one first. + router.post( + "/ndiquery", + json__searchstructure__0__param1="stimulus_presentation", + ).respond( + 200, + json={ + "number_matches": 2, + "pageSize": 1000, + "page": 1, + "documents": [{"id": "stim1"}, {"id": "stim2"}], + }, + ) + # Default — any other ndiquery (notably the literal `stimulus` + # which the service tries first) returns empty. + router.post("/ndiquery").respond( + 200, + json={ + "number_matches": 0, + "pageSize": 1000, + "page": 1, + "documents": [], + }, + ) + router.post("/datasets/DS1/documents/bulk-fetch").respond( + 200, + json={ + "documents": [ + _stim_doc("stim1", name="A", element_id="E1", presentations=[]), + _stim_doc("stim2", name="B", element_id="E2", presentations=[]), + ], + }, + ) + + r = client.get("/api/datasets/DS1/tables/stimulus") + assert r.status_code == 200, r.json() + body = r.json() + + # Despite the request being for /tables/stimulus the projection + # still emits STIMULUS_COLUMNS (six fixed keys) — the alias + # resolution is invisible to the caller. + assert [c["key"] for c in body["columns"]] == [ + "stimulusDocumentIdentifier", + "stimulusName", + "elementDocumentIdentifier", + "presentationCount", + "firstPresentationTime", + "lastPresentationTime", + ] + assert len(body["rows"]) == 2 + assert {r["stimulusName"] for r in body["rows"]} == {"A", "B"} + + +def test_tables_stimulus_supports_pagination( + app_and_cloud, +) -> None: # type: ignore[no-untyped-def] + """F-1: the curated stimulus projection respects Stream 5.8 pagination + so the StimuliPicker can scroll through >200 stim docs without + re-querying the cloud per page. Spec: cache full result, slice + in-memory per request.""" + client, router = app_and_cloud + + ids = [f"stim{i}" for i in range(7)] + ndiquery_route = router.post("/ndiquery").respond( + 200, + json={ + "number_matches": 7, + "pageSize": 1000, + "page": 1, + "documents": [{"id": sid} for sid in ids], + }, + ) + router.post("/datasets/DS1/documents/bulk-fetch").respond( + 200, + json={ + "documents": [ + _stim_doc(sid, name=f"Stim {i}", element_id=f"E{i}", presentations=[]) + for i, sid in enumerate(ids) + ], + }, + ) + + # Page 1 of 3. + r1 = client.get( + "/api/datasets/DS1/tables/stimulus_presentation?page=1&pageSize=3", + ) + assert r1.status_code == 200, r1.json() + body1 = r1.json() + assert body1["page"] == 1 + assert body1["pageSize"] == 3 + assert body1["totalRows"] == 7 + assert body1["hasMore"] is True + assert len(body1["rows"]) == 3 + + first_call_count = ndiquery_route.call_count + + # Page 2 — same cached full row set, no additional cloud hits. + r2 = client.get( + "/api/datasets/DS1/tables/stimulus_presentation?page=2&pageSize=3", + ) + assert r2.status_code == 200 + body2 = r2.json() + assert body2["page"] == 2 + assert len(body2["rows"]) == 3 + # The 95%-egress-saving invariant: pagination doesn't re-fan the cloud. + assert ndiquery_route.call_count == first_call_count diff --git a/backend/tests/unit/test_aggregate_documents_service.py b/backend/tests/unit/test_aggregate_documents_service.py new file mode 100644 index 0000000..749ce74 --- /dev/null +++ b/backend/tests/unit/test_aggregate_documents_service.py @@ -0,0 +1,684 @@ +"""Unit tests for `aggregate_documents_service` — Stream 4.9 (2026-05-16). + +Ports the TypeScript test scenarios from +`apps/web/tests/unit/ai/tools/aggregate-documents.test.ts` into pytest. +The service is stateless and only collaborates with `NdiCloudClient` via +`ndiquery`, so tests mock the cloud call and exercise the pure logic: + +* Numeric extraction at dotted ``valueField``. +* Optional grouping at dotted ``groupBy``. +* Per-group summary statistics (count, mean, median, std, min, max). +* `numeric_matches` accounting — including the "has value but no group + label" skip path that pre-fix used to inflate the count. +* `truncated` flag when the cloud returns more matches than ``max_docs``. +* `datasets_contributing` capped at REFERENCE_CAP. + +Float comparisons use `math.isclose(rel_tol=1e-9)` because Python's +sample-std math uses N-1; values agree with the TS handler to ~14 digits. +""" +from __future__ import annotations + +import math +from typing import Any + +import pytest + +from backend.services.aggregate_documents_service import ( + REFERENCE_CAP, + AggregateDocumentsRequest, + AggregateDocumentsService, + _extract_numeric, + _extract_string, + _summary_stats, +) + + +class _StubCloud: + """Test double for NdiCloudClient that records the ndiquery payload + and returns a canned response. No HTTP.""" + + def __init__( + self, + body: dict[str, Any], + *, + bulk_fetch_index: dict[tuple[str, str], dict[str, Any]] | None = None, + ) -> None: + self._body = body + # F-7 — optional per-(dataset_id, doc_id) bulk-fetch index used by + # tests that exercise the hydration path. When present, the stub's + # ``bulk_fetch`` returns the matching doc bodies. + self._bulk_index = bulk_fetch_index or {} + self.calls: list[dict[str, Any]] = [] + self.bulk_calls: list[dict[str, Any]] = [] + + async def ndiquery( + self, + *, + searchstructure: list[dict[str, Any]], + scope: str, + access_token: str | None, + ) -> dict[str, Any]: + self.calls.append({ + "searchstructure": searchstructure, + "scope": scope, + "access_token": access_token, + }) + return self._body + + async def bulk_fetch( + self, + dataset_id: str, + document_ids: list[str], + *, + access_token: str | None = None, + ) -> list[dict[str, Any]]: + self.bulk_calls.append({ + "dataset_id": dataset_id, + "document_ids": list(document_ids), + "access_token": access_token, + }) + return [ + self._bulk_index[(dataset_id, doc_id)] + for doc_id in document_ids + if (dataset_id, doc_id) in self._bulk_index + ] + + +def _make_subject( + doc_id: str, + dataset_id: str, + weight: float | str | None, + strain: str | None = None, +) -> dict[str, Any]: + """Helper: minimal subject doc shape matching what cloud-node emits.""" + return { + "id": doc_id, + "ndiId": f"ndi-{doc_id}", + "datasetId": dataset_id, + "document_class": {"class_name": "subject"}, + "data": { + "subject": { + "weight_grams": weight, + "strain": strain, + }, + }, + } + + +# --------------------------------------------------------------------------- +# Pure-helper tests (extraction + stats) +# --------------------------------------------------------------------------- + +class TestExtractNumeric: + def test_finds_int_at_dotted_path(self) -> None: + doc = {"data": {"subject": {"weight_grams": 250}}} + assert _extract_numeric(doc, "data.subject.weight_grams") == 250.0 + + def test_finds_float(self) -> None: + doc = {"data": {"x": 3.14}} + assert _extract_numeric(doc, "data.x") == 3.14 + + def test_coerces_string_numerics(self) -> None: + doc = {"data": {"x": "42.5"}} + assert _extract_numeric(doc, "data.x") == 42.5 + + def test_returns_none_for_non_finite(self) -> None: + doc1 = {"data": {"x": float("inf")}} + doc2 = {"data": {"x": float("nan")}} + assert _extract_numeric(doc1, "data.x") is None + assert _extract_numeric(doc2, "data.x") is None + + def test_returns_none_for_missing_path(self) -> None: + doc = {"data": {"y": 1}} + assert _extract_numeric(doc, "data.x") is None + + def test_returns_none_for_unparseable_string(self) -> None: + doc = {"data": {"x": "hello"}} + assert _extract_numeric(doc, "data.x") is None + + def test_rejects_booleans(self) -> None: + # Bools are int subclasses in Python but the TS helper rejected + # them; preserve that contract so we don't accidentally aggregate + # `True/False` as 1/0. + doc = {"data": {"x": True}} + assert _extract_numeric(doc, "data.x") is None + + +class TestExtractString: + def test_finds_string(self) -> None: + doc = {"data": {"subject": {"strain": "C57BL/6"}}} + assert _extract_string(doc, "data.subject.strain") == "C57BL/6" + + def test_returns_none_for_empty_string(self) -> None: + doc = {"data": {"x": ""}} + assert _extract_string(doc, "data.x") is None + + def test_returns_none_for_missing_path(self) -> None: + doc = {"data": {"y": "z"}} + assert _extract_string(doc, "data.x") is None + + def test_coerces_booleans(self) -> None: + doc = {"data": {"x": True}} + assert _extract_string(doc, "data.x") == "true" + + def test_coerces_numbers(self) -> None: + doc = {"data": {"x": 42}} + assert _extract_string(doc, "data.x") == "42" + + +class TestSummaryStats: + def test_count_mean_median_basic(self) -> None: + stats = _summary_stats([1.0, 2.0, 3.0, 4.0, 5.0]) + assert stats["count"] == 5 + assert stats["mean"] == 3.0 + assert stats["median"] == 3.0 + assert stats["min"] == 1.0 + assert stats["max"] == 5.0 + # Sample std of [1,2,3,4,5] = sqrt(2.5) ≈ 1.5811 + assert math.isclose(stats["std"], math.sqrt(2.5), rel_tol=1e-9) + + def test_median_for_even_length(self) -> None: + stats = _summary_stats([1.0, 2.0, 3.0, 4.0]) + assert stats["median"] == 2.5 + + def test_singleton_has_zero_std(self) -> None: + # n=1 → sample std undefined; TS returns 0; mirror that. + stats = _summary_stats([42.0]) + assert stats["count"] == 1 + assert stats["mean"] == 42.0 + assert stats["median"] == 42.0 + assert stats["std"] == 0.0 + + +# --------------------------------------------------------------------------- +# Service end-to-end (with stubbed cloud) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_aggregates_a_single_group_when_groupby_unset() -> None: + cloud = _StubCloud({ + "documents": [ + _make_subject("d1", "ds-A", 200.0), + _make_subject("d2", "ds-A", 250.0), + _make_subject("d3", "ds-A", 300.0), + ], + "totalItems": 3, + }) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + ) + + result = await svc.aggregate(req, access_token=None) + + assert result["total_items"] == 3 + assert result["numeric_matches"] == 3 + assert result["scanned_docs"] == 3 + assert result["truncated"] is False + assert result["valueField"] == "data.subject.weight_grams" + assert len(result["groups"]) == 1 + g = result["groups"][0] + assert g["group"] == "all" + assert g["count"] == 3 + assert g["mean"] == 250.0 + assert g["median"] == 250.0 + assert g["min"] == 200.0 + assert g["max"] == 300.0 + assert g["sample_doc"] == { + "id": "d1", "dataset_id": "ds-A", "class": "subject", + } + assert result["datasets_contributing"] == ["ds-A"] + + +@pytest.mark.asyncio +async def test_groups_by_dotted_path() -> None: + cloud = _StubCloud({ + "documents": [ + _make_subject("d1", "ds-A", 200.0, strain="C57"), + _make_subject("d2", "ds-A", 220.0, strain="C57"), + _make_subject("d3", "ds-A", 250.0, strain="BALB"), + _make_subject("d4", "ds-A", 260.0, strain="BALB"), + ], + "totalItems": 4, + }) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + groupBy="data.subject.strain", + ) + + result = await svc.aggregate(req, access_token=None) + + assert result["numeric_matches"] == 4 + groups = {g["group"]: g for g in result["groups"]} + assert set(groups.keys()) == {"C57", "BALB"} + assert groups["C57"]["count"] == 2 + assert groups["C57"]["mean"] == 210.0 + assert groups["BALB"]["mean"] == 255.0 + # Each group surfaces the FIRST contributing doc as its sample. + assert groups["C57"]["sample_doc"]["id"] == "d1" + assert groups["BALB"]["sample_doc"]["id"] == "d3" + + +@pytest.mark.asyncio +async def test_skips_docs_with_value_but_no_group_label() -> None: + """The TS handler was fixed to NOT inflate ``numeric_matches`` when a + doc has a finite numeric but the groupBy path is missing; otherwise + "across 215 subjects" would claim more subjects than actually got + bucketed.""" + cloud = _StubCloud({ + "documents": [ + _make_subject("d1", "ds-A", 200.0, strain="C57"), + _make_subject("d2", "ds-A", 220.0, strain=None), # no group + _make_subject("d3", "ds-A", 250.0, strain="BALB"), + ], + "totalItems": 3, + }) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + groupBy="data.subject.strain", + ) + + result = await svc.aggregate(req, access_token=None) + # 3 total matches, 1 dropped → 2 contributed. + assert result["total_items"] == 3 + assert result["numeric_matches"] == 2 + + +@pytest.mark.asyncio +async def test_skips_docs_with_no_numeric_value() -> None: + cloud = _StubCloud({ + "documents": [ + _make_subject("d1", "ds-A", 200.0), + _make_subject("d2", "ds-A", None), # missing + _make_subject("d3", "ds-A", "not a number"), # unparseable + _make_subject("d4", "ds-A", float("nan")), # NaN + ], + "totalItems": 4, + }) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + ) + + result = await svc.aggregate(req, access_token=None) + assert result["total_items"] == 4 + assert result["numeric_matches"] == 1 + + +@pytest.mark.asyncio +async def test_truncation_when_total_exceeds_max_docs() -> None: + docs = [_make_subject(f"d{i}", "ds-A", 100.0 + i) for i in range(10)] + cloud = _StubCloud({ + "documents": docs, + "totalItems": 10, + }) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + maxDocs=5, + ) + + result = await svc.aggregate(req, access_token=None) + assert result["scanned_docs"] == 5 + assert result["truncated"] is True + assert result["total_items"] == 10 + # Only first 5 contributed. + assert result["groups"][0]["count"] == 5 + + +@pytest.mark.asyncio +async def test_datasets_contributing_capped_at_reference_cap() -> None: + docs = [ + _make_subject(f"d{i}", f"ds-{i}", 100.0 + i) + for i in range(REFERENCE_CAP + 5) + ] + cloud = _StubCloud({"documents": docs, "totalItems": len(docs)}) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + ) + + result = await svc.aggregate(req, access_token=None) + # Cap kicks in — exactly REFERENCE_CAP distinct datasets surfaced. + assert len(result["datasets_contributing"]) == REFERENCE_CAP + + +@pytest.mark.asyncio +async def test_handles_empty_cloud_response_gracefully() -> None: + cloud = _StubCloud({"documents": [], "totalItems": 0}) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + ) + + result = await svc.aggregate(req, access_token=None) + assert result["total_items"] == 0 + assert result["numeric_matches"] == 0 + assert result["groups"] == [] + assert result["datasets_contributing"] == [] + assert result["truncated"] is False + + +@pytest.mark.asyncio +async def test_forwards_searchstructure_and_scope_to_cloud() -> None: + cloud = _StubCloud({"documents": [], "totalItems": 0}) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope="abc1234567890123456789ab,def1234567890123456789ab", + searchstructure=[ + {"operation": "isa", "param1": "subject"}, + {"operation": "contains_string", "field": "subject.strain", "param1": "C57"}, + ], + valueField="data.subject.weight_grams", + ) + + await svc.aggregate(req, access_token=None) + assert len(cloud.calls) == 1 + call = cloud.calls[0] + assert call["scope"] == "abc1234567890123456789ab,def1234567890123456789ab" + assert len(call["searchstructure"]) == 2 + assert call["searchstructure"][0] == {"operation": "isa", "param1": "subject"} + assert call["access_token"] is None + + +# --------------------------------------------------------------------------- +# Request validation +# --------------------------------------------------------------------------- + +class TestAggregateDocumentsRequestValidation: + def test_rejects_invalid_scope(self) -> None: + with pytest.raises(ValueError): + AggregateDocumentsRequest( + scope="not-a-keyword-or-csv-of-hex", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.x", + ) + + def test_rejects_empty_searchstructure(self) -> None: + with pytest.raises(ValueError): + AggregateDocumentsRequest( + scope="public", + searchstructure=[], + valueField="data.x", + ) + + def test_rejects_oversize_searchstructure(self) -> None: + with pytest.raises(ValueError): + AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "x"}] * 21, + valueField="data.x", + ) + + def test_rejects_max_docs_above_ceiling(self) -> None: + with pytest.raises(ValueError): + AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.x", + maxDocs=50_001, + ) + + def test_accepts_public_scope(self) -> None: + req = AggregateDocumentsRequest( + scope="public", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.x", + ) + assert req.scope == "public" + + def test_accepts_csv_dataset_id_scope(self) -> None: + req = AggregateDocumentsRequest( + scope="abc1234567890123456789ab", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.x", + ) + assert req.scope == "abc1234567890123456789ab" + + +# --------------------------------------------------------------------------- +# F-7 (2026-05-19) — bulk_fetch hydration path +# +# When `/ndiquery` returns slim `{id, datasetId}` pairs (or bare id strings +# under a single-dataset scope) the legacy code path silently dropped every +# doc because `_extract_numeric` needs `data.*`. F-7 re-hydrates via +# chunked `bulk_fetch` so the aggregation result is numerically identical +# whether the cloud sent slim refs or full bodies. +# --------------------------------------------------------------------------- + +DS_A = "abc1234567890123456789aa" +DS_B = "abc1234567890123456789bb" + + +def _slim(doc: dict[str, Any]) -> dict[str, Any]: + """Strip a full body down to {id, datasetId} — what the cloud emits + when it doesn't return inline bodies.""" + return {"id": doc["id"], "datasetId": doc["datasetId"]} + + +@pytest.mark.asyncio +async def test_hydrates_slim_refs_via_bulk_fetch() -> None: + """Pin: when ndiquery returns slim {id, datasetId} pairs, the service + must bulk_fetch the full bodies and produce the SAME numeric stats it + would have produced if the cloud had returned full bodies inline.""" + full_docs = [ + _make_subject("d1", DS_A, 200.0), + _make_subject("d2", DS_A, 250.0), + _make_subject("d3", DS_A, 300.0), + ] + slim_docs = [_slim(d) for d in full_docs] + + cloud = _StubCloud( + {"documents": slim_docs, "totalItems": 3}, + bulk_fetch_index={(d["datasetId"], d["id"]): d for d in full_docs}, + ) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope=DS_A, + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + ) + + result = await svc.aggregate(req, access_token=None) + + # Hydration MUST happen — bulk_fetch was called. + assert len(cloud.bulk_calls) == 1 + assert cloud.bulk_calls[0]["dataset_id"] == DS_A + assert sorted(cloud.bulk_calls[0]["document_ids"]) == ["d1", "d2", "d3"] + # Numeric output identical to the full-body happy path. + assert result["numeric_matches"] == 3 + assert result["scanned_docs"] == 3 + g = result["groups"][0] + assert g["count"] == 3 + assert g["mean"] == 250.0 + assert g["median"] == 250.0 + assert g["min"] == 200.0 + assert g["max"] == 300.0 + + +@pytest.mark.asyncio +async def test_full_body_path_skips_bulk_fetch_no_op() -> None: + """Pin: when the cloud already returned full bodies, hydration is a + no-op — zero extra cloud calls. (Protects the happy-path latency.)""" + full_docs = [ + _make_subject("d1", DS_A, 200.0), + _make_subject("d2", DS_A, 250.0), + ] + cloud = _StubCloud({"documents": full_docs, "totalItems": 2}) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope=DS_A, + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + ) + + result = await svc.aggregate(req, access_token=None) + + # NO bulk_fetch call — the docs came back full-bodied already. + assert cloud.bulk_calls == [] + assert result["numeric_matches"] == 2 + + +@pytest.mark.asyncio +async def test_per_doc_vs_bulk_numeric_equivalence() -> None: + """The core regression pin from the F-7 spec: running the SAME + fixture as (a) full-body inline and (b) slim-refs-then-bulk-fetch + must produce byte-equal {mean, median, std, min, max}. + + Reasoning: the per-doc `get_document` path and the bulk_fetch path + BOTH reach the cloud-side document body. The aggregator's only + consumer is the body's `data.*` — so the numeric reduction is + identical iff hydration preserves the body verbatim.""" + full_docs = [ + _make_subject("d1", DS_A, 200.0, strain="C57"), + _make_subject("d2", DS_A, 220.0, strain="C57"), + _make_subject("d3", DS_B, 250.0, strain="BALB"), + _make_subject("d4", DS_B, 260.0, strain="BALB"), + ] + + # Path A: full bodies inline (legacy path). + cloud_full = _StubCloud({"documents": full_docs, "totalItems": 4}) + svc_full = AggregateDocumentsService(cloud_full) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope=f"{DS_A},{DS_B}", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + groupBy="data.subject.strain", + ) + result_full = await svc_full.aggregate(req, access_token=None) + + # Path B: slim refs + bulk_fetch hydration (F-7 path). + slim_docs = [_slim(d) for d in full_docs] + cloud_slim = _StubCloud( + {"documents": slim_docs, "totalItems": 4}, + bulk_fetch_index={(d["datasetId"], d["id"]): d for d in full_docs}, + ) + svc_slim = AggregateDocumentsService(cloud_slim) # type: ignore[arg-type] + result_slim = await svc_slim.aggregate(req, access_token=None) + + # Same {count, mean, median, std, min, max} per group; same ordering. + def _strip(r: dict[str, Any]) -> list[dict[str, Any]]: + return [ + { + "group": g["group"], + "count": g["count"], + "mean": g["mean"], + "median": g["median"], + "std": g["std"], + "min": g["min"], + "max": g["max"], + } + for g in r["groups"] + ] + + assert _strip(result_full) == _strip(result_slim) + assert result_full["numeric_matches"] == result_slim["numeric_matches"] + assert ( + result_full["datasets_contributing"] + == result_slim["datasets_contributing"] + ) + + +@pytest.mark.asyncio +async def test_hydration_chunks_at_bulk_fetch_max() -> None: + """Pin: hydration chunks slim refs at BULK_FETCH_MAX (=500) per call. + Stage 600 refs in a single dataset → expect 2 bulk_fetch calls (500 + + 100).""" + from backend.clients.ndi_cloud import BULK_FETCH_MAX as CHUNK + + full_docs = [ + _make_subject(f"d{i}", DS_A, 100.0 + i) for i in range(CHUNK + 100) + ] + slim_docs = [_slim(d) for d in full_docs] + + cloud = _StubCloud( + {"documents": slim_docs, "totalItems": len(slim_docs)}, + bulk_fetch_index={(d["datasetId"], d["id"]): d for d in full_docs}, + ) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope=DS_A, + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + maxDocs=10_000, + ) + + result = await svc.aggregate(req, access_token=None) + + assert len(cloud.bulk_calls) == 2 + sizes = sorted(len(c["document_ids"]) for c in cloud.bulk_calls) + assert sizes == [100, CHUNK] + # All docs hydrated → numeric_matches matches scan window. + assert result["numeric_matches"] == CHUNK + 100 + + +@pytest.mark.asyncio +async def test_hydration_chunks_per_dataset() -> None: + """Pin: slim refs spanning multiple datasets fan out into per-dataset + bulk_fetch batches (since bulk_fetch is per-dataset).""" + full_docs = [ + _make_subject("d1", DS_A, 100.0), + _make_subject("d2", DS_A, 110.0), + _make_subject("d3", DS_B, 200.0), + _make_subject("d4", DS_B, 210.0), + ] + slim_docs = [_slim(d) for d in full_docs] + + cloud = _StubCloud( + {"documents": slim_docs, "totalItems": 4}, + bulk_fetch_index={(d["datasetId"], d["id"]): d for d in full_docs}, + ) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope=f"{DS_A},{DS_B}", + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + ) + + result = await svc.aggregate(req, access_token=None) + + # Two bulk_fetch calls — one per dataset. + assert len(cloud.bulk_calls) == 2 + datasets_called = sorted(c["dataset_id"] for c in cloud.bulk_calls) + assert datasets_called == sorted([DS_A, DS_B]) + assert result["numeric_matches"] == 4 + + +@pytest.mark.asyncio +async def test_hydration_handles_bare_id_strings_under_single_dataset_scope() -> None: + """Pin: bare id strings (rare; cloud returns ``ids: [...]``) under a + single-dataset scope attribute to that dataset and hydrate cleanly.""" + full_docs = [ + _make_subject("d1", DS_A, 200.0), + _make_subject("d2", DS_A, 250.0), + ] + cloud = _StubCloud( + {"ids": ["d1", "d2"], "documents": ["d1", "d2"], "totalItems": 2}, + bulk_fetch_index={(d["datasetId"], d["id"]): d for d in full_docs}, + ) + svc = AggregateDocumentsService(cloud) # type: ignore[arg-type] + req = AggregateDocumentsRequest( + scope=DS_A, + searchstructure=[{"operation": "isa", "param1": "subject"}], + valueField="data.subject.weight_grams", + ) + + result = await svc.aggregate(req, access_token=None) + assert len(cloud.bulk_calls) == 1 + assert result["numeric_matches"] == 2 diff --git a/backend/tests/unit/test_binary_default_image_pick.py b/backend/tests/unit/test_binary_default_image_pick.py new file mode 100644 index 0000000..6bc769d --- /dev/null +++ b/backend/tests/unit/test_binary_default_image_pick.py @@ -0,0 +1,109 @@ +"""Default image-ref picker — sibling of ``_pick_default_signal_ref`` +for the B5 sweep (2026-05-18). The signal-pick fix that landed in +e03d470 surfaces only when ``BinaryService.get_timeseries`` is the +caller; this picker covers the Document Explorer's image viewer +(`BinaryService.get_image`) and the chat's `fetch_image` +endpoint (`ImageService.fetch_image`) — both of which previously +grabbed ``refs[0]`` blindly. + +Pre-fix repro pattern: a multi-file image doc carries a primary +image (e.g. ``frame_001.tif``) plus one or more sidecars +(``imageStack_parameters.json``, ``calibration.json``, +``channel_list.bin``). Cloud-side ordering puts the sidecar first; +Pillow's ``Image.open`` raises ``UnidentifiedImageError`` and the +request returns ``errorKind="unsupported"`` even though the doc +DID have a decodable image. +""" +from __future__ import annotations + +from backend.services.binary_service import ( + FileRef, + _pick_default_image_ref, +) + + +def _r(name: str, url: str = "https://example/x") -> FileRef: + return FileRef(url=url, content_type=None, filename=name) + + +class TestPickDefaultImageRef: + def test_prefers_tif_over_imagestack_parameters_json(self) -> None: + """The imageStack-multi-file case: parameters JSON first, then + the actual TIFF frame(s).""" + refs = [ + _r("imageStack_parameters.json"), + _r("frame_001.tif"), + ] + # `imageStack_parameters.json` is not literally in the blocklist + # but it is not an image extension either. We pick the .tif + # because it matches a known decodable extension and is not + # itself metadata-blocked. The JSON is a non-metadata-blocked + # fallback (step 2) — but step 1 wins on the matching extension. + pick = _pick_default_image_ref(refs) + assert pick.filename == "frame_001.tif" + + def test_prefers_tiff_over_channel_list_bin(self) -> None: + """Shared metadata blocklist still applies for image picks.""" + refs = [ + _r("channel_list.bin"), + _r("microscopy.tiff"), + ] + assert _pick_default_image_ref(refs).filename == "microscopy.tiff" + + def test_image_extension_variants(self) -> None: + """All Pillow-supported extensions in ``_DECODABLE_IMAGE_EXTENSIONS`` + are recognized.""" + for name in ["pic.tif", "pic.tiff", "pic.png", "pic.jpg", + "pic.jpeg", "pic.gif"]: + refs = [_r("meta.json"), _r(name)] + assert _pick_default_image_ref(refs).filename == name, ( + f"failed for {name}" + ) + + def test_case_insensitive_extension_match(self) -> None: + """Extension match is case-insensitive (matches the signal pick).""" + refs = [_r("meta.json"), _r("frame.PNG")] + assert _pick_default_image_ref(refs).filename == "frame.PNG" + + def test_single_file_unchanged(self) -> None: + """Single-file docs hit the legacy fallback unchanged — existing + non-multi-file image docs continue to work.""" + refs = [_r("only_image.tif")] + assert _pick_default_image_ref(refs).filename == "only_image.tif" + + def test_single_metadata_file_falls_back_to_it(self) -> None: + """If the ONLY file is metadata, return it — let Pillow surface + the real ``UnidentifiedImageError`` rather than "no match".""" + refs = [_r("channel_list.bin")] + assert _pick_default_image_ref(refs).filename == "channel_list.bin" + + def test_extension_match_beats_non_metadata_fallback_order(self) -> None: + """Step 1 (extension match) wins even when step 2 (non-metadata + fallback) would otherwise return an earlier ref.""" + refs = [ + _r("calibration.json"), # non-metadata, no image ext + _r("frame_001.tif"), # image ext — picks this + ] + # Without the extension filter we'd pick the JSON (first + # non-metadata). The decoder would then 500. + assert _pick_default_image_ref(refs).filename == "frame_001.tif" + + def test_suffix_variants_for_tif(self) -> None: + """Numbered TIFF series (``frame.tif_1``) still recognized + — same suffix-with-non-alphanumeric-tail rule as the signal pick.""" + refs = [_r("meta.json"), _r("frame.tif_1")] + assert _pick_default_image_ref(refs).filename == "frame.tif_1" + + def test_signal_extensions_not_picked_for_image(self) -> None: + """``.nbf`` etc. are NOT in the image extension list; they fall + through to step 2 (non-metadata fallback), letting the existing + Pillow soft-error path surface.""" + refs = [_r("data.nbf_1"), _r("frame.png")] + # `data.nbf_1` is not an image extension; `frame.png` is. Pick + # the image one (step 1 wins). + assert _pick_default_image_ref(refs).filename == "frame.png" + + def test_all_metadata_returns_first(self) -> None: + """If every ref is metadata, fall back to refs[0] (legacy).""" + refs = [_r("channel_list.bin"), _r("meta.json")] + assert _pick_default_image_ref(refs).filename == "channel_list.bin" diff --git a/backend/tests/unit/test_binary_default_signal_pick.py b/backend/tests/unit/test_binary_default_signal_pick.py new file mode 100644 index 0000000..41ba768 --- /dev/null +++ b/backend/tests/unit/test_binary_default_signal_pick.py @@ -0,0 +1,89 @@ +"""Default signal-ref picker — picks the timeseries file from a multi- +file document instead of the metadata file. + +Live-found bug 2026-05-19: Francesconi's +``daqreader_mfdaq_epochdata_ingested`` docs (e.g. +``68d6e54703a03f5cfdac8ef7``) carry ``file_info[0] == channel_list.bin`` +followed by N ``.nbf_#`` signal files. Pre-fix, the signal endpoint +returned "Could not decode channel_list.bin binary file. Format may not +be supported." because ``refs[0]`` blindly picked the metadata file. +""" +from __future__ import annotations + +from backend.services.binary_service import ( + FileRef, + _pick_default_signal_ref, +) + + +def _r(name: str, url: str = "https://example/x") -> FileRef: + return FileRef(url=url, content_type=None, filename=name) + + +class TestPickDefaultSignalRef: + def test_prefers_nbf_over_channel_list_bin(self) -> None: + """The Francesconi case: channel_list.bin first, then .nbf_# signal files.""" + refs = [ + _r("channel_list.bin"), + _r("ai_group1_seg.nbf_1"), + _r("ai_group2_seg.nbf_1"), + ] + pick = _pick_default_signal_ref(refs) + assert pick.filename == "ai_group1_seg.nbf_1" + + def test_nbf_suffix_variants_match(self) -> None: + """`.nbf_#`, `.nbf_1`, `.nbf` (exact) all considered decodable.""" + for name in ["foo.nbf_#", "foo.nbf_1", "foo.nbf_42", "foo.nbf"]: + refs = [_r("channel_list.bin"), _r(name)] + pick = _pick_default_signal_ref(refs) + assert pick.filename == name, f"failed for {name}" + + def test_vhsb_recognized(self) -> None: + """vhlab `.vhsb` files (Haley etc.) are decodable signal data.""" + refs = [_r("channel_list.bin"), _r("trace.vhsb")] + assert _pick_default_signal_ref(refs).filename == "trace.vhsb" + + def test_single_file_unchanged(self) -> None: + """Single-file docs hit the legacy fallback — refs[0] returns + unchanged so existing element_epoch / vhsb docs work as before.""" + refs = [_r("only_file.vhsb")] + assert _pick_default_signal_ref(refs).filename == "only_file.vhsb" + + def test_single_metadata_file_falls_back_to_it(self) -> None: + """If the ONLY file is metadata, return it anyway (let the codec + produce the real "unknown format" error rather than "no match").""" + refs = [_r("channel_list.bin")] + assert _pick_default_signal_ref(refs).filename == "channel_list.bin" + + def test_unknown_extensions_skip_metadata(self) -> None: + """When no ext matches but multiple non-metadata files exist, + return the first non-metadata one — avoid returning known- + metadata even if extension doesn't match decodable set.""" + refs = [ + _r("channel_list.bin"), + _r("trace.weird_unknown_extension"), + ] + pick = _pick_default_signal_ref(refs) + assert pick.filename == "trace.weird_unknown_extension" + + def test_case_insensitive_metadata_block(self) -> None: + """Metadata filename match is case-insensitive.""" + refs = [_r("Channel_List.bin"), _r("data.nbf_1")] + assert _pick_default_signal_ref(refs).filename == "data.nbf_1" + + def test_case_insensitive_extension_match(self) -> None: + """Extension match is case-insensitive.""" + refs = [_r("channel_list.bin"), _r("trace.NBF_1")] + assert _pick_default_signal_ref(refs).filename == "trace.NBF_1" + + def test_all_metadata_returns_first(self) -> None: + """If every ref is metadata, return the first (legacy fallback).""" + refs = [_r("channel_list.bin"), _r("meta.json")] + # First metadata is returned via step-3 fallback (refs[0]). + assert _pick_default_signal_ref(refs).filename == "channel_list.bin" + + def test_metadata_blocklist_handles_channels_json(self) -> None: + """`channels.json` / `meta.json` / `metadata.json` also skipped.""" + for meta in ["channels.json", "meta.json", "metadata.json"]: + refs = [_r(meta), _r("data.nbf_1")] + assert _pick_default_signal_ref(refs).filename == "data.nbf_1" diff --git a/backend/tests/unit/test_cache_isolation.py b/backend/tests/unit/test_cache_isolation.py index b1b112d..08f7d4a 100644 --- a/backend/tests/unit/test_cache_isolation.py +++ b/backend/tests/unit/test_cache_isolation.py @@ -80,7 +80,7 @@ def test_authed_key_contains_user_hash(self) -> None: s = _make_session("alice@example.com") scope = user_scope_for(s) key = RedisTableCache.table_key("DS1", "subject", user_scope=scope) - assert key == f"table:v4:DS1:subject:{scope}" + assert key == f"table:v7:DS1:subject:{scope}" assert "u:" in key def test_public_key_uses_public_literal(self) -> None: @@ -183,8 +183,13 @@ async def test_schema_v4_does_not_read_v3_entries(fake_redis) -> None: # type: v4_key = RedisTableCache.table_key( "DS1", "subject", user_scope=user_scope_for(alice), ) - assert v4_key != v3_key, "v4 key must differ from v3 key" - assert v4_key.startswith("table:v4:"), "v4 key must carry v4 version prefix" + assert v4_key != v3_key, "current key must differ from v3 key" + # 2026-05-19 — bumped to v7 (F-1b follow-up: subject enrichment + # now fetches treatment_drug + treatment_transfer); the variable + # name stays `v4_key` for narrative continuity with the original + # test ("read across schema version boundaries"). The prefix + # check moves with the bump. + assert v4_key.startswith("table:v7:"), "current key must carry v7 version prefix" calls = 0 diff --git a/backend/tests/unit/test_cookie_attrs.py b/backend/tests/unit/test_cookie_attrs.py index 65e6e34..29f7095 100644 --- a/backend/tests/unit/test_cookie_attrs.py +++ b/backend/tests/unit/test_cookie_attrs.py @@ -2,14 +2,27 @@ The helper centralizes the per-environment Set-Cookie/Delete-Cookie attribute derivation that the backend uses for the session and CSRF -cookies. Production cookies must carry ``Domain=.ndi-cloud.com`` so the -cross-repo Vercel frontend (Phase 4 of the cross-repo unification) can -read them; dev cookies must not carry ``Secure`` because local dev -serves over plain HTTP; everything else (staging) is secure but -host-only. +cookies. Two layers: + + - **Environment** decides whether ``Secure`` is set and whether + ``Domain=.ndi-cloud.com`` is even considered: + - production: Domain is conditional (see next layer) + - development: no Secure, no Domain (plain-HTTP localhost) + - staging / other: Secure, no Domain (host-only) + + - **Per-request Origin** decides whether ``Domain=.ndi-cloud.com`` + is actually attached in production. The apex Vercel deployment + needs it so the Railway backend's cookies are readable on the + apex host. Vercel preview deployments at ``*.vercel.app`` need + it OMITTED — otherwise the browser silently rejects the Set- + Cookie because the response origin doesn't match the cookie's + claimed Domain (this was the 2026-05-14 preview-login CSRF + failure). """ from typing import Literal +from fastapi import Request + from backend.auth.cookie_attrs import cookie_attrs from backend.config import Settings @@ -22,20 +35,107 @@ def _settings(env: EnvName) -> Settings: return Settings(ENVIRONMENT=env) -def test_production_returns_secure_with_apex_domain() -> None: - assert cookie_attrs(_settings("production")) == { - "secure": True, - "domain": ".ndi-cloud.com", - } +def _request(origin: str | None = None, referer: str | None = None) -> Request: + """Build a minimal Starlette Request for cookie_attrs to read. + + Only the headers matter for this helper — scope.path/method are + unused. Using a raw scope avoids pulling in the TestClient just + to get a Request instance. + """ + headers: list[tuple[bytes, bytes]] = [] + if origin is not None: + headers.append((b"origin", origin.encode())) + if referer is not None: + headers.append((b"referer", referer.encode())) + return Request( + scope={ + "type": "http", + "method": "POST", + "path": "/api/auth/csrf", + "headers": headers, + "query_string": b"", + } + ) + + +# ─── Production xapex origin → Domain attached ───────────────────────── + +def test_production_with_apex_origin_attaches_domain() -> None: + """The original cross-repo unification (Phase 4) contract.""" + attrs = cookie_attrs( + _settings("production"), + request=_request(origin="https://ndi-cloud.com"), + ) + assert attrs == {"secure": True, "domain": ".ndi-cloud.com"} + + +def test_production_with_subdomain_origin_attaches_domain() -> None: + """`app.ndi-cloud.com` (legacy) and any future `*.ndi-cloud.com`.""" + attrs = cookie_attrs( + _settings("production"), + request=_request(origin="https://app.ndi-cloud.com"), + ) + assert attrs == {"secure": True, "domain": ".ndi-cloud.com"} + + +def test_production_with_referer_only_attaches_domain() -> None: + """Same-origin GETs may omit Origin; Referer should still work.""" + attrs = cookie_attrs( + _settings("production"), + request=_request(referer="https://ndi-cloud.com/login"), + ) + assert attrs == {"secure": True, "domain": ".ndi-cloud.com"} + + +# ─── Production xpreview / unknown origin → host-only ────────────────── + +def test_production_with_vercel_preview_origin_is_host_only() -> None: + """The 2026-05-14 preview-login fix: no Domain attribute when + the request came from a Vercel preview hostname.""" + attrs = cookie_attrs( + _settings("production"), + request=_request(origin="https://ndi-cloud-app-web-git-feat-x.vercel.app"), + ) + assert attrs == {"secure": True} + assert "domain" not in attrs + + +def test_production_with_no_origin_or_referer_is_host_only() -> None: + """Fail-safe path: when we can't tell where the request came + from, drop Domain. Worse case is host-only cookies on apex (which + still work — they just don't cross-subdomain share).""" + attrs = cookie_attrs(_settings("production"), request=_request()) + assert attrs == {"secure": True} + assert "domain" not in attrs + + +def test_production_with_unrelated_origin_is_host_only() -> None: + """Origin is `https://attacker.example` — don't attach our apex + Domain to that response (browsers would reject anyway, but be + explicit).""" + attrs = cookie_attrs( + _settings("production"), + request=_request(origin="https://attacker.example"), + ) + assert attrs == {"secure": True} + assert "domain" not in attrs + +# ─── Non-production envs: Origin doesn't matter ───────────────────────── def test_development_returns_insecure_without_domain() -> None: - attrs = cookie_attrs(_settings("development")) + """Localhost over plain HTTP needs Secure=False; Domain is + irrelevant either way.""" + attrs = cookie_attrs(_settings("development"), request=_request()) assert attrs == {"secure": False} assert "domain" not in attrs def test_staging_returns_secure_without_domain() -> None: - attrs = cookie_attrs(_settings("staging")) + """Staging serves over HTTPS but host-only.""" + attrs = cookie_attrs( + _settings("staging"), + request=_request(origin="https://ndi-cloud.com"), + ) assert attrs == {"secure": True} assert "domain" not in attrs diff --git a/backend/tests/unit/test_csrf.py b/backend/tests/unit/test_csrf.py index 2b4a97d..f08bcea 100644 --- a/backend/tests/unit/test_csrf.py +++ b/backend/tests/unit/test_csrf.py @@ -1,4 +1,4 @@ -from backend.middleware.csrf import generate_token, sign, verify +from backend.middleware.csrf import EXEMPT_PATHS, generate_token, sign, verify def test_sign_verify_roundtrip() -> None: @@ -24,3 +24,19 @@ def test_distinct_tokens_are_unique() -> None: a = generate_token() b = generate_token() assert a != b + + +def test_ontology_batch_lookup_is_csrf_exempt() -> None: + """Anonymous /api/ontology/batch-lookup must work without a CSRF token. + + The endpoint is POST-shaped (body holds an array of CURIEs to avoid + URL repetition for batches up to 200 terms) but is functionally a + read-only lookup with no state mutation. Anonymous visitors hit it + on every dataset page render, before they've had a chance to call + /api/auth/csrf. Pre-fix, every anonymous summary-table view + surfaced a "1 warning · ontology lookup failed" banner because the + POST 403'd. Adding the path to EXEMPT_PATHS lets the middleware + pass anonymous requests through to the router. (Visual-UX audit + a395 P0 #3, 2026-05-14.) + """ + assert "/api/ontology/batch-lookup" in EXEMPT_PATHS diff --git a/backend/tests/unit/test_dataset_binding_service.py b/backend/tests/unit/test_dataset_binding_service.py new file mode 100644 index 0000000..f727430 --- /dev/null +++ b/backend/tests/unit/test_dataset_binding_service.py @@ -0,0 +1,453 @@ +"""Unit tests for DatasetBindingService. + +These tests do NOT require NDI-python to be installed. We patch the +service's internals so the cold-load path returns fake Dataset objects +without ever hitting the network or the SDK. The contract under test is +the cache/eviction/error-handling shell, not the SDK itself — the SDK is +already exercised by NDI-python's own test suite + the (separate) +integration tests in ``tests/integration/test_dataset_binding_live.py``. +""" +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from backend.services.dataset_binding_service import ( + MAX_CACHED_DATASETS, + DatasetBindingService, + _CacheEntry, +) + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +def _make_fake_element(name: str, etype: str, n_epochs: int) -> SimpleNamespace: + """Duck-typed stand-in for ndi.element.ndi_element. + + The service touches: ``.name``, ``.type``, ``.numepochs()``, + ``.epochtable()``. Provide all four so the service can pick either + path without a None-deref. + """ + et = [{"epoch_number": i + 1} for i in range(n_epochs)] + return SimpleNamespace( + name=name, + type=etype, + numepochs=lambda: n_epochs, + epochtable=lambda: (et, "fakehash"), + ) + + +def _make_fake_dataset( + *, + elements: list[SimpleNamespace] | None = None, + subject_docs: list[dict[str, Any]] | None = None, + reference: str = "fake_ref", +) -> SimpleNamespace: + """Duck-typed ndi.dataset.Dataset. + + Surface the service uses: + - ``._session.getelements()`` + - ``.database_search(query)`` for subject docs + - ``.reference`` + """ + elements = elements or [] + subject_docs = subject_docs or [] + + session = SimpleNamespace(getelements=lambda **_kw: elements) + + def db_search(_q: Any) -> list[Any]: + # Service's only call is `isa('subject')`. Return canned docs. + return subject_docs + + return SimpleNamespace( + _session=session, + database_search=db_search, + reference=reference, + ) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def reset_ndi_python_caches(): + """Force ndi_python_service's availability cache flags to a known + state so test-local patches actually take effect. + """ + from backend.services import ndi_python_service + ndi_python_service._NDI_AVAILABLE = None + ndi_python_service._DATASET_BINDING_AVAILABLE = None + yield + ndi_python_service._NDI_AVAILABLE = None + ndi_python_service._DATASET_BINDING_AVAILABLE = None + + +@pytest.fixture +def svc(tmp_path) -> DatasetBindingService: + return DatasetBindingService(cache_dir=str(tmp_path / "ndi-cache")) + + +@pytest.fixture +def ndi_available(): + """Patch is_ndi_available -> True for the duration of the test. + + Also patches ``is_dataset_binding_available`` -> True so the + Sprint 1.5 binding probe (added 2026-05-14) doesn't short-circuit + the cold-load path before the patched ``_download_blocking`` fires. + + The service calls ``from . import ndi_python_service`` lazily inside + _cold_load(), so we must patch the attribute on the source module + (``backend.services.ndi_python_service.is_ndi_available``) rather + than on the binding module. + """ + with ( + patch( + "backend.services.ndi_python_service.is_ndi_available", + return_value=True, + ) as p, + patch( + "backend.services.ndi_python_service.is_dataset_binding_available", + return_value=True, + ), + ): + yield p + + +@pytest.fixture +def ndi_unavailable(): + with patch( + "backend.services.ndi_python_service.is_ndi_available", + return_value=False, + ) as p: + yield p + + +# --------------------------------------------------------------------------- +# get_dataset — cache miss/hit/eviction/coalescing/failure +# --------------------------------------------------------------------------- + + +class TestGetDataset: + async def test_returns_none_on_empty_id(self, svc: DatasetBindingService): + assert await svc.get_dataset("") is None + + @pytest.mark.usefixtures("ndi_available") + async def test_cold_miss_then_warm_hit( + self, svc: DatasetBindingService + ): + """First call downloads + caches. Second call hits the cache and + returns the SAME object without invoking downloadDataset again. + """ + fake = _make_fake_dataset() + call_count = 0 + + def fake_download(dataset_id: str) -> Any: + nonlocal call_count + call_count += 1 + return fake + + with patch.object(svc, "_download_blocking", side_effect=fake_download): + first = await svc.get_dataset("DS1") + second = await svc.get_dataset("DS1") + + assert first is fake + assert second is fake + # Only ONE download — the second call must hit the warm cache. + assert call_count == 1 + + @pytest.mark.usefixtures("ndi_available") + async def test_lru_eviction_at_max( + self, svc: DatasetBindingService + ): + """Inserting MAX_CACHED_DATASETS + 1 distinct ids evicts the + oldest. Verifies the LRU bound matches the documented constant. + """ + fakes = { + f"DS{i}": _make_fake_dataset(reference=f"ref{i}") + for i in range(MAX_CACHED_DATASETS + 1) + } + + def fake_download(dataset_id: str) -> Any: + return fakes[dataset_id] + + with patch.object(svc, "_download_blocking", side_effect=fake_download): + for i in range(MAX_CACHED_DATASETS + 1): + await svc.get_dataset(f"DS{i}") + + # Cache size is exactly MAX_CACHED_DATASETS. + assert len(svc._cache) == MAX_CACHED_DATASETS + # Oldest (DS0) was evicted; the newest are still present. + assert "DS0" not in svc._cache + assert f"DS{MAX_CACHED_DATASETS}" in svc._cache + + @pytest.mark.usefixtures("ndi_available") + async def test_concurrent_calls_dedupe( + self, svc: DatasetBindingService + ): + """Two simultaneous get_dataset('DS1') calls share ONE download. + + Pins the per-dataset lock contract: while one task is in the + cold path, others wait, then return the SAME cached object + without a second download. + """ + fake = _make_fake_dataset() + call_count = 0 + + def slow_download(_dataset_id: str) -> Any: + nonlocal call_count + call_count += 1 + return fake + + async def fire(_idx: int) -> Any: + return await svc.get_dataset("DS1") + + with patch.object(svc, "_download_blocking", side_effect=slow_download): + results = await asyncio.gather(fire(0), fire(1), fire(2)) + + # All three calls returned the same object. + assert results[0] is fake + assert results[1] is fake + assert results[2] is fake + # And there was exactly ONE download. + assert call_count == 1 + + @pytest.mark.usefixtures("ndi_available") + async def test_failure_returns_none_not_raise( + self, svc: DatasetBindingService + ): + """When downloadDataset raises, get_dataset MUST return None + rather than propagate — the chat falls back to ndi_query. + """ + def boom(_dataset_id: str) -> Any: + raise RuntimeError("simulated cloud-node 500") + + with patch.object(svc, "_download_blocking", side_effect=boom): + result = await svc.get_dataset("DS-broken") + + assert result is None + # Nothing cached on failure — a retry should attempt the cold + # path again. + assert "DS-broken" not in svc._cache + + @pytest.mark.usefixtures("ndi_unavailable") + async def test_returns_none_when_ndi_unavailable( + self, svc: DatasetBindingService + ): + """is_ndi_available=False short-circuits before + downloadDataset is reached. + """ + download = MagicMock() + with patch.object(svc, "_download_blocking", download): + result = await svc.get_dataset("DS1") + + assert result is None + download.assert_not_called() + + +# --------------------------------------------------------------------------- +# overview — counts + cache_hit + cache_age semantics +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("ndi_available") +class TestOverview: + async def test_happy_path_counts_match_fakes( + self, svc: DatasetBindingService + ): + elements = [ + _make_fake_element("e0", "n-trode", n_epochs=3), + _make_fake_element("e1", "stimulator", n_epochs=2), + ] + subjects = [{"_id": "subj1"}, {"_id": "subj2"}, {"_id": "subj3"}] + fake = _make_fake_dataset( + elements=elements, subject_docs=subjects, reference="DS-ref" + ) + + with patch.object(svc, "_download_blocking", return_value=fake): + out = await svc.overview("DS1") + + assert out is not None + assert out["element_count"] == 2 + # Subject count is best-effort — passes when ndi.query is + # importable AND returns the canned subjects list. On dev + # machines with an old/missing ndi.query, the subject_count + # path silently falls back to 0 (documented partial-failure + # behavior). Assert "either correct OR zero" so this test is + # resilient to both environments. + assert out["subject_count"] in (0, 3) + # 3 + 2 epochs. + assert out["epoch_count"] == 5 + assert out["elements"] == [ + {"name": "e0", "type": "n-trode"}, + {"name": "e1", "type": "stimulator"}, + ] + assert out["elements_truncated"] is False + assert out["reference"] == "DS-ref" + # First call is a cold one → cache_hit must be False. + assert out["cache_hit"] is False + # cache_age is small (just measured), but it's a float >= 0. + assert isinstance(out["cache_age_seconds"], float) + assert out["cache_age_seconds"] >= 0.0 + + async def test_warm_call_reports_cache_hit_true( + self, svc: DatasetBindingService + ): + fake = _make_fake_dataset() + with patch.object(svc, "_download_blocking", return_value=fake): + await svc.overview("DS1") # cold + second = await svc.overview("DS1") # warm + + assert second is not None + assert second["cache_hit"] is True + + async def test_overview_truncates_to_50_elements( + self, svc: DatasetBindingService + ): + elements = [ + _make_fake_element(f"e{i}", "n-trode", n_epochs=1) + for i in range(120) + ] + fake = _make_fake_dataset(elements=elements) + + with patch.object(svc, "_download_blocking", return_value=fake): + out = await svc.overview("DS1") + + assert out is not None + # element_count reports the TRUE total even when listing is truncated. + assert out["element_count"] == 120 + # Listing capped at 50. + assert len(out["elements"]) == 50 + assert out["elements_truncated"] is True + # Epoch count covers ALL elements, not just the truncated listing. + assert out["epoch_count"] == 120 + + async def test_overview_returns_none_on_binding_failure( + self, svc: DatasetBindingService + ): + with patch.object( + svc, "_download_blocking", side_effect=RuntimeError("boom") + ): + out = await svc.overview("DS-broken") + + assert out is None + + async def test_overview_tolerates_partial_traversal_failure( + self, svc: DatasetBindingService + ): + """When database_search raises (e.g. malformed query backend), + the overview should still surface element + epoch counts and + return subject_count=0 rather than blanking the whole payload. + """ + def bad_search(_q: Any) -> list[Any]: + raise RuntimeError("simulated DB error") + + fake = _make_fake_dataset( + elements=[_make_fake_element("e0", "n-trode", 2)] + ) + # Override database_search to raise. + fake.database_search = bad_search + + with patch.object(svc, "_download_blocking", return_value=fake): + out = await svc.overview("DS1") + + assert out is not None + assert out["element_count"] == 1 + assert out["epoch_count"] == 2 + # Subject search failed; subject_count fell back to 0. + assert out["subject_count"] == 0 + + +# --------------------------------------------------------------------------- +# last_failure surfacing — added 2026-05-14 so /ndi_overview's 503 +# envelope can carry a specific code/reason instead of the generic +# "not configured" string. The chat tool already routes 503 to a +# graceful fallback hint; the richer reason helps operators diagnose +# in the dashboard without tailing logs. +# --------------------------------------------------------------------------- + + +class TestLastFailure: + """Pins the (code, message) surface contract for each known + failure mode. Each test forces the corresponding cold-path + exit and asserts ``last_failure()`` returns the right code. + """ + + async def test_initial_state_is_none(self, svc: DatasetBindingService): + assert svc.last_failure() is None + + async def test_phase_a_unavailable_sets_code( + self, svc: DatasetBindingService, + ): + """is_ndi_available=False → phase_a_unavailable.""" + with patch( + "backend.services.ndi_python_service.is_ndi_available", + return_value=False, + ): + result = await svc.get_dataset("DS1") + assert result is None + failure = svc.last_failure() + assert failure is not None + assert failure[0] == "phase_a_unavailable" + assert "Phase A" in failure[1] or "not importable" in failure[1] + + async def test_binding_unavailable_sets_code( + self, svc: DatasetBindingService, + ): + """is_ndi_available=True but is_dataset_binding_available=False → + binding_unavailable. Most common cause: Phase A imports fine but + ``ndi.cloud.orchestration`` isn't installed in the deploy image. + """ + with ( + patch( + "backend.services.ndi_python_service.is_ndi_available", + return_value=True, + ), + patch( + "backend.services.ndi_python_service.is_dataset_binding_available", + return_value=False, + ), + ): + result = await svc.get_dataset("DS1") + assert result is None + failure = svc.last_failure() + assert failure is not None + assert failure[0] == "binding_unavailable" + + @pytest.mark.usefixtures("ndi_available") + async def test_cold_load_failed_sets_code_with_exception_class( + self, svc: DatasetBindingService, + ): + """downloadDataset raises → cold_load_failed, with the + exception's class name in the message so operators can grep + for it in dashboards (CloudAuthError vs HTTPError etc). + """ + def boom(_dataset_id: Any) -> Any: + raise RuntimeError("simulated cloud-auth 401") + + with patch.object(svc, "_download_blocking", side_effect=boom): + result = await svc.get_dataset("DS-broken") + assert result is None + failure = svc.last_failure() + assert failure is not None + assert failure[0] == "cold_load_failed" + assert "RuntimeError" in failure[1] + + +# --------------------------------------------------------------------------- +# Cache entry struct — basic invariants +# --------------------------------------------------------------------------- + + +class TestCacheEntry: + def test_loaded_at_equals_first_loaded_at_at_creation(self): + entry = _CacheEntry(dataset="x") + assert entry.loaded_at == entry.first_loaded_at + assert entry.dataset == "x" diff --git a/backend/tests/unit/test_dataset_summary_b6_session_filter.py b/backend/tests/unit/test_dataset_summary_b6_session_filter.py new file mode 100644 index 0000000..7ff8a8a --- /dev/null +++ b/backend/tests/unit/test_dataset_summary_b6_session_filter.py @@ -0,0 +1,741 @@ +"""B6 — parent/aggregate session filter tests. + +A "real" session is one with at least one other document (e.g. +`element_epoch`, `subject`) carrying `depends_on.value` pointing at +its `ndiId`. Parent / aggregate session docs (administrative +containers like Haley's `haley_2025` parent, ingested 10h after the +two leaf recordings) have zero downstream references. + +The filter: +- Skips datasets with `counts.sessions <= 1` (nothing to filter). +- Skips datasets with `counts.totalDocuments <= counts.sessions` + (no non-session docs that could be downstream — newly-published + catalogs or test fixtures). +- Skips datasets with > _MAX_SESSIONS_FILTER_WALK sessions (safety + cap; real multi-day series virtually always have downstream refs). +- Walks each session doc and fires an indexed `depends_on * [ndiId]` + ndiquery to check for downstream refs. +- Fail-open: any lookup error → session counted as real. +- Fail-open: every session looks unreferenced (real_count == 0) → + preserve the raw count (probably a flaky cloud, not a real "all + parents" dataset). +- Emits a `dataset_summary.session_filter` log line whenever the + filtered count differs from the raw count. +""" +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator +from typing import Any + +import pytest +import respx +from cryptography.fernet import Fernet +from httpx import Request, Response + +from backend.clients.ndi_cloud import NdiCloudClient +from backend.services.dataset_summary_service import DatasetSummaryService +from backend.services.ontology_cache import OntologyCache +from backend.services.ontology_service import OntologyService + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +DATASET_ID = "DSX" + + +def _dataset_raw(**overrides: Any) -> dict[str, Any]: + base = { + "_id": DATASET_ID, + "name": "B6 Test Dataset", + "abstract": "", + "license": "CC-BY-4.0", + "createdAt": "2025-09-01T00:00:00.000Z", + "updatedAt": "2026-01-01T00:00:00.000Z", + "contributors": [], + "associatedPublications": [], + } + base.update(overrides) + return base + + +def _counts_raw(**class_counts: int) -> dict[str, Any]: + total = sum(class_counts.values()) + return { + "datasetId": DATASET_ID, + "totalDocuments": total, + "classCounts": class_counts, + } + + +def _session_doc(ndi_id: str, name: str = "", reference: str | None = None) -> dict[str, Any]: + data: dict[str, Any] = { + "base": {"id": ndi_id, "name": name}, + "document_class": {"class_name": "session"}, + } + if reference is not None: + data["session"] = {"reference": reference} + return { + "id": f"mongo-{ndi_id}", + "ndiId": ndi_id, + "data": data, + } + + +@pytest.fixture +async def cloud() -> AsyncGenerator[NdiCloudClient, None]: + import os + os.environ.setdefault("SESSION_ENCRYPTION_KEY", Fernet.generate_key().decode()) + client = NdiCloudClient() + await client.start() + try: + yield client + finally: + await client.close() + + +@pytest.fixture +def ontology_service(tmp_path) -> OntologyService: # type: ignore[no-untyped-def] + """Offline OntologyService — no HTTP, no enrichment side-effects.""" + cache = OntologyCache(db_path=str(tmp_path / "ont.sqlite"), ttl_days=30) + svc = OntologyService(cache) + async def _fake_fetch(provider: str, term_id: str): # type: ignore[no-untyped-def] + from backend.services.ontology_cache import OntologyTerm as CacheTerm + return CacheTerm( + provider=provider, term_id=term_id, label=None, + definition=None, url=None, + ) + svc._fetch_from_provider = _fake_fetch # type: ignore[method-assign] + return svc + + +# --------------------------------------------------------------------------- +# The canonical Haley case — 3 sessions raw, 2 real (1 parent dropped) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_haley_3_sessions_filters_to_2( + cloud: NdiCloudClient, ontology_service: OntologyService, +) -> None: + """3 session docs reported; 2 have downstream element_epoch refs, 1 + is the parent/aggregate with no downstream. Filtered count = 2.""" + async with respx.mock( + base_url="https://api.example.test/v1", assert_all_called=False, + ) as router: + router.get(f"/datasets/{DATASET_ID}").respond( + 200, json=_dataset_raw(), + ) + router.get(f"/datasets/{DATASET_ID}/document-class-counts").respond( + 200, json=_counts_raw(session=3, element_epoch=10), + ) + + # ndiquery dispatcher — different responses for `isa session` + # (fetch session docs to walk) and per-session `depends_on *` + # (reverse-dep check). + leaf_a = "ndi-leaf-a" + leaf_b = "ndi-leaf-b" + parent = "ndi-parent" + session_ids = [leaf_a, leaf_b, parent] + + def _ndiquery_handler(request: Request) -> Response: + body = json.loads(request.content.decode("utf-8")) + ss = body.get("searchstructure", []) + if ss and ss[0].get("operation") == "isa" and ss[0].get("param1") == "session": + return Response( + 200, + json={ + "documents": [{"id": s} for s in session_ids], + "number_matches": len(session_ids), + }, + ) + if ss and ss[0].get("operation") == "depends_on": + target = ss[0].get("param2", [None])[0] + # Leaves are referenced (≥1 downstream); parent isn't. + n = 5 if target in (leaf_a, leaf_b) else 0 + return Response( + 200, + json={ + "documents": [{"id": f"dep-{target}-{i}"} for i in range(n)], + "number_matches": n, + "totalItems": n, + }, + ) + return Response(200, json={"documents": [], "number_matches": 0}) + + router.post("/ndiquery").mock(side_effect=_ndiquery_handler) + + # Bulk-fetch returns the synthetic session docs. + def _bulk_handler(req: Request) -> Response: + body = json.loads(req.content.decode("utf-8")) + ids = body.get("documentIds", []) + return Response( + 200, + json={ + "documents": [ + _session_doc(i, name=f"sess-{i}") for i in ids + ], + }, + ) + router.post(f"/datasets/{DATASET_ID}/documents/bulk-fetch").mock( + side_effect=_bulk_handler, + ) + + svc = DatasetSummaryService(cloud, ontology_service) + summary = await svc.build_summary(DATASET_ID, session=None) + + assert summary.counts.sessions == 2, ( + "Parent session should be filtered out; only the 2 leaves remain" + ) + assert summary.extractionWarnings == [] + + +# --------------------------------------------------------------------------- +# Skip conditions +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_skip_filter_when_only_one_session( + cloud: NdiCloudClient, ontology_service: OntologyService, +) -> None: + """Single-session dataset: skip the per-session walk (nothing to + filter). ndiquery should never be called for the session-filter + purpose.""" + async with respx.mock( + base_url="https://api.example.test/v1", assert_all_called=False, + ) as router: + router.get(f"/datasets/{DATASET_ID}").respond( + 200, json=_dataset_raw(), + ) + router.get(f"/datasets/{DATASET_ID}/document-class-counts").respond( + 200, json=_counts_raw(session=1, element_epoch=10), + ) + nq = router.post("/ndiquery").respond( + 200, json={"documents": [], "number_matches": 0}, + ) + router.post( + f"/datasets/{DATASET_ID}/documents/bulk-fetch", + ).respond(200, json={"documents": []}) + + svc = DatasetSummaryService(cloud, ontology_service) + summary = await svc.build_summary(DATASET_ID, session=None) + + assert summary.counts.sessions == 1 + assert nq.call_count == 0, "No ndiquery for single-session dataset" + + +@pytest.mark.asyncio +async def test_skip_filter_when_dataset_is_all_sessions( + cloud: NdiCloudClient, ontology_service: OntologyService, +) -> None: + """Pure-session dataset (totalDocuments <= sessions): skip the filter + so we don't waste ndiquery calls only to fail-open.""" + async with respx.mock( + base_url="https://api.example.test/v1", assert_all_called=False, + ) as router: + router.get(f"/datasets/{DATASET_ID}").respond( + 200, json=_dataset_raw(), + ) + router.get(f"/datasets/{DATASET_ID}/document-class-counts").respond( + 200, json=_counts_raw(session=3), # only sessions, nothing else + ) + nq = router.post("/ndiquery").respond( + 200, json={"documents": [], "number_matches": 0}, + ) + router.post( + f"/datasets/{DATASET_ID}/documents/bulk-fetch", + ).respond(200, json={"documents": []}) + + svc = DatasetSummaryService(cloud, ontology_service) + summary = await svc.build_summary(DATASET_ID, session=None) + + assert summary.counts.sessions == 3, ( + "All-sessions dataset preserves raw count" + ) + assert nq.call_count == 0 + + +# --------------------------------------------------------------------------- +# Fail-open semantics +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_all_zero_downstream_falls_back_to_raw_count( + cloud: NdiCloudClient, ontology_service: OntologyService, +) -> None: + """If every session looks unreferenced (newly-published catalog + with no element_epoch ingested yet, OR a cloud-side reverse-dep + outage), preserve the raw count.""" + async with respx.mock( + base_url="https://api.example.test/v1", assert_all_called=False, + ) as router: + router.get(f"/datasets/{DATASET_ID}").respond( + 200, json=_dataset_raw(), + ) + router.get(f"/datasets/{DATASET_ID}/document-class-counts").respond( + 200, json=_counts_raw(session=2, element_epoch=5), + ) + + session_ids = ["ndi-s1", "ndi-s2"] + + def _ndiquery_handler(req: Request) -> Response: + body = json.loads(req.content.decode("utf-8")) + ss = body.get("searchstructure", []) + if ss and ss[0].get("operation") == "isa" and ss[0].get("param1") == "session": + return Response( + 200, + json={ + "documents": [{"id": s} for s in session_ids], + "number_matches": len(session_ids), + }, + ) + # Every depends_on lookup returns 0. + return Response(200, json={"documents": [], "number_matches": 0}) + + router.post("/ndiquery").mock(side_effect=_ndiquery_handler) + + def _bulk(req: Request) -> Response: + ids = json.loads(req.content.decode("utf-8")).get("documentIds", []) + return Response(200, json={ + "documents": [_session_doc(i) for i in ids], + }) + router.post( + f"/datasets/{DATASET_ID}/documents/bulk-fetch", + ).mock(side_effect=_bulk) + + svc = DatasetSummaryService(cloud, ontology_service) + summary = await svc.build_summary(DATASET_ID, session=None) + + # Fail-open: when ALL sessions look unreferenced, return raw count + # rather than reporting 0 sessions. + assert summary.counts.sessions == 2 + + +@pytest.mark.asyncio +async def test_reverse_dep_lookup_failure_keeps_session( + cloud: NdiCloudClient, ontology_service: OntologyService, +) -> None: + """A 5xx on the per-session reverse-dep ndiquery means we can't + determine the session is parent — fail-open and count it as real + so we never silently drop a legitimate recording.""" + async with respx.mock( + base_url="https://api.example.test/v1", assert_all_called=False, + ) as router: + router.get(f"/datasets/{DATASET_ID}").respond( + 200, json=_dataset_raw(), + ) + router.get(f"/datasets/{DATASET_ID}/document-class-counts").respond( + 200, json=_counts_raw(session=2, element_epoch=5), + ) + + session_ids = ["ndi-s1", "ndi-s2"] + call_count = {"reverse": 0} + + def _ndiquery_handler(req: Request) -> Response: + body = json.loads(req.content.decode("utf-8")) + ss = body.get("searchstructure", []) + if ss and ss[0].get("operation") == "isa" and ss[0].get("param1") == "session": + return Response( + 200, + json={ + "documents": [{"id": s} for s in session_ids], + "number_matches": len(session_ids), + }, + ) + # Reverse-dep query: first call 503, second call ok-but-empty. + call_count["reverse"] += 1 + if call_count["reverse"] == 1: + return Response(503, json={"error": "upstream"}) + return Response(200, json={"documents": [{"id": "dep1"}], "number_matches": 1}) + + router.post("/ndiquery").mock(side_effect=_ndiquery_handler) + + def _bulk(req: Request) -> Response: + ids = json.loads(req.content.decode("utf-8")).get("documentIds", []) + return Response(200, json={ + "documents": [_session_doc(i) for i in ids], + }) + router.post( + f"/datasets/{DATASET_ID}/documents/bulk-fetch", + ).mock(side_effect=_bulk) + + svc = DatasetSummaryService(cloud, ontology_service) + summary = await svc.build_summary(DATASET_ID, session=None) + + # Both kept: the 503 session is fail-opened as real; the second is + # genuinely real with 1 downstream. + assert summary.counts.sessions == 2 + + +@pytest.mark.asyncio +async def test_session_class_fetch_failure_keeps_raw_count( + cloud: NdiCloudClient, ontology_service: OntologyService, +) -> None: + """If the `isa session` ndiquery (used to fetch session docs for + inspection) fails, we can't walk per-session reverse-deps. Keep + the raw count and surface a typed extractionWarnings entry.""" + async with respx.mock( + base_url="https://api.example.test/v1", assert_all_called=False, + ) as router: + router.get(f"/datasets/{DATASET_ID}").respond( + 200, json=_dataset_raw(), + ) + router.get(f"/datasets/{DATASET_ID}/document-class-counts").respond( + 200, json=_counts_raw(session=3, element_epoch=5), + ) + + def _ndiquery_handler(req: Request) -> Response: + body = json.loads(req.content.decode("utf-8")) + ss = body.get("searchstructure", []) + if ss and ss[0].get("operation") == "isa" and ss[0].get("param1") == "session": + return Response(503, json={"error": "ndiquery timeout"}) + return Response(200, json={"documents": [], "number_matches": 0}) + + router.post("/ndiquery").mock(side_effect=_ndiquery_handler) + router.post( + f"/datasets/{DATASET_ID}/documents/bulk-fetch", + ).respond(200, json={"documents": []}) + + svc = DatasetSummaryService(cloud, ontology_service) + summary = await svc.build_summary(DATASET_ID, session=None) + + # Raw count preserved on fetch failure. + assert summary.counts.sessions == 3 + # A typed warning surfaced (either from the session-class fetch or + # the filter's own catch). + assert any( + "session" in w.lower() and "fail" in w.lower() + for w in summary.extractionWarnings + ) + + +# --------------------------------------------------------------------------- +# Safety cap +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_skip_filter_when_session_count_exceeds_cap( + cloud: NdiCloudClient, ontology_service: OntologyService, +) -> None: + """Datasets with > _MAX_SESSIONS_FILTER_WALK (50) sessions skip the + filter entirely — for genuine multi-day series the count is correct, + and the cost of an O(N) walk isn't justified.""" + from backend.services.dataset_summary_service import _MAX_SESSIONS_FILTER_WALK + + async with respx.mock( + base_url="https://api.example.test/v1", assert_all_called=False, + ) as router: + router.get(f"/datasets/{DATASET_ID}").respond( + 200, json=_dataset_raw(), + ) + n = _MAX_SESSIONS_FILTER_WALK + 1 + router.get(f"/datasets/{DATASET_ID}/document-class-counts").respond( + 200, json=_counts_raw(session=n, element_epoch=200), + ) + nq = router.post("/ndiquery").respond( + 200, json={"documents": [], "number_matches": 0}, + ) + router.post( + f"/datasets/{DATASET_ID}/documents/bulk-fetch", + ).respond(200, json={"documents": []}) + + svc = DatasetSummaryService(cloud, ontology_service) + summary = await svc.build_summary(DATASET_ID, session=None) + + assert summary.counts.sessions == n, "Over-cap sessions stay raw" + assert nq.call_count == 0, "No ndiquery when over the safety cap" + + +# --------------------------------------------------------------------------- +# Constant + helper unit tests +# --------------------------------------------------------------------------- + +def test_max_sessions_filter_walk_is_50() -> None: + """Pin the safety cap. Sized so every real published dataset + (Haley=3, Bhar=2, Francesconi=1, ...) is well under, while still + bounding the worst case.""" + from backend.services.dataset_summary_service import _MAX_SESSIONS_FILTER_WALK + assert _MAX_SESSIONS_FILTER_WALK == 50 + + +def test_filtered_sessions_or_warn_passes_int_unchanged() -> None: + from backend.services.dataset_summary_service import _filtered_sessions_or_warn + warnings: list[str] = [] + assert _filtered_sessions_or_warn(2, raw=3, warnings=warnings) == 2 + assert warnings == [] + + +def test_filtered_sessions_or_warn_fail_open_on_exception() -> None: + from backend.services.dataset_summary_service import _filtered_sessions_or_warn + warnings: list[str] = [] + err = RuntimeError("boom") + assert _filtered_sessions_or_warn(err, raw=5, warnings=warnings) == 5 + assert len(warnings) == 1 + assert "session filter failed" in warnings[0].lower() + + +def test_filtered_sessions_or_warn_fail_open_on_unexpected_type() -> None: + from backend.services.dataset_summary_service import _filtered_sessions_or_warn + warnings: list[str] = [] + assert _filtered_sessions_or_warn("not an int", raw=4, warnings=warnings) == 4 + assert len(warnings) == 1 + assert "unexpected type" in warnings[0].lower() + + +@pytest.mark.asyncio +async def test_identity_int_returns_input() -> None: + from backend.services.dataset_summary_service import _identity_int + assert await _identity_int(0) == 0 + assert await _identity_int(7) == 7 + assert await _identity_int(-1) == -1 # caller's responsibility to gate + + +# --------------------------------------------------------------------------- +# Prefix-suffix fallback (Haley case: depends_on heuristic returns 0) +# --------------------------------------------------------------------------- + +def test_filter_by_reference_prefix_haley_case() -> None: + """Canonical Haley shape: 3 sessions, parent + 2 leaves. Filter + should return 2 (the leaves) when no session has downstream refs.""" + from backend.services.dataset_summary_service import _filter_by_reference_prefix + docs = [ + _session_doc("ndi-celegans", reference="haley_2025_Celegans"), + _session_doc("ndi-ecoli", reference="haley_2025_Ecoli"), + _session_doc("ndi-parent", reference="haley_2025"), + ] + assert _filter_by_reference_prefix(docs) == 2 + + +def test_filter_by_reference_prefix_all_leaves_no_parent() -> None: + """No prefix-suffix relationship: every session counted as real.""" + from backend.services.dataset_summary_service import _filter_by_reference_prefix + docs = [ + _session_doc("ndi-a", reference="dataset_a"), + _session_doc("ndi-b", reference="dataset_b"), + ] + assert _filter_by_reference_prefix(docs) == 2 + + +def test_filter_by_reference_prefix_returns_none_when_refs_missing() -> None: + """If any session lacks a `reference` field, bail with None so the + caller can fall back to raw count.""" + from backend.services.dataset_summary_service import _filter_by_reference_prefix + docs = [ + _session_doc("ndi-a", reference="haley_2025"), + _session_doc("ndi-b"), # no reference field + ] + assert _filter_by_reference_prefix(docs) is None + + +def test_filter_by_reference_prefix_returns_none_when_all_same_ref() -> None: + """All sessions share the same reference string — ambiguous. Bail.""" + from backend.services.dataset_summary_service import _filter_by_reference_prefix + docs = [ + _session_doc("ndi-a", reference="haley_2025"), + _session_doc("ndi-b", reference="haley_2025"), + ] + assert _filter_by_reference_prefix(docs) is None + + +def test_filter_by_reference_prefix_requires_underscore_separator() -> None: + """`haley` is a prefix of `haley_2025` but only via underscore; without + the `_` we don't claim a parent relationship — could be coincidence.""" + from backend.services.dataset_summary_service import _filter_by_reference_prefix + docs = [ + _session_doc("ndi-a", reference="haley2025"), # no `_` separator + _session_doc("ndi-b", reference="haley"), + ] + # `haley` + `_` = `haley_`, not a prefix of `haley2025` → both real + assert _filter_by_reference_prefix(docs) == 2 + + +def test_filter_by_reference_prefix_multi_level_tree() -> None: + """Three-level hierarchy: parent + intermediate + 2 leaves. + Only the deepest leaves are real.""" + from backend.services.dataset_summary_service import _filter_by_reference_prefix + docs = [ + _session_doc("p", reference="proj"), + _session_doc("i1", reference="proj_phase1"), + _session_doc("l1", reference="proj_phase1_a"), + _session_doc("l2", reference="proj_phase1_b"), + ] + # Real leaves: proj_phase1_a, proj_phase1_b (nothing extends them) + # Parents: proj (proj_phase1 extends it), proj_phase1 (the leaves extend it) + assert _filter_by_reference_prefix(docs) == 2 + + +def test_filter_by_reference_prefix_single_session() -> None: + """Single session — no comparator possible, bail with None.""" + from backend.services.dataset_summary_service import _filter_by_reference_prefix + docs = [_session_doc("a", reference="haley_2025")] + assert _filter_by_reference_prefix(docs) is None + + +def test_session_reference_extracts_from_session_block() -> None: + """Extractor reads `data.session.reference` from doc body.""" + from backend.services.dataset_summary_service import _session_reference + doc = _session_doc("a", reference="x_y_z") + assert _session_reference(doc) == "x_y_z" + + +def test_session_reference_falls_back_through_alternates() -> None: + """Falls back to `session_reference` then `name` if `reference` + is absent — robust to minor schema variation across labs.""" + from backend.services.dataset_summary_service import _session_reference + # Synthesize a doc with session_reference key + doc = { + "id": "x", + "ndiId": "n", + "data": { + "base": {"id": "n"}, + "session": {"session_reference": "alt_form"}, + }, + } + assert _session_reference(doc) == "alt_form" + + +def test_session_reference_returns_none_when_missing() -> None: + """No `session` block at all → None.""" + from backend.services.dataset_summary_service import _session_reference + doc = {"id": "x", "ndiId": "n", "data": {"base": {"id": "n"}}} + assert _session_reference(doc) is None + + +# --------------------------------------------------------------------------- +# Composition: prefix-heuristic refines depends_on when both apply +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_prefix_refines_depends_on_when_parent_has_admin_ref( + cloud: NdiCloudClient, ontology_service: OntologyService, +) -> None: + """The Haley production case: all 3 sessions return >0 via + depends_on (the parent is referenced by an admin doc like + `dataset_session_info`), so depends_on alone wouldn't filter. + But the reference-prefix heuristic returns 2 (the parent's name + is a prefix of both leaves'). The composition policy prefers + the more selective prefix result. + """ + async with respx.mock( + base_url="https://api.example.test/v1", assert_all_called=False, + ) as router: + router.get(f"/datasets/{DATASET_ID}").respond( + 200, json=_dataset_raw(), + ) + router.get(f"/datasets/{DATASET_ID}/document-class-counts").respond( + 200, json=_counts_raw(session=3, element_epoch=10, dataset_session_info=1), + ) + + leaf_a = "ndi-leaf-a" + leaf_b = "ndi-leaf-b" + parent = "ndi-parent" + session_ids = [leaf_a, leaf_b, parent] + + def _ndiquery_handler(request: Request) -> Response: + body = json.loads(request.content.decode("utf-8")) + ss = body.get("searchstructure", []) + if ss and ss[0].get("operation") == "isa" and ss[0].get("param1") == "session": + return Response( + 200, + json={ + "documents": [{"id": s} for s in session_ids], + "number_matches": len(session_ids), + }, + ) + if ss and ss[0].get("operation") == "depends_on": + # Every session has ≥1 downstream — parent gets a hit + # from the dataset_session_info admin doc, leaves get + # hits from element_epochs. + return Response( + 200, + json={ + "documents": [{"id": "dep-x"}], + "number_matches": 1, + "totalItems": 1, + }, + ) + return Response(200, json={"documents": [], "number_matches": 0}) + + router.post("/ndiquery").mock(side_effect=_ndiquery_handler) + + def _bulk_handler(req: Request) -> Response: + body = json.loads(req.content.decode("utf-8")) + ids = body.get("documentIds", []) + ref_for = { + leaf_a: "haley_2025_Celegans", + leaf_b: "haley_2025_Ecoli", + parent: "haley_2025", + } + return Response( + 200, + json={ + "documents": [ + _session_doc(i, reference=ref_for.get(i, "")) for i in ids + ], + }, + ) + router.post( + f"/datasets/{DATASET_ID}/documents/bulk-fetch", + ).mock(side_effect=_bulk_handler) + + svc = DatasetSummaryService(cloud, ontology_service) + summary = await svc.build_summary(DATASET_ID, session=None) + + # Prefix heuristic wins: 2 leaves, parent filtered out, despite + # depends_on saying all 3 are "referenced". + assert summary.counts.sessions == 2 + + +@pytest.mark.asyncio +async def test_depends_on_used_when_prefix_inconclusive( + cloud: NdiCloudClient, ontology_service: OntologyService, +) -> None: + """When prefix returns None (no reference fields, ambiguous names), + fall back to depends_on as the canonical signal. Two leaves have + downstream refs; parent doesn't → filter to 2.""" + async with respx.mock( + base_url="https://api.example.test/v1", assert_all_called=False, + ) as router: + router.get(f"/datasets/{DATASET_ID}").respond( + 200, json=_dataset_raw(), + ) + router.get(f"/datasets/{DATASET_ID}/document-class-counts").respond( + 200, json=_counts_raw(session=3, element_epoch=10), + ) + leaf_a, leaf_b, parent = "n-a", "n-b", "n-p" + session_ids = [leaf_a, leaf_b, parent] + + def _ndiquery_handler(request: Request) -> Response: + body = json.loads(request.content.decode("utf-8")) + ss = body.get("searchstructure", []) + if ss and ss[0].get("operation") == "isa" and ss[0].get("param1") == "session": + return Response( + 200, + json={"documents": [{"id": s} for s in session_ids], + "number_matches": len(session_ids)}, + ) + if ss and ss[0].get("operation") == "depends_on": + target = ss[0].get("param2", [None])[0] + n = 1 if target in (leaf_a, leaf_b) else 0 + return Response( + 200, + json={"documents": [{"id": "x"}] * n, + "number_matches": n, "totalItems": n}, + ) + return Response(200, json={"documents": [], "number_matches": 0}) + router.post("/ndiquery").mock(side_effect=_ndiquery_handler) + + # Sessions don't have `reference` field → prefix returns None + def _bulk_handler(req: Request) -> Response: + ids = json.loads(req.content.decode("utf-8")).get("documentIds", []) + return Response(200, json={ + "documents": [_session_doc(i) for i in ids], # no reference + }) + router.post( + f"/datasets/{DATASET_ID}/documents/bulk-fetch", + ).mock(side_effect=_bulk_handler) + + svc = DatasetSummaryService(cloud, ontology_service) + summary = await svc.build_summary(DATASET_ID, session=None) + + assert summary.counts.sessions == 2 # via depends_on diff --git a/backend/tests/unit/test_dataset_summary_service.py b/backend/tests/unit/test_dataset_summary_service.py index 2c96fe2..5bd60a0 100644 --- a/backend/tests/unit/test_dataset_summary_service.py +++ b/backend/tests/unit/test_dataset_summary_service.py @@ -501,10 +501,10 @@ async def test_user_cache_keys_are_isolated() -> None: bob = _minimal_session("bob") assert summary_cache_key(dataset_id, alice) != summary_cache_key(dataset_id, bob) assert summary_cache_key(dataset_id, None) == ( - f"summary:v1:{dataset_id}:public" + f"summary:v8:{dataset_id}:public" ) assert summary_cache_key(dataset_id, alice) == ( - f"summary:v1:{dataset_id}:{user_scope_for(alice)}" + f"summary:v8:{dataset_id}:{user_scope_for(alice)}" ) @@ -890,6 +890,102 @@ async def producer() -> dict[str, Any]: assert kwargs["ex"] == 24 * 60 * 60 +def test_counts_from_raw_uses_element_epoch_when_present() -> None: + """`element_epoch` is the primary epoch class for newer NDI datasets.""" + from backend.services.dataset_summary_service import _counts_from_raw + counts = _counts_from_raw(_counts_raw(subject=2, element_epoch=42)) + assert counts.epochs == 42 + + +def test_counts_from_raw_falls_back_to_plain_epoch() -> None: + """Legacy `epoch` class is the second choice.""" + from backend.services.dataset_summary_service import _counts_from_raw + counts = _counts_from_raw(_counts_raw(subject=2, epoch=7)) + assert counts.epochs == 7 + + +def test_counts_from_raw_falls_back_to_epochfiles_ingested() -> None: + """Francesconi pattern: Phase-A ingest emits `epochfiles_ingested` + per epoch file, neither `element_epoch` nor `epoch`. Without this + fallback the EPOCHS chip read 0 on the workspace even though + the dataset has thousands of epochs (the 2026-05-14 parity bug).""" + from backend.services.dataset_summary_service import _counts_from_raw + counts = _counts_from_raw(_counts_raw(subject=215, epochfiles_ingested=1604)) + assert counts.epochs == 1604 + + +def test_counts_from_raw_falls_back_to_daqreader_mfdaq_epochdata() -> None: + """Some Van Hooser lab datasets emit `daqreader_mfdaq_epochdata_ingested` + instead. Should also resolve.""" + from backend.services.dataset_summary_service import _counts_from_raw + counts = _counts_from_raw( + _counts_raw(subject=10, daqreader_mfdaq_epochdata_ingested=33), + ) + assert counts.epochs == 33 + + +def test_counts_from_raw_picks_first_non_zero_no_double_count() -> None: + """When both `epochfiles_ingested` and `daqreader_mfdaq_epochdata_ingested` + are present (1:1 with the same set of epochs in some datasets), the + fallback chain picks the first hit only — summing them would + double-count. Priority: element_epoch > epoch > epochfiles_ingested + > daqreader_mfdaq_epochdata_ingested.""" + from backend.services.dataset_summary_service import _counts_from_raw + counts = _counts_from_raw( + _counts_raw( + subject=10, + epochfiles_ingested=1604, + daqreader_mfdaq_epochdata_ingested=1605, + ), + ) + # First in chain wins, not the sum. + assert counts.epochs == 1604 + + +def test_counts_from_raw_returns_zero_when_no_epoch_class_present() -> None: + """Bhar (C. elegans behavior) has no epoch documents at all — + EPOCHS=0 is the correct chip value for that dataset.""" + from backend.services.dataset_summary_service import _counts_from_raw + counts = _counts_from_raw( + _counts_raw(subject=5314, subject_group=235, treatment_drug=24466), + ) + assert counts.epochs == 0 + + +# F-1c (2026-05-19) — `counts.probes` mirrors `counts.elements` when +# the literal `probe` class is zero. Python runtime treats `probe` +# as an alias for `element` (services/class_aliases.py). Three pins: +# - literal probe non-zero → use literal +# - literal probe zero + element non-zero → use element (alias hit) +# - both zero → 0 +def test_counts_from_raw_probes_uses_literal_when_non_zero() -> None: + """When the dataset emits literal `probe` documents, `counts.probes` + reports that count even if there are also `element` documents.""" + from backend.services.dataset_summary_service import _counts_from_raw + counts = _counts_from_raw(_counts_raw(probe=3, element=10)) + assert counts.probes == 3 + assert counts.elements == 10 + + +def test_counts_from_raw_probes_alias_to_elements_when_literal_zero() -> None: + """Francesconi pattern: 0 literal `probe` + 606 `element` documents. + `counts.probes` should resolve via the alias chain to 606 so the + snapshot tile matches the catalog's probeTypes facet.""" + from backend.services.dataset_summary_service import _counts_from_raw + counts = _counts_from_raw(_counts_raw(element=606)) + assert counts.probes == 606 + assert counts.elements == 606 + + +def test_counts_from_raw_probes_zero_when_both_classes_zero() -> None: + """Datasets with neither `probe` nor `element` documents (pure + behavioral, e.g. Bhar C. elegans) report `counts.probes: 0`.""" + from backend.services.dataset_summary_service import _counts_from_raw + counts = _counts_from_raw(_counts_raw(subject=5314, subject_group=235)) + assert counts.probes == 0 + assert counts.elements == 0 + + def test_summary_schema_version_literal() -> None: with pytest.raises(Exception): # noqa: B017 — any pydantic ValidationError variant DatasetSummary.model_validate({ diff --git a/backend/tests/unit/test_dependencies.py b/backend/tests/unit/test_dependencies.py index f8dd055..cd73bf5 100644 --- a/backend/tests/unit/test_dependencies.py +++ b/backend/tests/unit/test_dependencies.py @@ -197,13 +197,19 @@ async def test_ip_change_logs_warning_allows_request(fake_redis: Any) -> None: ip_events = [e for e in logs if e.get("event") == "session.ip_changed"] assert len(ip_events) == 1 e = ip_events[0] - assert e["session_id"] == session.session_id + # session_id is truncated to 8 chars (the full id is the session + # secret — see dependencies.py for the rationale). The prefix is + # still enough to correlate log lines for one session. + assert e["session_id"] == session.session_id[:8] assert e["stored_ip_hash"] == session.ip_addr_hash assert e["current_ip_hash"] != session.ip_addr_hash # No raw IPs in the log line. payload = str(e) assert "192.168.1.10" not in payload assert "10.0.0.5" not in payload + # And the full session id MUST NOT appear anywhere in the + # captured event payload — otherwise log-readers could replay it. + assert session.session_id not in payload @pytest.mark.asyncio diff --git a/backend/tests/unit/test_dependency_direction_filter.py b/backend/tests/unit/test_dependency_direction_filter.py new file mode 100644 index 0000000..d33b825 --- /dev/null +++ b/backend/tests/unit/test_dependency_direction_filter.py @@ -0,0 +1,115 @@ +"""F-3 (2026-05-19) — direction filter post-walk. + +The full bidirectional walk is preserved; the new ``direction`` query +param post-filters the cached result. These tests target the pure +``_filter_graph_by_direction`` helper plus the assertion that +``get_graph`` accepts the parameter. +""" +from __future__ import annotations + +import pytest + +from backend.services.dependency_graph_service import ( + VALID_DEPENDENCY_DIRECTIONS, + _filter_graph_by_direction, +) + + +def _graph_fixture(): + """A canonical bidirectional graph for a target ``T`` with: + - 2 upstream ancestors (A1, A2) + - 2 downstream consumers (D1, D2) + - target T itself + + Edges: + A1 → T (upstream) # A1 produced T + A2 → A1 (upstream) # A2 produced A1 (transitive ancestor) + D1 → T (downstream) # D1 consumed T + D2 → D1 (downstream) # D2 consumed D1 (transitive consumer) + """ + nodes = [ + {"ndi_id": "T", "name": "target"}, + {"ndi_id": "A1", "name": "ancestor1"}, + {"ndi_id": "A2", "name": "ancestor2"}, + {"ndi_id": "D1", "name": "consumer1"}, + {"ndi_id": "D2", "name": "consumer2"}, + ] + edges = [ + {"source": "T", "target": "A1", "direction": "upstream", "label": "depends_on"}, + {"source": "A1", "target": "A2", "direction": "upstream", "label": "depends_on"}, + {"source": "D1", "target": "T", "direction": "downstream", "label": "depends_on"}, + {"source": "D2", "target": "D1", "direction": "downstream", "label": "depends_on"}, + ] + return { + "target_id": "T", + "target_ndi_id": "T", + "nodes": nodes, + "edges": edges, + "node_count": len(nodes), + "edge_count": len(edges), + "truncated": False, + "max_depth": 3, + } + + +def test_both_passes_full_graph_through_unchanged(): + g = _graph_fixture() + out = _filter_graph_by_direction(g, "both") + assert out is g # passthrough — no allocation + assert out["node_count"] == 5 + assert out["edge_count"] == 4 + assert "direction_filter" not in out + + +def test_upstream_drops_downstream_edges(): + g = _graph_fixture() + out = _filter_graph_by_direction(g, "upstream") + assert out["direction_filter"] == "upstream" + edge_dirs = [e["direction"] for e in out["edges"]] + assert edge_dirs == ["upstream", "upstream"] + # Target + 2 ancestors retained; D1 + D2 dropped. + ndi_ids = {n["ndi_id"] for n in out["nodes"]} + assert ndi_ids == {"T", "A1", "A2"} + assert out["edge_count"] == 2 + assert out["node_count"] == 3 + + +def test_downstream_drops_upstream_edges(): + g = _graph_fixture() + out = _filter_graph_by_direction(g, "downstream") + assert out["direction_filter"] == "downstream" + edge_dirs = [e["direction"] for e in out["edges"]] + assert edge_dirs == ["downstream", "downstream"] + ndi_ids = {n["ndi_id"] for n in out["nodes"]} + assert ndi_ids == {"T", "D1", "D2"} + assert out["edge_count"] == 2 + assert out["node_count"] == 3 + + +def test_invalid_direction_defaults_to_both(): + g = _graph_fixture() + out = _filter_graph_by_direction(g, "sideways") + # Same passthrough as 'both'. + assert out is g + + +def test_empty_target_id_still_works(): + g = _graph_fixture() + g["target_ndi_id"] = None + out = _filter_graph_by_direction(g, "upstream") + # Without a target the reachability seed is empty; only edge + # endpoints survive. + ndi_ids = {n["ndi_id"] for n in out["nodes"]} + assert ndi_ids == {"T", "A1", "A2"} + + +def test_valid_directions_constant(): + assert VALID_DEPENDENCY_DIRECTIONS == {"both", "upstream", "downstream"} + + +@pytest.mark.parametrize("direction", ["upstream", "downstream"]) +def test_target_always_retained(direction: str): + g = _graph_fixture() + out = _filter_graph_by_direction(g, direction) + ndi_ids = {n["ndi_id"] for n in out["nodes"]} + assert "T" in ndi_ids diff --git a/backend/tests/unit/test_document_service_class_alias.py b/backend/tests/unit/test_document_service_class_alias.py new file mode 100644 index 0000000..3b88bd2 --- /dev/null +++ b/backend/tests/unit/test_document_service_class_alias.py @@ -0,0 +1,234 @@ +"""DocumentService.list_by_class — class-alias resolution (B2, 2026-05-18). + +Workspace picker repro: `/api/datasets/682e7772cdf3f24938176fac/documents +?class=probe` returned 0 documents on Haley (`682e7772…`) despite the +dataset carrying 4156 ``element`` docs. The summary-table service's +``_build_single_class`` already walked the canonical ``probe → element`` +alias chain; this test file pins the same behaviour into +``DocumentService.list_by_class`` so the workspace's class-filtered +picker queries follow the chain too. + +These tests pin: + - Literal-class hit: ndiquery returns rows → no alias walk. + - Alias resolution: literal=0, first chain entry non-empty → use it, + surface total + log the alias_hit. + - Alias chain exhausted: literal=0 AND every alias=0 → return empty. + - Class without an alias entry (e.g. `subject`): no chain walk + regardless of literal-zero result. + - Total comes from the alias body's ``number_matches`` (not from + the page-length fallback). +""" +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from backend.services.document_service import DocumentService + +DATASET_ID = "682e7772cdf3f24938176fac" # Haley (real-world repro target) + + +def _stub_cloud() -> Any: + class _Stub: + ndiquery = AsyncMock() + get_dataset = AsyncMock() + bulk_fetch = AsyncMock() + list_documents_by_dataset = AsyncMock(return_value=[]) + get_dataset_document_count = AsyncMock(return_value=0) + + return _Stub() + + +@pytest.mark.asyncio +async def test_literal_class_hit_skips_alias_walk() -> None: + """When the literal class returns rows, the alias chain MUST NOT + fire — back-compat for datasets that DO carry literal `probe` docs + (Dabrowska).""" + cloud = _stub_cloud() + cloud.ndiquery.return_value = { + "documents": [{"id": "id-001"}, {"id": "id-002"}], + "number_matches": 2, + } + cloud.bulk_fetch.return_value = [ + {"id": "id-001", "className": "probe"}, + {"id": "id-002", "className": "probe"}, + ] + + svc = DocumentService(cloud) + result = await svc.list_by_class( + dataset_id=DATASET_ID, + class_name="probe", + page=1, + page_size=50, + access_token=None, + ) + + assert result["total"] == 2 + assert len(result["documents"]) == 2 + # Exactly ONE ndiquery: the literal lookup. No alias retry. + assert cloud.ndiquery.await_count == 1 + + +@pytest.mark.asyncio +async def test_probe_alias_resolves_to_element() -> None: + """Haley-style repro: literal `probe` returns 0, alias `element` + returns 4156. ``slice_ids`` must come from the alias body, and the + total must be the alias body's ``number_matches`` (not 0).""" + cloud = _stub_cloud() + # First call (probe): empty. + # Second call (element alias): 50 ids + 4156 total. + element_ids = [f"el-{i:04d}" for i in range(50)] + cloud.ndiquery.side_effect = [ + {"documents": [], "number_matches": 0}, + { + "documents": [{"id": x} for x in element_ids], + "number_matches": 4156, + }, + ] + cloud.bulk_fetch.return_value = [ + {"id": x, "className": "element"} for x in element_ids + ] + + svc = DocumentService(cloud) + result = await svc.list_by_class( + dataset_id=DATASET_ID, + class_name="probe", + page=1, + page_size=50, + access_token=None, + ) + + assert result["total"] == 4156 + assert result["page"] == 1 + assert result["pageSize"] == 50 + assert len(result["documents"]) == 50 + # Two ndiquery calls: literal then alias. + assert cloud.ndiquery.await_count == 2 + # Bulk-fetch should target the alias ids. + cloud.bulk_fetch.assert_awaited_once_with( + DATASET_ID, element_ids, access_token=None, + ) + + +@pytest.mark.asyncio +async def test_alias_chain_exhausted_returns_empty() -> None: + """`epoch` has three aliases (element_epoch, epochfiles_ingested, + daqreader_mfdaq_epochdata_ingested). When ALL of them return 0, + the service should return the original empty result rather than + falling through to the inline/Mongo fallbacks (which are + class-filtered out by design).""" + cloud = _stub_cloud() + # Literal + 3 alias probes = 4 calls, all empty. + cloud.ndiquery.return_value = {"documents": [], "number_matches": 0} + + svc = DocumentService(cloud) + result = await svc.list_by_class( + dataset_id=DATASET_ID, + class_name="epoch", + page=1, + page_size=50, + access_token=None, + ) + + assert result["total"] == 0 + assert result["documents"] == [] + # Literal + 3 alias retries = 4 ndiquery calls. + assert cloud.ndiquery.await_count == 4 + # Inline/Mongo fallbacks must NOT fire for class-filtered queries. + cloud.get_dataset.assert_not_awaited() + cloud.list_documents_by_dataset.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_first_non_empty_alias_wins() -> None: + """For `epoch`, if `element_epoch` is empty but + `epochfiles_ingested` has rows, we should stop after the second + alias (don't keep probing the third).""" + cloud = _stub_cloud() + legacy_ids = ["legacy-001", "legacy-002"] + cloud.ndiquery.side_effect = [ + # Literal epoch — empty. + {"documents": [], "number_matches": 0}, + # First alias element_epoch — empty. + {"documents": [], "number_matches": 0}, + # Second alias epochfiles_ingested — non-empty, WIN. + { + "documents": [{"id": x} for x in legacy_ids], + "number_matches": 7, + }, + ] + cloud.bulk_fetch.return_value = [ + {"id": x, "className": "epochfiles_ingested"} for x in legacy_ids + ] + + svc = DocumentService(cloud) + result = await svc.list_by_class( + dataset_id=DATASET_ID, + class_name="epoch", + page=1, + page_size=50, + access_token=None, + ) + + assert result["total"] == 7 + assert len(result["documents"]) == 2 + # Three calls — should not probe the third alias. + assert cloud.ndiquery.await_count == 3 + + +@pytest.mark.asyncio +async def test_class_without_alias_entry_no_chain_walk() -> None: + """`subject` is not in CLASS_ALIASES. Literal=0 should NOT trigger + any alias retry — the chain is opt-in per class.""" + cloud = _stub_cloud() + cloud.ndiquery.return_value = {"documents": [], "number_matches": 0} + + svc = DocumentService(cloud) + result = await svc.list_by_class( + dataset_id=DATASET_ID, + class_name="subject", + page=1, + page_size=50, + access_token=None, + ) + + assert result["total"] == 0 + assert result["documents"] == [] + # Exactly ONE ndiquery: the literal lookup. No alias retries. + assert cloud.ndiquery.await_count == 1 + + +@pytest.mark.asyncio +async def test_alias_resolution_uses_page_params() -> None: + """The alias retry must respect the caller's page + page_size — we + don't want page 2 to silently reset to page 1 when alias-resolved.""" + cloud = _stub_cloud() + ids = [f"id-{i:03d}" for i in range(50)] + cloud.ndiquery.side_effect = [ + {"documents": [], "number_matches": 0}, + { + "documents": [{"id": x} for x in ids], + "number_matches": 4156, + }, + ] + cloud.bulk_fetch.return_value = [{"id": x} for x in ids] + + svc = DocumentService(cloud) + await svc.list_by_class( + dataset_id=DATASET_ID, + class_name="probe", + page=3, + page_size=50, + access_token=None, + ) + + # Inspect the call args of the second ndiquery (the alias probe). + alias_call = cloud.ndiquery.await_args_list[1] + assert alias_call.kwargs["page"] == 3 + assert alias_call.kwargs["page_size"] == 50 + # The alias's searchstructure must use the alias name, not the + # original class. + structure = alias_call.kwargs["searchstructure"] + assert structure == [{"operation": "isa", "param1": "element"}] diff --git a/backend/tests/unit/test_facet_service_dedupe.py b/backend/tests/unit/test_facet_service_dedupe.py index 0b47f77..237a156 100644 --- a/backend/tests/unit/test_facet_service_dedupe.py +++ b/backend/tests/unit/test_facet_service_dedupe.py @@ -159,6 +159,68 @@ async def test_two_case_identical_species_collapse_to_one_entry() -> None: ) +@pytest.mark.asyncio +async def test_labeled_and_unlabeled_species_with_same_label_merge() -> None: + """Visual-UX audit row #6 / a395 (2026-05-12): ``/datasets`` and + ``/query`` showed two ``Caenorhabditis elegans`` chips because one + contributing dataset reported the species with + ``ontologyId=NCBITaxon:6239`` and another reported it with + ``ontologyId=None``. Pre-fix the two dedupe keys (``oid::NCBITaxon:6239`` + and ``norm::caenorhabditis elegans``) were disjoint so both + surfaced. Post-fix the asymmetric label-alias merge collapses + them into one chip; the merged entry keeps the ontologyId. + """ + ds1 = _make_summary( + "ds1", species=[("Caenorhabditis elegans", "NCBITaxon:6239")], + ) + ds2 = _make_summary( + "ds2", species=[("Caenorhabditis elegans", None)], + ) + rows = [ + _make_row("ds1", CompactDatasetSummary.from_full(ds1)), + _make_row("ds2", CompactDatasetSummary.from_full(ds2)), + ] + ds_svc = _fake_dataset_service({1: rows}, total_number=2) + sum_svc = _fake_summary_service({"ds1": ds1, "ds2": ds2}) + + svc = FacetService(ds_svc, sum_svc) + facets = await svc.build_facets() + + assert len(facets.species) == 1, ( + f"Labeled + unlabeled same-name species must merge. Got: " + f"{[(t.label, t.ontologyId) for t in facets.species]}" + ) + assert facets.species[0].label == "Caenorhabditis elegans" + # The ontologyId from the labeled side wins — more authoritative. + assert facets.species[0].ontologyId == "NCBITaxon:6239" + + +@pytest.mark.asyncio +async def test_unlabeled_then_labeled_species_merge_promotes_ontology_id() -> None: + """Reverse-order variant of the audit bug: the unlabeled entry + arrives first, then the labeled one. The merge must still collapse + them and promote the ontologyId onto the surviving entry.""" + ds1 = _make_summary( + "ds1", species=[("Caenorhabditis elegans", None)], + ) + ds2 = _make_summary( + "ds2", species=[("Caenorhabditis elegans", "NCBITaxon:6239")], + ) + rows = [ + _make_row("ds1", CompactDatasetSummary.from_full(ds1)), + _make_row("ds2", CompactDatasetSummary.from_full(ds2)), + ] + ds_svc = _fake_dataset_service({1: rows}, total_number=2) + sum_svc = _fake_summary_service({"ds1": ds1, "ds2": ds2}) + + svc = FacetService(ds_svc, sum_svc) + facets = await svc.build_facets() + + assert len(facets.species) == 1 + assert facets.species[0].label == "Caenorhabditis elegans" + assert facets.species[0].ontologyId == "NCBITaxon:6239" + + @pytest.mark.asyncio async def test_whitespace_drift_in_species_collapses() -> None: """Trivial whitespace differences (trailing space, internal diff --git a/backend/tests/unit/test_image_service.py b/backend/tests/unit/test_image_service.py new file mode 100644 index 0000000..46f5500 --- /dev/null +++ b/backend/tests/unit/test_image_service.py @@ -0,0 +1,339 @@ +"""Unit tests for image_service. + +Coverage targets: + - Happy path: PNG / TIFF / JPEG payloads decode to a 2D float array + with the right (width, height, min, max, format) envelope. + - Downsampling: images larger than MAX_DIMENSION get thumbnailed and + the response sets `downsampled: True`. + - Multi-channel input: RGB / RGBA images convert to grayscale. + - Missing document file refs: returns errorKind="notfound". + - Pillow can't open the bytes: returns errorKind="unsupported" + (covers raw .nim and other NDI-native formats). + - Empty payload: returns errorKind="notfound". + +The cloud-download path is stubbed via AsyncMock so we only exercise +the decode pipeline. Pure helpers (`_decode_image`, `_source_block`, +`_image_error`) are also exercised directly with fixture bytes. +""" +from __future__ import annotations + +import io +from typing import Any +from unittest.mock import AsyncMock + +import numpy as np +import pytest +from PIL import Image + +from backend.services.image_service import ( + MAX_DIMENSION, + ImageService, + _decode_image, + _image_error, + _source_block, +) + +# --------------------------------------------------------------------------- +# Fixture builders — produce raw image bytes Pillow can decode. +# --------------------------------------------------------------------------- + + +def _make_png_bytes(width: int, height: int, mode: str = "L") -> bytes: + """Build a PNG payload. Default mode `L` is single-channel grayscale. + + The pixel values are a deterministic gradient so tests can assert + min/max bracket the expected range. + """ + img = Image.new(mode, (width, height)) + pixels = img.load() + assert pixels is not None # narrow Pillow's PixelAccess|None type + for y in range(height): + for x in range(width): + # 0..255 ramp diagonally so min < max for any non-trivial size. + value = (x + y) % 256 + if mode == "L": + pixels[x, y] = value + else: + # Multi-channel: write the ramp across all channels so the + # grayscale conversion is well-defined and predictable. + pixels[x, y] = tuple([value] * len(mode)) + buf = io.BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +def _make_tiff_bytes(width: int, height: int) -> bytes: + """Build a single-frame TIFF — TIFF is the common scientific format.""" + img = Image.new("L", (width, height)) + pixels = img.load() + assert pixels is not None # narrow Pillow's PixelAccess|None type + for y in range(height): + for x in range(width): + pixels[x, y] = (x + y) % 256 + buf = io.BytesIO() + img.save(buf, format="TIFF") + return buf.getvalue() + + +def _doc_with_file(url: str, filename: str = "image.png") -> dict[str, Any]: + """Build a document shape that matches the cloud's file_info envelope.""" + return { + "id": "doc-abc", + "datasetId": "ds-xyz", + "className": "image", + "data": { + "files": { + "file_list": [filename], + "file_info": { + "name": filename, + "locations": {"location": url}, + }, + }, + "base": {"name": "Test image"}, + "document_class": {"classname": "image"}, + }, + } + + +# --------------------------------------------------------------------------- +# _decode_image — pure-function tests over raw payloads +# --------------------------------------------------------------------------- + + +class TestDecodeImage: + def test_decodes_png_to_2d_array(self) -> None: + payload = _make_png_bytes(8, 4) + result = _decode_image( + payload, frame=0, filename="image.png", + source={"dataset_id": "d", "document_id": "doc", + "doc_class": "image", "doc_name": "x", "filename": "image.png"}, + ) + assert "error" not in result + assert result["width"] == 8 + assert result["height"] == 4 + assert result["format"] == "png" + assert isinstance(result["data"], list) + assert len(result["data"]) == 4 + assert len(result["data"][0]) == 8 + # min < max because the gradient covers a non-trivial range. + assert result["min"] < result["max"] + assert result["downsampled"] is False + + def test_decodes_tiff(self) -> None: + payload = _make_tiff_bytes(16, 16) + result = _decode_image( + payload, frame=0, filename="image.tiff", + source={"dataset_id": "d", "document_id": "doc", + "doc_class": "image", "doc_name": "x", "filename": "image.tiff"}, + ) + assert "error" not in result + assert result["format"] == "tiff" + assert result["width"] == 16 + assert result["height"] == 16 + + def test_downsamples_when_above_max_dimension(self) -> None: + # Use a non-square image to verify both dimensions get scaled. + big_w = MAX_DIMENSION + 200 + big_h = MAX_DIMENSION + 100 + payload = _make_png_bytes(big_w, big_h) + result = _decode_image( + payload, frame=0, filename="big.png", + source={"dataset_id": "d", "document_id": "doc", + "doc_class": "image", "doc_name": "x", "filename": "big.png"}, + ) + assert "error" not in result + assert result["downsampled"] is True + # Thumbnail preserves aspect ratio; the longer side must be at + # MAX_DIMENSION and the other proportionally smaller. + assert max(result["width"], result["height"]) == MAX_DIMENSION + assert result["width"] <= MAX_DIMENSION + assert result["height"] <= MAX_DIMENSION + + def test_does_not_downsample_when_within_bounds(self) -> None: + payload = _make_png_bytes(MAX_DIMENSION, MAX_DIMENSION) + result = _decode_image( + payload, frame=0, filename="ok.png", + source={"dataset_id": "d", "document_id": "doc", + "doc_class": "image", "doc_name": "x", "filename": "ok.png"}, + ) + assert result["downsampled"] is False + assert result["width"] == MAX_DIMENSION + assert result["height"] == MAX_DIMENSION + + def test_rgb_converts_to_grayscale(self) -> None: + """A 3-channel RGB image should come back as a single-channel + 2D array (not a 3D RGB array). Plotly heatmaps expect 2D.""" + payload = _make_png_bytes(8, 8, mode="RGB") + result = _decode_image( + payload, frame=0, filename="color.png", + source={"dataset_id": "d", "document_id": "doc", + "doc_class": "image", "doc_name": "x", "filename": "color.png"}, + ) + assert "error" not in result + # 2D — each row is a list of scalars, not a list of triples. + assert isinstance(result["data"][0][0], float) + + def test_empty_payload_returns_notfound(self) -> None: + result = _decode_image( + b"", frame=0, filename="x", + source={"dataset_id": "d", "document_id": "doc", + "doc_class": None, "doc_name": None, "filename": "x"}, + ) + assert result["errorKind"] == "notfound" + + def test_unrecognized_bytes_return_unsupported(self) -> None: + """Raw NDI .nim payloads (or any non-image bytes) should surface + as `unsupported` so the LLM can communicate it cleanly.""" + # Random bytes that don't match any image magic Pillow knows. + payload = b"\x00\x01\x02\x03not a real image\xff\xfe" * 8 + result = _decode_image( + payload, frame=0, filename="weird.nim", + source={"dataset_id": "d", "document_id": "doc", + "doc_class": None, "doc_name": None, "filename": "weird.nim"}, + ) + assert result["errorKind"] == "unsupported" + assert "not yet supported" in result["error"] or "not recognized" in result["error"] + + def test_min_max_match_array_extremes(self) -> None: + """min/max should be the actual array extremes (used as Plotly + zmin/zmax). Manufactured ramp guarantees min=0, max approaches + the modulus.""" + payload = _make_png_bytes(16, 16) + result = _decode_image( + payload, frame=0, filename="ramp.png", + source={"dataset_id": "d", "document_id": "doc", + "doc_class": None, "doc_name": None, "filename": "ramp.png"}, + ) + # Reconstruct from the response to verify + arr = np.asarray(result["data"], dtype=np.float32) + assert result["min"] == float(arr.min()) + assert result["max"] == float(arr.max()) + + +# --------------------------------------------------------------------------- +# _image_error — sanity check the envelope shape +# --------------------------------------------------------------------------- + + +class TestImageError: + def test_envelope_shape(self) -> None: + env = _image_error("decode", "Bad bytes") + assert env == {"error": "Bad bytes", "errorKind": "decode"} + + def test_all_three_kinds_recognized(self) -> None: + for kind in ("notfound", "decode", "unsupported"): + env = _image_error(kind, "msg") + assert env["errorKind"] == kind + + +# --------------------------------------------------------------------------- +# _source_block — citation provenance for the chat reference chip +# --------------------------------------------------------------------------- + + +class TestSourceBlock: + def test_extracts_document_metadata(self) -> None: + doc = { + "id": "doc-abc", + "datasetId": "ds-xyz", + "className": "image", + "data": { + "base": {"name": "Patch encounter map S1"}, + "document_class": {"classname": "image"}, + }, + "base": {"name": "Patch encounter map S1"}, + "document_class": {"classname": "image"}, + } + block = _source_block(doc, filename="cell_image.tiff") + assert block["dataset_id"] == "ds-xyz" + assert block["document_id"] == "doc-abc" + assert block["doc_class"] == "image" + assert block["doc_name"] == "Patch encounter map S1" + assert block["filename"] == "cell_image.tiff" + + def test_handles_missing_fields(self) -> None: + """A bare document shouldn't crash _source_block assembly.""" + block = _source_block({}, filename=None) + assert block["doc_class"] is None + assert block["doc_name"] is None + assert block["filename"] is None + + +# --------------------------------------------------------------------------- +# ImageService — end-to-end with the cloud client stubbed +# --------------------------------------------------------------------------- + + +class TestImageServiceFetchImage: + @pytest.mark.asyncio + async def test_happy_path_png(self) -> None: + png_bytes = _make_png_bytes(8, 8) + cloud = AsyncMock() + cloud.download_file.return_value = png_bytes + svc = ImageService(cloud) + doc = _doc_with_file("https://signed.example/image.png", "image.png") + result = await svc.fetch_image(doc, frame=0, session=None) + assert "error" not in result + assert result["width"] == 8 + assert result["height"] == 8 + assert result["format"] == "png" + cloud.download_file.assert_awaited_once_with( + "https://signed.example/image.png", access_token=None, + ) + + @pytest.mark.asyncio + async def test_no_file_refs_returns_notfound(self) -> None: + """An empty file_info on the document should not reach the cloud.""" + cloud = AsyncMock() + svc = ImageService(cloud) + doc = {"id": "d", "datasetId": "ds", "data": {"files": {}}} + result = await svc.fetch_image(doc, frame=0, session=None) + assert result["errorKind"] == "notfound" + cloud.download_file.assert_not_awaited() + + @pytest.mark.asyncio + async def test_download_failure_returns_notfound(self) -> None: + cloud = AsyncMock() + cloud.download_file.side_effect = RuntimeError("403 from S3") + svc = ImageService(cloud) + doc = _doc_with_file("https://signed.example/image.png") + result = await svc.fetch_image(doc, frame=0, session=None) + assert result["errorKind"] == "notfound" + assert "Failed to download" in result["error"] + + @pytest.mark.asyncio + async def test_unsupported_bytes_return_unsupported(self) -> None: + """When the document file is downloaded but Pillow can't decode it + (e.g. raw .nim payload), the service surfaces `unsupported` so the + LLM can tell the user without trying to render a chart.""" + cloud = AsyncMock() + cloud.download_file.return_value = b"NOT AN IMAGE" * 32 + svc = ImageService(cloud) + doc = _doc_with_file("https://signed.example/weird.nim", "weird.nim") + result = await svc.fetch_image(doc, frame=0, session=None) + assert result["errorKind"] == "unsupported" + + @pytest.mark.asyncio + async def test_downsamples_oversized_image(self) -> None: + big_payload = _make_png_bytes(MAX_DIMENSION + 256, MAX_DIMENSION + 256) + cloud = AsyncMock() + cloud.download_file.return_value = big_payload + svc = ImageService(cloud) + doc = _doc_with_file("https://signed.example/big.png") + result = await svc.fetch_image(doc, frame=0, session=None) + assert result["downsampled"] is True + assert max(result["width"], result["height"]) == MAX_DIMENSION + + @pytest.mark.asyncio + async def test_source_block_propagates_filename(self) -> None: + """The source block returned to the chat should include the + underlying filename so the LLM can name the file in its answer.""" + cloud = AsyncMock() + cloud.download_file.return_value = _make_png_bytes(4, 4) + svc = ImageService(cloud) + doc = _doc_with_file("https://signed.example/cell_image.tiff", "cell_image.tiff") + result = await svc.fetch_image(doc, frame=0, session=None) + assert "error" not in result + assert result["source"]["filename"] == "cell_image.tiff" + assert result["source"]["document_id"] == "doc-abc" + assert result["source"]["dataset_id"] == "ds-xyz" diff --git a/backend/tests/unit/test_ndi_python_service.py b/backend/tests/unit/test_ndi_python_service.py new file mode 100644 index 0000000..57eb0ba --- /dev/null +++ b/backend/tests/unit/test_ndi_python_service.py @@ -0,0 +1,228 @@ +"""Unit tests for the NDI-python service wrappers. + +These tests don't require the NDI-python stack to be installed (CI may +not have it). The service is designed to degrade gracefully when the +imports fail — and that's exactly what these tests pin down. When the +stack IS available, additional integration tests in the experimental +Railway env will exercise the real decoder paths against the production +Haley / Dabrowska binaries. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from backend.services import ndi_python_service + + +@pytest.fixture(autouse=True) +def reset_available_cache(): + """Ensure each test starts with a fresh NDI-availability probe. + + The service caches the result of its first import attempt to avoid + re-paying the cost; tests need to clear that cache so they can + independently force-on or force-off the stack. + """ + ndi_python_service._NDI_AVAILABLE = None + yield + ndi_python_service._NDI_AVAILABLE = None + + +# --------------------------------------------------------------------------- +# is_ndi_compressed — pure byte-prefix check, no NDI dependency +# --------------------------------------------------------------------------- + + +class TestIsNdiCompressed: + def test_detects_gzip_magic(self): + assert ndi_python_service.is_ndi_compressed(b"\x1f\x8b\x08\x00") is True + + def test_rejects_short_payload(self): + assert ndi_python_service.is_ndi_compressed(b"") is False + assert ndi_python_service.is_ndi_compressed(b"\x1f") is False + + def test_rejects_non_gzip_payloads(self): + assert ndi_python_service.is_ndi_compressed(b"VHSB") is False + assert ndi_python_service.is_ndi_compressed(b"This is a VHSB file") is False + assert ndi_python_service.is_ndi_compressed(b"\x00\x00\x00\x00") is False + + def test_only_inspects_first_two_bytes(self): + # Gzip-magic prefix followed by garbage. Detection passes; the + # downstream expand call would surface the real format issue. + assert ndi_python_service.is_ndi_compressed(b"\x1f\x8b" + b"junk" * 100) is True + + +# --------------------------------------------------------------------------- +# read_vhsb_from_bytes — graceful degradation when NDI unavailable +# --------------------------------------------------------------------------- + + +class TestReadVhsbFromBytes: + def test_returns_none_when_ndi_unavailable(self): + ndi_python_service._NDI_AVAILABLE = False + result = ndi_python_service.read_vhsb_from_bytes(b"This is a VHSB file" + b"\x00" * 2100) + assert result is None + + def test_returns_none_on_short_payload(self): + ndi_python_service._NDI_AVAILABLE = True + # Minimum VHSB payload is 200 (text tag) + 1836 (header) = 2036 bytes + result = ndi_python_service.read_vhsb_from_bytes(b"This is a VHSB file") + assert result is None + + def test_returns_none_on_empty_payload(self): + ndi_python_service._NDI_AVAILABLE = True + assert ndi_python_service.read_vhsb_from_bytes(b"") is None + + def test_returns_none_when_vhsb_read_raises(self): + """When the real vlt call raises (malformed payload, etc.), we + swallow and return None so callers can fall through to their + legacy soft-error path. No exception escapes the service.""" + ndi_python_service._NDI_AVAILABLE = True + with patch.dict( + "sys.modules", + {"vlt.file.custom_file_formats": None}, + ): + # Module-set-to-None forces ImportError on `from vlt.file...` + result = ndi_python_service.read_vhsb_from_bytes(b"x" * 3000) + assert result is None + + +# --------------------------------------------------------------------------- +# expand_ephys_from_bytes — graceful degradation +# --------------------------------------------------------------------------- + + +class TestExpandEphysFromBytes: + def test_returns_none_when_ndi_unavailable(self): + ndi_python_service._NDI_AVAILABLE = False + result = ndi_python_service.expand_ephys_from_bytes(b"\x1f\x8b" + b"x" * 100) + assert result is None + + def test_returns_none_on_non_compressed_payload(self): + # Caller is supposed to gate on is_ndi_compressed first, but the + # wrapper double-checks defensively. + ndi_python_service._NDI_AVAILABLE = True + result = ndi_python_service.expand_ephys_from_bytes(b"VHSB" + b"x" * 100) + assert result is None + + def test_returns_none_when_ndicompress_fails(self): + ndi_python_service._NDI_AVAILABLE = True + with patch.dict("sys.modules", {"ndicompress": None}): + result = ndi_python_service.expand_ephys_from_bytes(b"\x1f\x8b" + b"x" * 100) + assert result is None + + +# --------------------------------------------------------------------------- +# lookup_ontology — never raises, returns None on miss +# --------------------------------------------------------------------------- + + +class TestLookupOntology: + def test_returns_none_on_malformed_curie(self): + # No `:` separator → not a CURIE → don't even probe. + assert ndi_python_service.lookup_ontology("WBStrain00000001") is None + + def test_returns_none_on_empty_input(self): + assert ndi_python_service.lookup_ontology("") is None + + def test_returns_none_when_ndi_unavailable(self): + ndi_python_service._NDI_AVAILABLE = False + result = ndi_python_service.lookup_ontology("CL:0000540") + assert result is None + + def test_returns_none_on_ndi_miss(self): + """NDI's lookup is documented to never raise — it returns a + falsy OntologyResult on miss. Make sure we surface None upward, + not an empty dict.""" + + class _FakeResult: + id = "" + name = "" + + def __bool__(self): + return False + + def to_dict(self): + return {} + + ndi_python_service._NDI_AVAILABLE = True + # ndi isn't installed in the test env, so we inject a fake module + # via sys.modules. The wrapper imports lazily via `from ndi.ontology + # import lookup` so monkey-patching sys.modules is the cleanest way. + fake_module = type("M", (), {"lookup": lambda _curie: _FakeResult()}) + with patch.dict("sys.modules", {"ndi.ontology": fake_module}): + result = ndi_python_service.lookup_ontology("CL:0000540") + assert result is None + + def test_returns_dict_on_ndi_hit(self): + class _FakeResult: + id = "0000540" + name = "T cell" + prefix = "CL" + definition = "Mature T cell." + synonyms = [] + short_name = "T cell" + + def __bool__(self): + return True + + def to_dict(self): + return { + "id": self.id, + "name": self.name, + "prefix": self.prefix, + "definition": self.definition, + "synonyms": self.synonyms, + "short_name": self.short_name, + } + + ndi_python_service._NDI_AVAILABLE = True + fake_module = type("M", (), {"lookup": lambda _curie: _FakeResult()}) + with patch.dict("sys.modules", {"ndi.ontology": fake_module}): + result = ndi_python_service.lookup_ontology("CL:0000540") + assert result is not None + assert result["name"] == "T cell" + assert result["prefix"] == "CL" + + def test_swallows_ndi_exception(self): + """Defensive: even though NDI is documented not to raise, if it + does, we swallow + return None so callers don't see exceptions.""" + ndi_python_service._NDI_AVAILABLE = True + + def _boom(_curie): + raise RuntimeError("boom") + + fake_module = type("M", (), {"lookup": _boom}) + with patch.dict("sys.modules", {"ndi.ontology": fake_module}): + result = ndi_python_service.lookup_ontology("CL:0000540") + assert result is None + + +# --------------------------------------------------------------------------- +# is_ndi_available — caches result, doesn't crash on missing imports +# --------------------------------------------------------------------------- + + +class TestIsNdiAvailable: + def test_caches_first_result(self): + ndi_python_service._NDI_AVAILABLE = True + # Without resetting, subsequent calls should not re-import. + assert ndi_python_service.is_ndi_available() is True + ndi_python_service._NDI_AVAILABLE = False + assert ndi_python_service.is_ndi_available() is False + + def test_returns_false_when_imports_fail(self): + ndi_python_service._NDI_AVAILABLE = None + with patch.dict( + "sys.modules", + { + "vlt.file.custom_file_formats": None, + "ndicompress": None, + "ndi.ontology": None, + }, + ): + assert ndi_python_service.is_ndi_available() is False + # And the cache survives: + assert ndi_python_service._NDI_AVAILABLE is False diff --git a/backend/tests/unit/test_no_phi_in_logs.py b/backend/tests/unit/test_no_phi_in_logs.py new file mode 100644 index 0000000..4b87ab1 --- /dev/null +++ b/backend/tests/unit/test_no_phi_in_logs.py @@ -0,0 +1,218 @@ +"""Static regression test for the audit-log-policy promise. + +The public `/security` page promises: + + > Every API call is logged with user, timestamp, action, and outcome. + > Request bodies and response payloads are explicitly excluded — so + > PHI cannot leak into logs by accident. + +This test enforces that promise structurally: every `log.X(...)` call in +the backend is parsed with `ast`, and the keyword-argument names are +checked against a denylist of PHI / secret-shaped names. A new log line +introducing `password=`, `email=` (unhashed), `request_body=`, etc. fails +the build before it ships. + +The denylist is conservative — it catches both "obvious leak" names and +"surface that could carry PHI" names like `body` / `payload`. Adding a +new logger that legitimately needs one of these names (e.g. ALLOWLISTED +debug logging of a structured response shape that has been audited) +should be done by adding an explicit `# noqa: phi-in-logs` marker on the +log line + an entry in `ALLOWED_LINE_MARKERS` below, with a brief comment +explaining why the audit is OK. + +Doc reference: `apps/web/docs/operations/hipaa-technical-safeguards.md` +§164.312(b) Audit controls, Verification test row. +""" +from __future__ import annotations + +import ast +from pathlib import Path + +import pytest + +# Names that suggest a log line carries plaintext PHI or a secret. +# Hashes and IDs are fine: `user_id_hash`, `session_id`, `request_id`, +# `email_hash` all pass. The names below are the bare versions. +PHI_DENYLIST: frozenset[str] = frozenset({ + # Authentication secrets + "password", + "passwd", + "pwd", + "secret", + "raw_password", + "plain_password", + "access_token", + "refresh_token", + "bearer_token", + "csrf_raw", + "csrf_cookie", + # Request / response surface (these are the PHI-bearing fields) + "body", + "request_body", + "req_body", + "response_body", + "resp_body", + "payload", + "request_payload", + "response_payload", + # PII surface — must be hashed before logging + "email", + "email_raw", + "phone", + "phone_number", + "ssn", + "dob", + "date_of_birth", + "raw_user_agent", + "raw_ip", + "ip_address", + "user_agent", # use `user_agent_hash` instead +}) + +# Names that have hashes / truncations and are SAFE despite looking +# similar to denylisted names. Tracked here to make the safe-vs-unsafe +# boundary explicit rather than implicit in the denylist. +SAFE_NAME_PATTERNS: tuple[str, ...] = ( + "_hash", + "_hashed", + "_digest", + "_short", + "_truncated", +) + +# Lines that have been audited and exempted, as ``<rel-path>:<line>`` +# strings (e.g. ``auth/login.py:105``). Empty by design — every entry +# represents a documented exception. Add only with an accompanying +# audit note explaining why the log call is safe despite using one +# of the PHI_DENYLIST names. +ALLOWED_LINE_MARKERS: frozenset[str] = frozenset() + + +def _backend_root() -> Path: + """Resolve the `backend/` package root from this test file.""" + return Path(__file__).resolve().parents[2] + + +def _python_source_files(root: Path) -> list[Path]: + """Walk the `backend/` tree for .py files, skipping tests + caches.""" + paths: list[Path] = [] + for p in root.rglob("*.py"): + rel = p.relative_to(root).as_posix() + if rel.startswith("tests/") or rel.startswith("__pycache__/"): + continue + if "/__pycache__/" in rel: + continue + paths.append(p) + return paths + + +def _is_logger_call(node: ast.Call) -> bool: + """Match `log.X(...)`, `logger.X(...)`, `LOG.X(...)`, etc. + + Heuristic: attribute access where the method name is one of the + structlog levels and the receiver's lowercased name contains `log`. + Tolerates dotted receivers like `self.log.info(...)`. + """ + func = node.func + if not isinstance(func, ast.Attribute): + return False + if func.attr not in { + "debug", + "info", + "warning", + "warn", + "error", + "exception", + "critical", + "msg", + }: + return False + # Walk the receiver chain looking for a part whose lowercased name + # contains `log`. Skip the first attr (the level name). + receiver: ast.AST = func.value + while isinstance(receiver, ast.Attribute): + if "log" in receiver.attr.lower(): + return True + receiver = receiver.value + return isinstance(receiver, ast.Name) and "log" in receiver.id.lower() + + +def _safe_by_pattern(kw_name: str) -> bool: + """The `_hash`/`_truncated`/etc.-suffixed names are safe by convention.""" + return any(kw_name.endswith(p) for p in SAFE_NAME_PATTERNS) + + +def _scan_file(path: Path) -> list[tuple[int, str]]: + """Return [(line_no, denylisted_kwarg_name), ...] for the given file. + + Empty list means no findings. + """ + try: + source = path.read_text(encoding="utf-8") + except OSError: + return [] + try: + tree = ast.parse(source, filename=str(path)) + except SyntaxError: + return [] + findings: list[tuple[int, str]] = [] + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + if not _is_logger_call(node): + continue + for kw in node.keywords: + if kw.arg is None: + continue + name = kw.arg + if _safe_by_pattern(name): + continue + if name in PHI_DENYLIST: + findings.append((node.lineno, name)) + return findings + + +@pytest.mark.parametrize("source_path", _python_source_files(_backend_root())) +def test_no_phi_in_log_calls(source_path: Path) -> None: + """Every log.X() call must avoid PHI / secret-shaped kwarg names. + + Failure means a new log line was introduced with a kwarg name from + the denylist. Either rename to a hashed/truncated form, OR add an + explicit `# noqa: phi-in-logs` comment in the source + an entry in + ALLOWED_LINE_MARKERS above with a brief audit justification. + """ + findings = _scan_file(source_path) + backend_root = _backend_root() + rel = source_path.relative_to(backend_root).as_posix() + findings = [ + f for f in findings if f"{rel}:{f[0]}" not in ALLOWED_LINE_MARKERS + ] + if findings: + details = "\n".join( + f" {rel}:{lineno} — kwarg `{name}` is in PHI_DENYLIST" + for lineno, name in findings + ) + pytest.fail( + f"PHI / secret-shaped kwargs found in log calls:\n{details}\n\n" + "Either:\n" + " (a) Rename the kwarg to the hashed / truncated form (e.g.\n" + " `email_hash` instead of `email`, `session_id[:8]` value\n" + " under the existing `session_id` key).\n" + " (b) If the value is genuinely safe to log (e.g. an audited\n" + " enum), add `# noqa: phi-in-logs` on the source line AND\n" + " an entry in `ALLOWED_LINE_MARKERS` in this test file with\n" + " a brief explanation." + ) + + +def test_phi_denylist_is_non_empty() -> None: + """Belt-and-suspenders: the denylist itself isn't empty. + + A future refactor that accidentally clears the set would silently + pass the parametrized test (zero findings on every file). This + sanity check catches that. + """ + assert PHI_DENYLIST, "PHI_DENYLIST must contain entries" + assert "password" in PHI_DENYLIST + assert "body" in PHI_DENYLIST + assert "email" in PHI_DENYLIST diff --git a/backend/tests/unit/test_ontology_service.py b/backend/tests/unit/test_ontology_service.py new file mode 100644 index 0000000..23acce9 --- /dev/null +++ b/backend/tests/unit/test_ontology_service.py @@ -0,0 +1,319 @@ +"""Unit tests for ``OntologyService`` — specifically the cache-stub +bypass behavior introduced as a fix for the granular-completeness +regression. + +Pre-fix bug: when a term like ``WBStrain:00000001`` was looked up +BEFORE Phase A wired ``ndi.ontology.lookup`` as a fallback, the +legacy provider returned a stub (``label=None``, ``definition=None``) +which was cached. ``ONTOLOGY_CACHE_TTL_DAYS`` defaults to 30, so for +~a month after Phase A shipped, every lookup of that term hit the +stale stub and short-circuited the NDI-python fallback. End result: +the data browser kept rendering ``WBStrain:00000001`` raw instead of +"N2 wild-type" even though the NDI-python integration knew the +answer. + +The fix: ``OntologyService.lookup`` now treats stubs as cache MISSES, +re-runs the fetch pipeline (legacy providers + NDI-python fallback), +and on success OVERWRITES the stub. So stuck stubs heal on first +use without waiting for the 30-day TTL to roll over. + +These tests cover the lookup pipeline's branching directly with +stubbed providers; the NDI-python integration itself has its own +boundary tests. +""" +from __future__ import annotations + +from unittest.mock import patch + +import httpx +import pytest +import respx + +from backend.services.ontology_cache import OntologyCache, OntologyTerm +from backend.services.ontology_service import OntologyService + + +@pytest.fixture +def cache(tmp_path) -> OntologyCache: + return OntologyCache(db_path=str(tmp_path / "ontology_test.db")) + + +@pytest.fixture +def service(cache: OntologyCache) -> OntologyService: + return OntologyService(cache=cache) + + +def _stub(provider: str, term_id: str) -> OntologyTerm: + """An empty cache entry — what the legacy path returns when a + provider doesn't know the term.""" + return OntologyTerm( + provider=provider, term_id=term_id, + label=None, definition=None, url=None, + ) + + +def _hit(provider: str, term_id: str, label: str) -> OntologyTerm: + return OntologyTerm( + provider=provider, term_id=term_id, + label=label, definition=f"{label} definition", url=None, + ) + + +@pytest.mark.asyncio +async def test_lookup_returns_real_cached_hit_without_refetching(service, cache): + """Real cache hits short-circuit the fetch path — no upstream calls.""" + cache.set(_hit("CL", "0000540", "neuron")) + with patch.object(service, "_fetch_from_provider") as fetch_mock, \ + patch.object(service, "_try_ndi_fallback") as ndi_mock: + result = await service.lookup("CL:0000540") + assert result.label == "neuron" + fetch_mock.assert_not_called() + ndi_mock.assert_not_called() + + +@pytest.mark.asyncio +async def test_lookup_treats_stub_as_cache_miss_and_retries(service, cache): + """STUB cache entries (label=None AND definition=None) must NOT + short-circuit. The fetch pipeline must run again so the + NDI-python fallback can fire.""" + # Seed a stub — simulates a pre-Phase-A cached miss for WBStrain. + cache.set(_stub("WBStrain", "00000001")) + async def fake_fetch(_p, _t): + # Legacy provider still doesn't know WBStrain. + return _stub("WBStrain", "00000001") + async def fake_ndi(_term, p, t): + return _hit(p, t, "N2 wild-type") + with patch.object(service, "_fetch_from_provider", side_effect=fake_fetch), \ + patch.object(service, "_try_ndi_fallback", side_effect=fake_ndi): + result = await service.lookup("WBStrain:00000001") + # NDI-python's result wins, and the cache stub is replaced with + # the real hit so subsequent lookups don't re-pay the cost. + assert result.label == "N2 wild-type" + # Cache now has the real entry. + cached_after = cache.get("WBStrain", "00000001") + assert cached_after is not None + assert cached_after.label == "N2 wild-type" + + +@pytest.mark.asyncio +async def test_stub_bypass_caches_new_stub_when_both_paths_fail(service, cache): + """When the stub-miss retry ALSO comes up empty (legacy + NDI-python + both unknown), we return the empty result without thrashing the + cache: we already have a stub for this term, no need to write + another. Subsequent lookups still bypass, but that's OK — the + extra cost is only when the term genuinely can't be resolved by + anyone, which is rare.""" + cache.set(_stub("UNKNOWN", "99999")) + async def fake_fetch(_p, _t): + return _stub("UNKNOWN", "99999") + async def fake_ndi(_term, _p, _t): + return None + with patch.object(service, "_fetch_from_provider", side_effect=fake_fetch), \ + patch.object(service, "_try_ndi_fallback", side_effect=fake_ndi): + result = await service.lookup("UNKNOWN:99999") + assert result.label is None + assert result.definition is None + # The pre-existing stub stays in the cache — no double-write. + cached_after = cache.get("UNKNOWN", "99999") + assert cached_after is not None + assert cached_after.label is None + + +@pytest.mark.asyncio +async def test_fresh_term_with_provider_hit_does_not_call_ndi(service, cache): + """When the legacy provider returns a REAL hit on first lookup, + NDI-python is NOT called (it's a fallback, not a co-resolver). + This is the original behavior; verify the stub fix didn't break it.""" + async def fake_fetch(p, t): + return _hit(p, t, "frontal cortex") + with patch.object(service, "_fetch_from_provider", side_effect=fake_fetch), \ + patch.object(service, "_try_ndi_fallback") as ndi_mock: + result = await service.lookup("UBERON:0001870") + assert result.label == "frontal cortex" + ndi_mock.assert_not_called() + # And the cache now has the real entry. + assert cache.get("UBERON", "0001870").label == "frontal cortex" + + +@pytest.mark.asyncio +async def test_fresh_term_falls_through_to_ndi_when_legacy_returns_stub( + service, cache, +): + """For terms the legacy providers can't resolve (e.g. NDIC, WBStrain), + the legacy path returns a stub and we fall through to NDI-python. + Same as test_lookup_treats_stub_as_cache_miss_and_retries but + without any prior cache state — covers the cold-start path.""" + async def fake_fetch(p, t): + return _stub(p, t) + async def fake_ndi(_term, p, t): + return _hit(p, t, "Purpose: Assessing spatial frequency tuning") + with patch.object(service, "_fetch_from_provider", side_effect=fake_fetch), \ + patch.object(service, "_try_ndi_fallback", side_effect=fake_ndi): + result = await service.lookup("NDIC:1") + assert result.label == "Purpose: Assessing spatial frequency tuning" + # And the result is now cached as a real hit. + assert cache.get("NDIC", "1").label == "Purpose: Assessing spatial frequency tuning" + + +@pytest.mark.asyncio +async def test_batch_lookup_unblocks_stale_stubs(service, cache): + """The batch path inherits stub-bypass automatically because it + delegates to ``self.lookup`` per term. Verify end-to-end so we + don't regress this.""" + # Seed two stubs (mix of providers) so a batch hits both. + cache.set(_stub("WBStrain", "00000001")) + cache.set(_stub("NDIC", "1")) + async def fake_fetch(p, t): + return _stub(p, t) + labels = { + "WBStrain:00000001": "N2 wild-type", + "NDIC:1": "Purpose: Assessing spatial frequency tuning", + } + async def fake_ndi(term, p, t): + return _hit(p, t, labels[term]) + with patch.object(service, "_fetch_from_provider", side_effect=fake_fetch), \ + patch.object(service, "_try_ndi_fallback", side_effect=fake_ndi): + results = await service.batch_lookup( + ["WBStrain:00000001", "NDIC:1"], + ) + assert len(results) == 2 + label_by_id = {f"{r.provider}:{r.term_id}": r.label for r in results} + assert label_by_id["WBStrain:00000001"] == "N2 wild-type" + assert label_by_id["NDIC:1"] == "Purpose: Assessing spatial frequency tuning" + + +# --------------------------------------------------------------------------- +# WBStrain scrape — `_fetch_wormbase` now resolves strain names from the +# canonical wormbase.org strain page so the lookup pipeline no longer has +# to depend on NDI-python's WBStrain provider (which only returns a URL). +# --------------------------------------------------------------------------- + + +# Minimal HTML fixture mirroring the real WormBase strain page. Captures +# the two parse targets (``<title>`` + page-title breadcrumb) and enough +# surrounding chrome that a future regex refactor can verify it still +# anchors on the right boundaries. +_WORMBASE_N2_HTML = """<!DOCTYPE html> +<html lang="en-US"> +<head> + <meta charset="utf-8"> + <title> N2 (strain) - WormBase : Nematode Information Resource + + +
+
+ + Species » + C. elegans + +

Strain » N2

+
+
+
+ + +""" + +# Cloudflare-interstitial response — the most common failure mode in +# practice when the backend's egress IP is on a datacenter range. +_CLOUDFLARE_JUST_A_MOMENT = ( + "" + "Just a moment..." + "" +) + + +@pytest.mark.asyncio +async def test_wormbase_scrape_resolves_strain_name(service): + """Happy path: the WBStrain page returns 200 with the canonical + title, and ``_fetch_wormbase`` extracts the strain name from + ````.""" + url = "https://wormbase.org/species/c_elegans/strain/00000001" + with respx.mock(assert_all_called=False) as router: + router.get(url).mock( + return_value=httpx.Response(200, text=_WORMBASE_N2_HTML), + ) + term = await service._fetch_wormbase("00000001") + assert term.provider == "WBStrain" + assert term.term_id == "00000001" + assert term.label == "N2" + assert term.url == url + + +@pytest.mark.asyncio +async def test_wormbase_scrape_falls_back_to_breadcrumb(service): + """If the ``<title>`` element is missing or mangled (older snapshot, + partial response), the page-title breadcrumb still resolves the + name.""" + body_no_title = _WORMBASE_N2_HTML.replace( + "<title> N2 (strain) - WormBase : Nematode Information Resource", + "", + ) + url = "https://wormbase.org/species/c_elegans/strain/00000001" + with respx.mock(assert_all_called=False) as router: + router.get(url).mock( + return_value=httpx.Response(200, text=body_no_title), + ) + term = await service._fetch_wormbase("00000001") + assert term.label == "N2" + + +@pytest.mark.asyncio +async def test_wormbase_scrape_returns_none_label_on_cloudflare_block(service): + """Cloudflare interstitials still return 200, but the ``(strain)`` + anchor in the title regex won't match. We must NOT leak the + interstitial body as a strain name.""" + url = "https://wormbase.org/species/c_elegans/strain/00000001" + with respx.mock(assert_all_called=False) as router: + router.get(url).mock( + return_value=httpx.Response(200, text=_CLOUDFLARE_JUST_A_MOMENT), + ) + term = await service._fetch_wormbase("00000001") + assert term.label is None + assert term.url == url + + +@pytest.mark.asyncio +async def test_wormbase_scrape_returns_none_label_on_404(service): + """Strain IDs that don't exist on WormBase return 404. The scrape + must NOT raise and must return ``label=None`` so the upstream + pipeline falls through to the NDI-python fallback.""" + url = "https://wormbase.org/species/c_elegans/strain/99999999" + with respx.mock(assert_all_called=False) as router: + router.get(url).mock(return_value=httpx.Response(404, text="")) + term = await service._fetch_wormbase("99999999") + assert term.label is None + + +@pytest.mark.asyncio +async def test_wormbase_scrape_returns_none_label_on_network_error(service): + """Network errors (timeouts, DNS, RST) must be swallowed; the + lookup pipeline degrades cleanly to ``label=None``.""" + url = "https://wormbase.org/species/c_elegans/strain/00000001" + with respx.mock(assert_all_called=False) as router: + router.get(url).mock(side_effect=httpx.ConnectTimeout("boom")) + term = await service._fetch_wormbase("00000001") + assert term.label is None + assert term.url == url + + +@pytest.mark.asyncio +async def test_wormbase_scrape_end_to_end_caches_result(service, cache): + """End-to-end: lookup of a WBStrain CURIE invokes the scrape, gets + the strain name, and the result is cached as a REAL hit (not a + stub). Second lookup must short-circuit without re-fetching.""" + url = "https://wormbase.org/species/c_elegans/strain/00000001" + with respx.mock(assert_all_called=False) as router: + route = router.get(url).mock( + return_value=httpx.Response(200, text=_WORMBASE_N2_HTML), + ) + # First call hits WormBase via the scrape. + first = await service.lookup("WBStrain:00000001") + # Second call must come from cache. + second = await service.lookup("WBStrain:00000001") + assert first.label == "N2" + assert second.label == "N2" + assert route.call_count == 1 + cached = cache.get("WBStrain", "00000001") + assert cached is not None + assert cached.label == "N2" diff --git a/backend/tests/unit/test_psth_service.py b/backend/tests/unit/test_psth_service.py new file mode 100644 index 0000000..a7c7b1c --- /dev/null +++ b/backend/tests/unit/test_psth_service.py @@ -0,0 +1,550 @@ +"""Unit tests for psth_service — peri-stimulus time histogram orchestration. + +Tests verify: + + * happy path: unit with deterministic spike train + stimulus with + multiple events produces bin_centers + counts arrays of consistent + length; mean_rate_hz normalization is correct + * empty path: zero spikes in window → zero-counts but valid bin + arrays + correct n_trials + * include_raster: per-trial relative-time arrays are surfaced + * cap enforcement: bin_size_ms < 1 ms → invalid_window envelope + * window cap: (t1 - t0) > 10 s → invalid_window envelope + * bin-count cap: too many bins → invalid_window envelope + * soft error: stimulus doc lacks event timestamps → no_events envelope + * soft error: spike doc binary fails to decode → decode_failed envelope + * event-extraction across NDI doc-class paths: + - stimulus_presentation.presentations[*].time_started + - stimulus_response.responses[*].stim_time + - data.events (top-level array) +""" +from __future__ import annotations + +from typing import Any + +from backend.services.psth_service import ( + DEFAULT_BIN_SIZE_MS, + DEFAULT_T0, + DEFAULT_T1, + MAX_WINDOW_SECONDS, + MIN_BIN_SIZE_MS, + PsthRequest, + _build_bin_arrays, + _cap_raster, + _extract_event_times, + _extract_spike_times_from_doc, + _validate_window, + compute_psth, +) + +# --------------------------------------------------------------------------- +# Fakes — mirror the shape the real services produce. The real +# document_service.detail() returns a normalized doc dict; we hand +# back canned dicts keyed by doc_id so the test orchestrator picks +# the right body per call. +# --------------------------------------------------------------------------- + + +class _FakeDocumentService: + """Stub for DocumentService.detail — canned responses per doc_id. + + The real signature is ``detail(dataset_id, document_id, *, + access_token)``; we mirror that here so the orchestrator's call + sites don't have to branch on test vs prod. + """ + + def __init__(self, docs_by_id: dict[str, dict[str, Any]]) -> None: + self._docs = docs_by_id + self.calls: list[tuple[str, str]] = [] + + async def detail( + self, + dataset_id: str, + document_id: str, + *, + access_token: str | None, # noqa: ARG002 — stub signature + ) -> dict[str, Any]: + self.calls.append((dataset_id, document_id)) + if document_id not in self._docs: + raise RuntimeError(f"no canned doc for {document_id}") + return self._docs[document_id] + + +class _FakeBinaryService: + """Stub for BinaryService.get_timeseries — canned response. + + Used for the binary-fallback path when the unit doc's JSON body + has no inlined spike-times array. The real service decodes + NBF/VHSB; tests use a pre-baked dict instead. + """ + + def __init__(self, response: dict[str, Any] | None = None) -> None: + self.response = response or {"timestamps": [], "channels": {}, "error": "no_data"} + self.calls: list[dict[str, Any]] = [] + + async def get_timeseries( + self, + document: dict[str, Any], + *, + access_token: str | None, # noqa: ARG002 — stub signature + filename: str | None = None, # noqa: ARG002 + ) -> dict[str, Any]: + self.calls.append(document) + return self.response + + +def _vmspikesummary_doc( + *, + name: str = "unit_001", + spike_times: list[float] | None = None, +) -> dict[str, Any]: + """Build a minimal vmspikesummary doc body. Spike times inline + under ``data.vmspikesummary.spike_times``. + """ + inner: dict[str, Any] = {"name": name} + if spike_times is not None: + inner["spike_times"] = spike_times + return { + "id": "u" * 24, + "data": {"vmspikesummary": inner}, + } + + +def _stim_presentation_doc(times: list[float]) -> dict[str, Any]: + """Stimulus doc using the ``stimulus_presentation`` schema.""" + return { + "id": "s" * 24, + "data": { + "stimulus_presentation": { + "presentations": [{"time_started": t} for t in times], + }, + }, + } + + +def _stim_response_doc(times: list[float]) -> dict[str, Any]: + """Stimulus doc using the ``stimulus_response`` schema.""" + return { + "id": "s" * 24, + "data": { + "stimulus_response": { + "responses": [{"stim_time": t} for t in times], + }, + }, + } + + +def _stim_events_doc(times: list[float]) -> dict[str, Any]: + """Stimulus doc with preprocessed top-level events array.""" + return { + "id": "s" * 24, + "data": {"events": list(times)}, + } + + +# --------------------------------------------------------------------------- +# Pure-helper tests +# --------------------------------------------------------------------------- + + +class TestExtractSpikeTimesFromDoc: + def test_pulls_spike_times_under_canonical_key(self) -> None: + doc = _vmspikesummary_doc(spike_times=[0.1, 0.2, 0.3]) + out = _extract_spike_times_from_doc(doc) + assert out == [0.1, 0.2, 0.3] + + def test_falls_back_to_alternate_keys(self) -> None: + doc = {"data": {"vmspikesummary": {"sample_times": [1.0, 2.0]}}} + out = _extract_spike_times_from_doc(doc) + assert out == [1.0, 2.0] + + def test_returns_none_when_no_array(self) -> None: + doc = {"data": {"vmspikesummary": {"name": "u"}}} + assert _extract_spike_times_from_doc(doc) is None + + def test_skips_non_numeric_entries(self) -> None: + doc = {"data": {"vmspikesummary": {"spike_times": [1.0, "bad", 2.0, None, True, False]}}} + out = _extract_spike_times_from_doc(doc) + # bool/None excluded; numeric strings accepted by _coerce_numeric_list + assert out == [1.0, 2.0] + + +class TestExtractEventTimes: + def test_stimulus_presentation_path(self) -> None: + doc = _stim_presentation_doc([1.0, 2.0, 3.0]) + assert _extract_event_times(doc) == [1.0, 2.0, 3.0] + + def test_stimulus_response_path(self) -> None: + doc = _stim_response_doc([0.5, 1.5]) + assert _extract_event_times(doc) == [0.5, 1.5] + + def test_data_events_array(self) -> None: + doc = _stim_events_doc([10.0, 20.0]) + assert _extract_event_times(doc) == [10.0, 20.0] + + def test_data_events_list_of_dicts(self) -> None: + doc = {"data": {"events": [{"time": 1.0}, {"t": 2.0}]}} + assert _extract_event_times(doc) == [1.0, 2.0] + + def test_top_level_events_fallback(self) -> None: + doc = {"events": [5.0, 6.0]} + assert _extract_event_times(doc) == [5.0, 6.0] + + def test_returns_empty_when_no_path_matches(self) -> None: + assert _extract_event_times({"data": {}}) == [] + assert _extract_event_times({}) == [] + # Wrong-shape entries (no recognized key) + assert _extract_event_times({"data": {"events": [{"foo": 1.0}]}}) == [] + + +class TestValidateWindow: + def _req(self, **kwargs: Any) -> PsthRequest: + defaults: dict[str, Any] = { + "unit_doc_id": "u" * 24, + "stimulus_doc_id": "s" * 24, + "t0": DEFAULT_T0, + "t1": DEFAULT_T1, + "bin_size_ms": DEFAULT_BIN_SIZE_MS, + } + defaults.update(kwargs) + return PsthRequest(**defaults) + + def test_defaults_are_valid(self) -> None: + _, _, _, err = _validate_window(self._req()) + assert err is None + + def test_bin_size_below_minimum_rejected(self) -> None: + _, _, _, err = _validate_window(self._req(bin_size_ms=0.5)) + assert err is not None + assert "minimum" in err.lower() + + def test_window_too_wide_rejected(self) -> None: + _, _, _, err = _validate_window( + self._req(t0=-5.0, t1=5.0 + MAX_WINDOW_SECONDS), + ) + assert err is not None + assert "window" in err.lower() or "exceeds" in err.lower() + + def test_t1_must_exceed_t0(self) -> None: + _, _, _, err = _validate_window(self._req(t0=1.0, t1=0.5)) + assert err is not None + + def test_too_many_bins_rejected(self) -> None: + # (1 s) / (0.5 ms) = 2000 bins → over MAX_BINS=1000. But 0.5 ms + # fails the bin_size_ms floor first, so we need a wider window. + # 2 s / 1 ms = 2000 bins → over cap. + _, _, _, err = _validate_window( + self._req(t0=0.0, t1=2.0, bin_size_ms=MIN_BIN_SIZE_MS), + ) + assert err is not None + assert "bin" in err.lower() + + +class TestBuildBinArrays: + def test_bin_count_matches_window_and_size(self) -> None: + edges, centers = _build_bin_arrays(t0=0.0, t1=1.0, bin_size_ms=10.0) + # 1 s / 10 ms = 100 bins + assert len(centers) == 100 + assert len(edges) == 101 + + def test_centers_are_midpoints(self) -> None: + _, centers = _build_bin_arrays(t0=0.0, t1=1.0, bin_size_ms=100.0) + # 10 bins, centers at 0.05, 0.15, ..., 0.95 + assert abs(centers[0] - 0.05) < 1e-9 + assert abs(centers[-1] - 0.95) < 1e-9 + + +class TestCapRaster: + def test_under_cap_returns_verbatim(self) -> None: + per_trial = [[1.0, 2.0], [3.0]] + out = _cap_raster(per_trial, total_cap=100) + assert out == per_trial + + def test_over_cap_strides_proportionally(self) -> None: + # 3 trials with 100 spikes each = 300 total; cap at 30 → ratio 0.1 + per_trial = [[float(i) for i in range(100)] for _ in range(3)] + out = _cap_raster(per_trial, total_cap=30) + # Each trial gets ~10 spikes (max keep), preserving endpoints + assert len(out) == 3 + for trial in out: + assert len(trial) <= 11 + assert trial[0] == 0.0 + assert trial[-1] == 99.0 + + +# --------------------------------------------------------------------------- +# compute_psth integration — service-level happy paths + soft errors +# --------------------------------------------------------------------------- + + +def _spike_train(n: int, stride: float = 0.01) -> list[float]: + """Build a deterministic spike train of n spikes at stride seconds.""" + return [i * stride for i in range(n)] + + +def _req(**kwargs: Any) -> PsthRequest: + defaults: dict[str, Any] = { + "unit_doc_id": "a" * 24, + "stimulus_doc_id": "b" * 24, + } + defaults.update(kwargs) + return PsthRequest(**defaults) + + +async def test_happy_path_consistent_arrays() -> None: + """Unit with 100 spikes + stimulus with 10 events. + + Each event sits inside the spike train so the [-0.5, 1.5] window + captures spikes around it. Verifies: + - bin_centers, counts, mean_rate_hz are parallel arrays + - n_trials matches the event count + - n_spikes is non-zero + - error/error_kind are None on the happy path + """ + # Unit: 100 spikes spaced 0.01 s apart, covering [0, 1] s + spike_times = _spike_train(100, stride=0.01) + unit_doc = _vmspikesummary_doc(spike_times=spike_times) + # Stimulus: 10 events at 0.05, 0.10, ..., 0.50 s — every event has + # spikes both before (t0 = -0.5) and after (t1 = 1.5) it. + stim_doc = _stim_presentation_doc([0.05 + i * 0.05 for i in range(10)]) + + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + bs = _FakeBinaryService() + resp = await compute_psth( + _req(), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + + # Parallel-array consistency. + n_bins = len(resp.bin_centers) + assert n_bins > 0 + assert len(resp.counts) == n_bins + assert len(resp.mean_rate_hz) == n_bins + + # Default bin layout: (-0.5, 1.5) s @ 20 ms = 100 bins. + assert n_bins == 100 + + # Trials + spikes were counted. + assert resp.n_trials == 10 + assert resp.n_spikes > 0 + assert resp.error is None + assert resp.error_kind is None + + # Rate normalization sanity: mean_rate_hz[i] == counts[i] / (n_trials * bin_size_s) + bin_size_s = resp.bin_size_ms / 1000.0 + for c, r in zip(resp.counts, resp.mean_rate_hz, strict=True): + assert abs(r - c / (resp.n_trials * bin_size_s)) < 1e-9 + + +async def test_empty_window_returns_zero_counts_envelope() -> None: + """Events present, but their windows contain zero spikes. + + Spike train is far from the events; n_trials still matches event + count; counts are all zero; valid bin arrays returned; + error_kind='empty_window'. + """ + # Unit: spikes at t = 100, 100.01, ... (way after the events) + spike_times = [100.0 + i * 0.01 for i in range(20)] + unit_doc = _vmspikesummary_doc(spike_times=spike_times) + # Stimulus events at t = 0, 1, 2 — windows are all near zero; + # no spikes overlap. + stim_doc = _stim_presentation_doc([0.0, 1.0, 2.0]) + + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + bs = _FakeBinaryService() + resp = await compute_psth( + _req(), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + + assert resp.n_trials == 3 + assert resp.n_spikes == 0 + assert resp.error_kind == "empty_window" + # Bin arrays are still populated (chart can render flat trace). + assert len(resp.bin_centers) > 0 + assert len(resp.counts) == len(resp.bin_centers) + assert all(c == 0 for c in resp.counts) + assert all(r == 0.0 for r in resp.mean_rate_hz) + + +async def test_include_raster_returns_per_trial_arrays() -> None: + """include_raster=True surfaces per-trial relative spike times.""" + spike_times = _spike_train(50, stride=0.02) + unit_doc = _vmspikesummary_doc(spike_times=spike_times) + stim_doc = _stim_presentation_doc([0.0, 0.5]) + + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + bs = _FakeBinaryService() + resp = await compute_psth( + _req(include_raster=True), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + + assert resp.per_trial_raster is not None + assert len(resp.per_trial_raster) == 2 # two events + # Every value must fall within [t0, t1] (relative-time bounds) + for trial in resp.per_trial_raster: + for t in trial: + assert resp.t0 <= t <= resp.t1 + + +async def test_raster_default_off() -> None: + """include_raster defaults to False → per_trial_raster=None.""" + unit_doc = _vmspikesummary_doc(spike_times=_spike_train(10)) + stim_doc = _stim_presentation_doc([0.0]) + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + bs = _FakeBinaryService() + resp = await compute_psth( + _req(), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + assert resp.per_trial_raster is None + + +async def test_cap_enforcement_rejects_tiny_bins() -> None: + """bin_size_ms below MIN_BIN_SIZE_MS surfaces invalid_window envelope. + + The service returns a soft envelope rather than raising so the + chat tool can render the explanation. n_trials=0 because we never + fetched events. + """ + unit_doc = _vmspikesummary_doc(spike_times=_spike_train(10)) + stim_doc = _stim_presentation_doc([0.0]) + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + bs = _FakeBinaryService() + resp = await compute_psth( + _req(bin_size_ms=0.5), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + assert resp.error_kind == "invalid_window" + assert resp.n_trials == 0 + assert resp.bin_centers == [] + + +async def test_soft_error_no_event_timestamps() -> None: + """Stimulus doc with no extractable timestamps → no_events envelope.""" + unit_doc = _vmspikesummary_doc(spike_times=_spike_train(10)) + # Stimulus doc shape that doesn't match any extraction path. + stim_doc = {"data": {"stimulus_presentation": {"name": "no events here"}}} + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + bs = _FakeBinaryService() + resp = await compute_psth( + _req(), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + assert resp.error_kind == "no_events" + # Diagnostics: caller can echo which doc failed + assert resp.stimulus_doc_id == "b" * 24 + + +async def test_soft_error_decode_failed_when_no_spike_times() -> None: + """Unit doc with no inlined spike times + binary fallback empty → + decode_failed envelope. + """ + unit_doc = _vmspikesummary_doc(spike_times=None) # no spike_times key + stim_doc = _stim_presentation_doc([0.0, 1.0]) + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + bs = _FakeBinaryService( + response={"timestamps": [], "channels": {}, "error": "no_file"}, + ) + resp = await compute_psth( + _req(), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + assert resp.error_kind == "decode_failed" + assert resp.n_trials == 0 + + +async def test_binary_fallback_supplies_spike_times() -> None: + """When the JSON body has no spike_times, the binary fallback's + timestamps array is used as the spike train. + """ + unit_doc = _vmspikesummary_doc(spike_times=None) + stim_doc = _stim_presentation_doc([0.0]) + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + # Binary returns 5 timestamps within the [-0.5, 1.5] window. + bs = _FakeBinaryService( + response={ + "timestamps": [0.1, 0.2, 0.3, 0.4, 0.5], + "channels": {"ch0": [1.0] * 5}, + "error": None, + }, + ) + resp = await compute_psth( + _req(), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + assert resp.error_kind is None + assert resp.n_spikes == 5 + assert resp.n_trials == 1 + + +async def test_stimulus_response_path_works_end_to_end() -> None: + """Verifies the stimulus_response.responses[*].stim_time path.""" + unit_doc = _vmspikesummary_doc(spike_times=_spike_train(50)) + stim_doc = _stim_response_doc([0.1, 0.2, 0.3]) + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + bs = _FakeBinaryService() + resp = await compute_psth( + _req(), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + assert resp.n_trials == 3 + assert resp.error_kind is None + + +async def test_unit_name_extracted_from_vmspikesummary_name() -> None: + """unit_name is propagated from data.vmspikesummary.name.""" + unit_doc = _vmspikesummary_doc(name="MUA_ch3_unit5", spike_times=[0.1]) + stim_doc = _stim_presentation_doc([0.0]) + docs = _FakeDocumentService({"a" * 24: unit_doc, "b" * 24: stim_doc}) + bs = _FakeBinaryService() + resp = await compute_psth( + _req(), + document_service=docs, # type: ignore[arg-type] + binary_service=bs, # type: ignore[arg-type] + session=None, + dataset_id="ds_test", + ) + assert resp.unit_name == "MUA_ch3_unit5" + + +async def test_camelcase_alias_accepted() -> None: + """PsthRequest accepts camelCase aliases from the TS chat proxy.""" + req = PsthRequest.model_validate({ + "unitDocId": "u" * 24, + "stimulusDocId": "s" * 24, + "binSizeMs": 50, + "includeRaster": True, + }) + assert req.unit_doc_id == "u" * 24 + assert req.stimulus_doc_id == "s" * 24 + assert req.bin_size_ms == 50 + assert req.include_raster is True diff --git a/backend/tests/unit/test_redis_table_cache.py b/backend/tests/unit/test_redis_table_cache.py index c5af9ea..d9559cd 100644 --- a/backend/tests/unit/test_redis_table_cache.py +++ b/backend/tests/unit/test_redis_table_cache.py @@ -17,10 +17,13 @@ def test_table_key_shape() -> None: assert k2 == f"table:{RedisTableCache.SCHEMA_VERSION}:DS1:subject:u:deadbeefdeadbeef" -def test_table_key_schema_version_is_v4() -> None: +def test_table_key_schema_version_is_v7() -> None: """Pinned so any projection-shape or cache-semantics change forces a - conscious bump. Current = v4 (post-PR-3: per-user cache scoping).""" - assert RedisTableCache.SCHEMA_VERSION == "v4" + conscious bump. Current = v7 (F-1b follow-up: subject enrichment now + fetches treatment_drug + treatment_transfer in addition to literal + treatment, so the broadcast actually fires for subclass-only + datasets like Bhar).""" + assert RedisTableCache.SCHEMA_VERSION == "v7" def test_default_ttl_is_one_hour() -> None: diff --git a/backend/tests/unit/test_spike_summary_service.py b/backend/tests/unit/test_spike_summary_service.py new file mode 100644 index 0000000..2c64b17 --- /dev/null +++ b/backend/tests/unit/test_spike_summary_service.py @@ -0,0 +1,523 @@ +"""Unit tests for SpikeSummaryService. + +Mocks the cloud HTTP layer via respx (mirrors the pattern in +``test_pivot_service.py``). Exercises: + +- ``unit_doc_id`` single-doc fetch path → returns one unit with + spikes + ISIs. +- ``unit_name_match`` query path → returns N units in name order. +- Bare scan with N > ``max_units`` → returns capped slice with + ``total_matching`` reflecting full count. +- Stride-sample cap: a doc with > MAX_SPIKES_PER_UNIT spikes returns + the capped count. +- Empty: zero matching docs → ``units=[]`` and ``total_matching=0`` + with the empty-reason envelope populated. +- Soft error: matched doc with no parseable spike_times → unit entry + with ``error_kind='decode_failed'`` instead of crashing. +- ``kind`` gating: ``raster`` omits ISIs, ``isi_histogram`` omits + spike_times, ``both`` returns both. +- ``t_window`` filter trims spikes before stride-sampling. +""" +from __future__ import annotations + +from typing import Any + +import pytest +import respx +from cryptography.fernet import Fernet + +from backend.clients.ndi_cloud import NdiCloudClient +from backend.services.document_service import DocumentService +from backend.services.spike_summary_service import ( + DEFAULT_MAX_UNITS, + MAX_SPIKES_PER_UNIT, + SpikeSummaryRequest, + SpikeSummaryUnit, + _build_isi_field, + _build_spike_field, + _extract_spike_times, + _pick_doc_id, + _pick_unit_name, + _stride_sample, + compute_spike_summary, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +async def cloud() -> NdiCloudClient: # type: ignore[no-untyped-def] + """Shared cloud client. Mirrors test_pivot_service.cloud.""" + import os + os.environ.setdefault("SESSION_ENCRYPTION_KEY", Fernet.generate_key().decode()) + client = NdiCloudClient() + await client.start() + try: + yield client + finally: + await client.close() + + +def _make_doc( + doc_id: str, + name: str, + spike_times: list[float], + *, + key: str = "spike_times", +) -> dict[str, Any]: + """Build a vmspikesummary document body in the shape the cloud + returns from ndiquery / bulk-fetch. + + Defaults to ``data.vmspikesummary.spike_times`` (the most-common + field path). ``key`` overrides it for tests probing the + ``spiketimes`` / ``sample_times`` fallbacks. + """ + return { + "id": doc_id, + "ndiId": f"ndi-{doc_id}", + "name": name, + "data": { + "base": {"id": f"ndi-{doc_id}", "name": name}, + "vmspikesummary": { + "name": name, + key: spike_times, + }, + }, + } + + +def _detail_body(doc: dict[str, Any]) -> dict[str, Any]: + """The single-doc endpoint hoists the body to top-level (see + ``DocumentService._normalize_document``). We mirror that here so + the DocumentService's normalizer roundtrips into the bulk-fetch + shape. + """ + out: dict[str, Any] = {k: v for k, v in doc.items() if k != "data"} + out.update(doc.get("data", {})) + return out + + +def _ndiquery_body(docs: list[dict[str, Any]]) -> dict[str, Any]: + return { + "number_matches": len(docs), + "pageSize": 1000, + "page": 1, + "documents": docs, + } + + +# --------------------------------------------------------------------------- +# Pure helper unit tests — no HTTP, no fixtures. +# --------------------------------------------------------------------------- + + +class TestExtractSpikeTimes: + def test_canonical_spike_times_field(self) -> None: + doc = _make_doc("d1", "Unit 1", [0.1, 0.2, 0.5]) + assert _extract_spike_times(doc) == [0.1, 0.2, 0.5] + + def test_spiketimes_fallback(self) -> None: + doc = _make_doc("d1", "U", [1.0, 2.0], key="spiketimes") + assert _extract_spike_times(doc) == [1.0, 2.0] + + def test_sample_times_canonical_schema_name(self) -> None: + # Schema-canonical fallback. Used by older NDI versions. + doc = _make_doc("d1", "U", [3.0], key="sample_times") + assert _extract_spike_times(doc) == [3.0] + + def test_returns_none_when_no_data(self) -> None: + assert _extract_spike_times({}) is None + assert _extract_spike_times({"data": {}}) is None + assert _extract_spike_times({"data": {"vmspikesummary": {}}}) is None + + def test_parses_stringified_numbers(self) -> None: + # Some NDI exports stringify floats. Matches the TS handler. + doc = { + "data": { + "vmspikesummary": {"spike_times": ["0.1", "0.2", "bogus", "0.5"]}, + }, + } + assert _extract_spike_times(doc) == [0.1, 0.2, 0.5] + + def test_skips_non_finite_values(self) -> None: + doc = { + "data": { + "vmspikesummary": { + "spike_times": [ + 0.1, float("nan"), 0.2, float("inf"), 0.3, float("-inf"), + ], + }, + }, + } + assert _extract_spike_times(doc) == [0.1, 0.2, 0.3] + + def test_empty_array_returns_none(self) -> None: + doc = {"data": {"vmspikesummary": {"spike_times": []}}} + assert _extract_spike_times(doc) is None + + +class TestPickIds: + def test_pick_doc_id_prefers_id(self) -> None: + assert _pick_doc_id({"id": "A", "_id": "B"}) == "A" + assert _pick_doc_id({"_id": "B"}) == "B" + assert _pick_doc_id({"ndiId": "C"}) == "C" + assert _pick_doc_id({}) == "" + + def test_pick_unit_name_prefers_inner_name(self) -> None: + doc = { + "name": "outer", + "data": {"vmspikesummary": {"name": "inner"}}, + } + assert _pick_unit_name(doc, "did") == "inner" + + def test_pick_unit_name_falls_back_to_top_level(self) -> None: + assert _pick_unit_name({"name": "outer"}, "did") == "outer" + + def test_pick_unit_name_falls_back_to_id_tail(self) -> None: + assert _pick_unit_name({}, "abc1234567") == "Unit 234567" + + +class TestStrideSample: + def test_under_cap_returns_all(self) -> None: + assert _stride_sample([1.0, 2.0, 3.0], cap=10) == [1.0, 2.0, 3.0] + + def test_over_cap_preserves_endpoints(self) -> None: + vals = [float(i) for i in range(1000)] + out = _stride_sample(vals, cap=50) + assert len(out) == 50 + assert out[0] == 0.0 + assert out[-1] == 999.0 + + +class TestKindGating: + def test_raster_kind_omits_isi(self) -> None: + spikes = [0.0, 0.1, 0.2] + assert _build_spike_field(spikes, "raster") == spikes + assert _build_isi_field(spikes, "raster") is None + + def test_isi_histogram_kind_omits_spike_times(self) -> None: + spikes = [0.0, 0.1, 0.2] + assert _build_spike_field(spikes, "isi_histogram") is None + intervals = _build_isi_field(spikes, "isi_histogram") + assert intervals is not None + # diff of [0, 0.1, 0.2] = [0.1, 0.1]; ms = [100, 100]. + assert intervals == pytest.approx([100.0, 100.0]) + + def test_both_kind_returns_both(self) -> None: + spikes = [0.0, 0.05] + s = _build_spike_field(spikes, "both") + isi = _build_isi_field(spikes, "both") + assert s == spikes + assert isi == pytest.approx([50.0]) + + def test_isi_with_single_spike_returns_empty(self) -> None: + # 1 spike → no intervals possible. + assert _build_isi_field([0.5], "both") == [] + + def test_isi_drops_zero_and_negative_intervals(self) -> None: + # Duplicate timestamp would produce a 0-interval; we drop it. + intervals = _build_isi_field([0.0, 0.0, 0.1], "both") + assert intervals == pytest.approx([100.0]) + + +# --------------------------------------------------------------------------- +# Service-level tests with respx-mocked cloud +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_unit_doc_id_single_doc_path(cloud: NdiCloudClient) -> None: + """Direct fetch by unit_doc_id returns one unit.""" + # 24-char Mongo id so the DocumentService doesn't try to resolve it + # via ndiquery first (that's a separate code path tested elsewhere). + dataset_id = "DS_SPIKE_1" + doc_id = "a" * 24 + doc = _make_doc(doc_id, "Unit Saline", [0.1, 0.2, 0.3, 0.4]) + async with respx.mock(base_url="https://api.example.test/v1") as router: + router.get(f"/datasets/{dataset_id}/documents/{doc_id}").respond( + 200, json=_detail_body(doc), + ) + docs = DocumentService(cloud) + request = SpikeSummaryRequest( + datasetId=dataset_id, + unitDocId=doc_id, + kind="both", + ) + response = await compute_spike_summary( + request, + document_service=docs, + cloud=cloud, + session=None, + ) + + assert response.total_matching == 1 + assert response.kind == "both" + assert len(response.units) == 1 + unit = response.units[0] + assert unit.name == "Unit Saline" + assert unit.doc_id == doc_id + assert unit.spike_times == [0.1, 0.2, 0.3, 0.4] + # ISIs in ms: diff of [0.1, 0.2, 0.3, 0.4] = [0.1, 0.1, 0.1] → [100, 100, 100]. + assert unit.isi_intervals == pytest.approx([100.0, 100.0, 100.0]) + assert unit.error is None + + +@pytest.mark.asyncio +async def test_unit_name_match_query_returns_ordered_units( + cloud: NdiCloudClient, +) -> None: + """Query path with substring filter returns N units sorted by name.""" + dataset_id = "DS_SPIKE_2" + docs_in = [ + _make_doc("d1", "Unit 3 (Saline)", [0.0, 0.1]), + _make_doc("d2", "Unit 1 (Saline)", [0.0, 0.2]), + _make_doc("d3", "Unit 2 (Saline)", [0.0, 0.3]), + ] + async with respx.mock(base_url="https://api.example.test/v1") as router: + # ndiquery is auto-paginated by the cloud client; one page is enough. + router.post("/ndiquery").respond(200, json=_ndiquery_body(docs_in)) + ds = DocumentService(cloud) + request = SpikeSummaryRequest( + datasetId=dataset_id, + unitNameMatch="Saline", + kind="raster", + ) + response = await compute_spike_summary( + request, + document_service=ds, + cloud=cloud, + session=None, + ) + + assert response.total_matching == 3 + assert len(response.units) == 3 + # Sorted by name → Unit 1 < Unit 2 < Unit 3. + assert [u.name for u in response.units] == [ + "Unit 1 (Saline)", + "Unit 2 (Saline)", + "Unit 3 (Saline)", + ] + # kind='raster' → spike_times populated, isi_intervals omitted. + for unit in response.units: + assert unit.spike_times is not None + assert unit.isi_intervals is None + + +@pytest.mark.asyncio +async def test_bare_scan_caps_at_max_units(cloud: NdiCloudClient) -> None: + """Bare dataset scan honors max_units while surfacing total_matching.""" + dataset_id = "DS_SPIKE_3" + # 15 docs but max_units=5 → response.units has 5 entries; total_matching=15. + docs_in = [ + _make_doc(f"d{i}", f"Unit {i:02d}", [float(i), float(i) + 0.1]) + for i in range(15) + ] + async with respx.mock(base_url="https://api.example.test/v1") as router: + router.post("/ndiquery").respond(200, json=_ndiquery_body(docs_in)) + ds = DocumentService(cloud) + request = SpikeSummaryRequest( + datasetId=dataset_id, + kind="both", + maxUnits=5, + ) + response = await compute_spike_summary( + request, + document_service=ds, + cloud=cloud, + session=None, + ) + + assert response.total_matching == 15 + assert len(response.units) == 5 + + +@pytest.mark.asyncio +async def test_stride_sample_caps_high_spike_count_unit( + cloud: NdiCloudClient, +) -> None: + """A doc with > MAX_SPIKES_PER_UNIT spikes returns the capped count.""" + dataset_id = "DS_SPIKE_4" + # 10_000 spikes → cap at MAX_SPIKES_PER_UNIT (5000). + big_spikes = [i * 0.001 for i in range(10_000)] + doc = _make_doc("a" * 24, "Big Unit", big_spikes) + async with respx.mock(base_url="https://api.example.test/v1") as router: + router.get(f"/datasets/{dataset_id}/documents/{'a' * 24}").respond( + 200, json=_detail_body(doc), + ) + ds = DocumentService(cloud) + request = SpikeSummaryRequest( + datasetId=dataset_id, + unitDocId="a" * 24, + kind="raster", + ) + response = await compute_spike_summary( + request, + document_service=ds, + cloud=cloud, + session=None, + ) + + assert len(response.units) == 1 + unit = response.units[0] + assert unit.spike_times is not None + assert len(unit.spike_times) <= MAX_SPIKES_PER_UNIT + # First + last preserved by stride-sample. + assert unit.spike_times[0] == pytest.approx(0.0) + assert unit.spike_times[-1] == pytest.approx(9.999) + + +@pytest.mark.asyncio +async def test_empty_match_returns_empty_units_with_error_envelope( + cloud: NdiCloudClient, +) -> None: + """Zero matching docs → empty units + total_matching=0 + reason.""" + dataset_id = "DS_SPIKE_5" + async with respx.mock(base_url="https://api.example.test/v1") as router: + router.post("/ndiquery").respond(200, json=_ndiquery_body([])) + ds = DocumentService(cloud) + request = SpikeSummaryRequest( + datasetId=dataset_id, + unitNameMatch="NonexistentLabel", + kind="both", + ) + response = await compute_spike_summary( + request, + document_service=ds, + cloud=cloud, + session=None, + ) + + assert response.units == [] + assert response.total_matching == 0 + assert response.error_kind == "no_matches" + assert response.error is not None + assert "NonexistentLabel" in response.error + + +@pytest.mark.asyncio +async def test_decode_failure_yields_per_unit_soft_error( + cloud: NdiCloudClient, +) -> None: + """A matched doc with no parseable spike_times surfaces as a unit + entry with error_kind='decode_failed' instead of crashing. + """ + dataset_id = "DS_SPIKE_6" + # Doc body where the vmspikesummary subtree exists but spike_times + # is missing. Mirrors a malformed export the chat tool used to crash on. + doc = { + "id": "b" * 24, + "ndiId": "ndi-b", + "name": "Broken Unit", + "data": { + "base": {"id": "ndi-b", "name": "Broken Unit"}, + "vmspikesummary": {"name": "Broken Unit"}, + }, + } + async with respx.mock(base_url="https://api.example.test/v1") as router: + router.post("/ndiquery").respond(200, json=_ndiquery_body([doc])) + ds = DocumentService(cloud) + request = SpikeSummaryRequest( + datasetId=dataset_id, + kind="both", + ) + response = await compute_spike_summary( + request, + document_service=ds, + cloud=cloud, + session=None, + ) + + assert response.total_matching == 1 + assert len(response.units) == 1 + unit = response.units[0] + assert isinstance(unit, SpikeSummaryUnit) + assert unit.name == "Broken Unit" + assert unit.error_kind == "decode_failed" + assert unit.error is not None + assert "no parseable spike_times" in unit.error + # No data fields populated when decode failed. + assert unit.spike_times is None + assert unit.isi_intervals is None + + +@pytest.mark.asyncio +async def test_t_window_filters_spikes(cloud: NdiCloudClient) -> None: + """Spikes outside [t0, t1] are filtered before stride-sampling.""" + dataset_id = "DS_SPIKE_7" + # 0..9, with t_window=(2, 5) → keeps [2.0, 3.0, 4.0, 5.0]. + spikes = [float(i) for i in range(10)] + doc = _make_doc("c" * 24, "Windowed Unit", spikes) + async with respx.mock(base_url="https://api.example.test/v1") as router: + router.get(f"/datasets/{dataset_id}/documents/{'c' * 24}").respond( + 200, json=_detail_body(doc), + ) + ds = DocumentService(cloud) + request = SpikeSummaryRequest( + datasetId=dataset_id, + unitDocId="c" * 24, + kind="raster", + tWindow=(2.0, 5.0), + ) + response = await compute_spike_summary( + request, + document_service=ds, + cloud=cloud, + session=None, + ) + + assert len(response.units) == 1 + assert response.units[0].spike_times == [2.0, 3.0, 4.0, 5.0] + + +@pytest.mark.asyncio +async def test_camelcase_aliases_round_trip(cloud: NdiCloudClient) -> None: + """The TS handler's camelCase fields (``datasetId``, ``unitDocId``, + ``unitNameMatch``, ``tWindow``, ``maxUnits``) must populate the + snake_case Python fields without translation. + """ + request = SpikeSummaryRequest.model_validate({ + "datasetId": "DS_X", + "unitDocId": "d" * 24, + "unitNameMatch": "Saline", + "kind": "both", + "tWindow": [0.0, 10.0], + "maxUnits": 7, + "title": "Test", + }) + assert request.dataset_id == "DS_X" + assert request.unit_doc_id == "d" * 24 + assert request.unit_name_match == "Saline" + assert request.t_window == (0.0, 10.0) + assert request.max_units == 7 + assert request.title == "Test" + + +@pytest.mark.asyncio +async def test_default_max_units_when_unset(cloud: NdiCloudClient) -> None: + """When max_units isn't provided, the service falls back to + DEFAULT_MAX_UNITS (10) so callers don't accidentally pull the whole + dataset's vmspikesummary set. + """ + dataset_id = "DS_DEFAULT_CAP" + docs_in = [ + _make_doc(f"d{i}", f"Unit {i:02d}", [float(i), float(i) + 0.1]) + for i in range(DEFAULT_MAX_UNITS + 3) + ] + async with respx.mock(base_url="https://api.example.test/v1") as router: + router.post("/ndiquery").respond(200, json=_ndiquery_body(docs_in)) + ds = DocumentService(cloud) + request = SpikeSummaryRequest( + datasetId=dataset_id, + kind="both", + ) + response = await compute_spike_summary( + request, + document_service=ds, + cloud=cloud, + session=None, + ) + + assert response.total_matching == DEFAULT_MAX_UNITS + 3 + assert len(response.units) == DEFAULT_MAX_UNITS diff --git a/backend/tests/unit/test_subject_treatment_broadcast.py b/backend/tests/unit/test_subject_treatment_broadcast.py new file mode 100644 index 0000000..7842340 --- /dev/null +++ b/backend/tests/unit/test_subject_treatment_broadcast.py @@ -0,0 +1,570 @@ +"""F-1b — server-side broadcast of treatment columns onto subject table. + +Ports the cloud-app's frontend ``joinTreatmentsToSubjects`` pivot from +``apps/web/app/(app)/datasets/[id]/tables/[className]/table-shell.tsx`` +(lines 768-923) into the backend. Tests pin both helpers: + +- ``_pascal_case_from_treatment_name`` — string normalisation that + yields the dynamic column-key prefix. +- ``_broadcast_treatments_onto_subjects`` — the join itself. + +Both helpers are pure functions; no fixtures from the cloud are +required. Tests run against synthetic ``subject_rows`` and +``treatment_rows`` mirroring the shape ``_row_subject`` and +``_row_treatment`` would emit. +""" +from __future__ import annotations + +from typing import Any + +import pytest + +from backend.services.summary_table_service import ( + SUBJECT_COLUMNS, + _broadcast_treatments_onto_subjects, + _pascal_case_from_treatment_name, +) + +# --------------------------------------------------------------------------- +# _pascal_case_from_treatment_name +# --------------------------------------------------------------------------- + + +class TestPascalCaseFromTreatmentName: + def test_canonical_dabrowska_target_location(self) -> None: + """Real Dabrowska treatment name → real expected key prefix.""" + assert _pascal_case_from_treatment_name( + "Optogenetic Tetanus Stimulation Target Location", + ) == "OptogeneticTetanusStimulationTargetLocation" + + def test_simple_two_word(self) -> None: + assert _pascal_case_from_treatment_name("Optogenetic Tetanus") \ + == "OptogeneticTetanus" + + def test_four_word(self) -> None: + assert _pascal_case_from_treatment_name("Foo Bar Baz Quux") \ + == "FooBarBazQuux" + + def test_hyphens_stripped_per_word(self) -> None: + """Hyphen is non-alphanumeric, stripped from the word. Word stays + lowercase-inside since only the first character is upper-cased.""" + assert _pascal_case_from_treatment_name("with-hyphens here") \ + == "WithhyphensHere" + + def test_collapses_repeated_whitespace(self) -> None: + """``str.split()`` with no arg collapses any whitespace run.""" + assert _pascal_case_from_treatment_name("foo bar\tbaz\nqux") \ + == "FooBarBazQux" + + def test_leading_trailing_whitespace_trimmed(self) -> None: + assert _pascal_case_from_treatment_name(" leading trailing ") \ + == "LeadingTrailing" + + def test_empty_string_returns_none(self) -> None: + assert _pascal_case_from_treatment_name("") is None + + def test_whitespace_only_returns_none(self) -> None: + assert _pascal_case_from_treatment_name(" \t\n ") is None + + def test_none_returns_none(self) -> None: + assert _pascal_case_from_treatment_name(None) is None + + @pytest.mark.parametrize("non_string", [42, 3.14, ["foo"], {"k": "v"}, True]) + def test_non_string_inputs_return_none(self, non_string: Any) -> None: + assert _pascal_case_from_treatment_name(non_string) is None + + def test_all_non_alphanumeric_returns_none(self) -> None: + """A word that is entirely punctuation collapses to nothing; the + function returns ``None`` rather than emitting an illegal key.""" + assert _pascal_case_from_treatment_name("--- ___ +++") is None + + def test_digits_preserved(self) -> None: + """isalnum keeps digits; only the first char is upper-cased.""" + assert _pascal_case_from_treatment_name("dose 5mg") == "Dose5mg" + + def test_single_word(self) -> None: + assert _pascal_case_from_treatment_name("treatment") == "Treatment" + + +# --------------------------------------------------------------------------- +# _broadcast_treatments_onto_subjects +# --------------------------------------------------------------------------- + + +def _subject_row(doc_id: str, **extras: Any) -> dict[str, Any]: + """Minimal subject row mirroring ``_row_subject`` output shape.""" + return { + "subjectIdentifier": doc_id, + "subjectDocumentIdentifier": doc_id, + **extras, + } + + +def _treatment_row( + *, + subject: str | None, + name: Any, + ontology: str | None = None, + string_value: Any = None, + numeric_value: Any = None, +) -> dict[str, Any]: + """Treatment row mirroring ``_row_treatment`` output shape.""" + return { + "subjectDocumentIdentifier": subject, + "treatmentName": name, + "treatmentOntology": ontology, + "numericValue": numeric_value, + "stringValue": string_value, + } + + +class TestBroadcastTreatmentsOntoSubjects: + def test_no_treatments_returns_originals_unchanged(self) -> None: + """Empty treatments list → columns identity preserved, rows + carry no extra keys.""" + subjects = [_subject_row("S1")] + columns = list(SUBJECT_COLUMNS) + rows, cols = _broadcast_treatments_onto_subjects(subjects, columns, []) + assert cols == columns + # Same shape on each row (no dynamic keys added). + assert all(set(r.keys()) == set(subjects[0].keys()) for r in rows) + + def test_subjects_with_no_matching_treatments_get_null_cells(self) -> None: + """Subject A has a treatment; subject B has none. Columns are + extended once (both subjects see the keys), B's cells are + ``None`` (not broadcast from A).""" + subjects = [_subject_row("S_A"), _subject_row("S_B")] + treatments = [ + _treatment_row( + subject="S_A", + name="Optogenetic Tetanus", + ontology="EMPTY:tetanus", + string_value="left CA1", + ), + ] + rows, cols = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + # Columns extended by 2 (Name + Ontology). + assert len(cols) == len(SUBJECT_COLUMNS) + 2 + keys = {c["key"] for c in cols} + assert "OptogeneticTetanusName" in keys + assert "OptogeneticTetanusOntology" in keys + + row_a = next(r for r in rows if r["subjectDocumentIdentifier"] == "S_A") + row_b = next(r for r in rows if r["subjectDocumentIdentifier"] == "S_B") + # A populated. + assert row_a["OptogeneticTetanusName"] == "left CA1" + assert row_a["OptogeneticTetanusOntology"] == "EMPTY:tetanus" + # B explicitly None — broadcast does NOT spread A's values to B. + assert row_b["OptogeneticTetanusName"] is None + assert row_b["OptogeneticTetanusOntology"] is None + + def test_one_subject_one_treatment_columns_extended_by_two(self) -> None: + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="S1", + name="Optogenetic Tetanus", + ontology="EMPTY:1", + string_value="left", + ), + ] + rows, cols = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + assert len(cols) == len(SUBJECT_COLUMNS) + 2 + # Labels track the original treatmentName, not the PascalCase key. + labels_by_key = {c["key"]: c["label"] for c in cols} + assert labels_by_key["OptogeneticTetanusName"] \ + == "Optogenetic Tetanus Name" + assert labels_by_key["OptogeneticTetanusOntology"] \ + == "Optogenetic Tetanus Ontology" + assert rows[0]["OptogeneticTetanusName"] == "left" + + def test_two_treatments_same_kind_collect_into_array(self) -> None: + """Two treatments of the SAME kind on the same subject → cell + becomes an array (csvJoinFormatter renders ``"a, b"``).""" + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="S1", name="Dose", ontology="CHEBI:1", + string_value="aspirin", + ), + _treatment_row( + subject="S1", name="Dose", ontology="CHEBI:2", + string_value="ibuprofen", + ), + ] + rows, cols = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + # Columns extended only by 2 (same kind collapses to one pair). + assert len(cols) == len(SUBJECT_COLUMNS) + 2 + assert rows[0]["DoseName"] == ["aspirin", "ibuprofen"] + assert rows[0]["DoseOntology"] == ["CHEBI:1", "CHEBI:2"] + + def test_two_treatments_different_kinds_both_populated(self) -> None: + """Two treatments of DIFFERENT kinds on the same subject → + columns extended by 4 (two pairs); both populated.""" + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="S1", name="Optogenetic Tetanus", + ontology="EMPTY:t", string_value="left", + ), + _treatment_row( + subject="S1", name="Dose", + ontology="CHEBI:1", string_value="aspirin", + ), + ] + rows, cols = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + assert len(cols) == len(SUBJECT_COLUMNS) + 4 + assert rows[0]["OptogeneticTetanusName"] == "left" + assert rows[0]["OptogeneticTetanusOntology"] == "EMPTY:t" + assert rows[0]["DoseName"] == "aspirin" + assert rows[0]["DoseOntology"] == "CHEBI:1" + + def test_treatment_with_no_subject_skipped_silently(self) -> None: + """``subjectDocumentIdentifier=None`` → drop the row, columns + unchanged.""" + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject=None, name="Optogenetic Tetanus", + ontology="EMPTY:t", string_value="left", + ), + ] + rows, cols = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + # No columns added — that treatment row never broadcast. + assert cols == list(SUBJECT_COLUMNS) + assert "OptogeneticTetanusName" not in rows[0] + + def test_treatment_with_empty_string_subject_skipped(self) -> None: + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="", name="Optogenetic Tetanus", + ontology="EMPTY:t", string_value="left", + ), + ] + _, cols = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + assert cols == list(SUBJECT_COLUMNS) + + def test_treatment_with_non_string_name_skipped_silently(self) -> None: + """A non-string treatmentName collapses to ``None`` in the + pascal helper → caller drops the row, columns unchanged.""" + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="S1", name=42, + ontology="EMPTY:t", string_value="left", + ), + _treatment_row( + subject="S1", name=None, + ontology="EMPTY:t", string_value="left", + ), + _treatment_row( + subject="S1", name=["list"], + ontology="EMPTY:t", string_value="left", + ), + ] + _, cols = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + # None of those treatment rows yielded a valid prefix. + assert cols == list(SUBJECT_COLUMNS) + + def test_value_priority_string_over_numeric(self) -> None: + """When stringValue is a non-empty string AND numericValue is + also set, stringValue wins (the cloud-app's priority chain).""" + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="S1", name="Dose", + string_value="aspirin", numeric_value=5.0, + ), + ] + rows, _ = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + assert rows[0]["DoseName"] == "aspirin" + + def test_value_falls_through_to_numeric_when_string_empty(self) -> None: + """Empty stringValue ('') is falsy → numericValue picked up.""" + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="S1", name="Dose", + string_value="", numeric_value=5.0, + ), + ] + rows, _ = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + assert rows[0]["DoseName"] == 5.0 + + def test_value_falls_through_to_numeric_list(self) -> None: + """When stringValue + scalar numericValue both fall through, a + non-empty numericValue LIST (treatment_drug timing pair) wins.""" + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="S1", name="Stim", + string_value=None, numeric_value=[-21600.0, 0.0], + ), + ] + rows, _ = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + assert rows[0]["StimName"] == [-21600.0, 0.0] + + def test_value_empty_numeric_list_skipped(self) -> None: + """Empty numericValue list → cell stays empty (no value to broadcast).""" + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="S1", name="Stim", + ontology="EMPTY:1", + string_value=None, numeric_value=[], + ), + ] + rows, cols = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + # Ontology DID broadcast — the row still discovers the column pair. + assert any(c["key"] == "StimName" for c in cols) + # But the Name cell is None because value chain produced nothing. + assert rows[0]["StimName"] is None + assert rows[0]["StimOntology"] == "EMPTY:1" + + def test_input_rows_not_mutated(self) -> None: + """The helper is pure — caller's subject_rows + columns are + not modified in place.""" + original_subject = _subject_row("S1") + subjects = [original_subject] + original_columns = list(SUBJECT_COLUMNS) + columns = list(SUBJECT_COLUMNS) + treatments = [ + _treatment_row( + subject="S1", name="Dose", + ontology="EMPTY:1", string_value="aspirin", + ), + ] + _broadcast_treatments_onto_subjects(subjects, columns, treatments) + assert original_subject == {"subjectIdentifier": "S1", + "subjectDocumentIdentifier": "S1"} + assert columns == original_columns + + def test_non_string_subject_row_skipped(self) -> None: + """If a subject row's subjectDocumentIdentifier is None (bad + data) the dynamic cells default to None on that row.""" + subjects = [_subject_row("S1"), + {"subjectIdentifier": "S2", + "subjectDocumentIdentifier": None}] + treatments = [ + _treatment_row( + subject="S1", name="Dose", + ontology="EMPTY:1", string_value="aspirin", + ), + ] + rows, _ = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + # S1 row populated. + assert rows[0]["DoseName"] == "aspirin" + # S2 row (no valid subject id) gets None for the dynamic cell. + assert rows[1]["DoseName"] is None + + def test_ontology_not_string_skipped(self) -> None: + """treatmentOntology must be a non-empty string to broadcast. + Non-strings / empty strings → ontology cell stays None.""" + subjects = [_subject_row("S1")] + treatments = [ + _treatment_row( + subject="S1", name="Dose", + ontology="", # empty string — skipped + string_value="aspirin", + ), + ] + rows, _ = _broadcast_treatments_onto_subjects( + subjects, list(SUBJECT_COLUMNS), treatments, + ) + assert rows[0]["DoseName"] == "aspirin" + assert rows[0]["DoseOntology"] is None + + +# --------------------------------------------------------------------------- +# F-1b end-to-end via _project_for_class — exercises the subject branch +# wiring (treatment_drug + treatment_transfer merge from enriched dict) +# that the broadcast helper depends on. Caught the 2026-05-19 follow-up +# bug where the subject branch only read ``enriched["treatment"]`` and +# missed Bhar's subclass-only datasets. +# --------------------------------------------------------------------------- + + +class TestProjectSubjectMergesTreatmentSubclasses: + """Integration tests for _project_for_class("subject", ...) that + exercise the subclass-merge path. These would have caught the + bug where Bhar (0 literal treatment + 24466 treatment_drug + + 1675 treatment_transfer) saw no broadcast columns post-F-1b + because the subject branch only consulted enriched["treatment"].""" + + def _make_subject_doc(self, subject_id: str) -> dict[str, Any]: + """Minimal subject doc whose `_row_subject` emits a non-empty + `subjectDocumentIdentifier == subject_id`. Both `data.base.id` + and `ndiId` map to the join key per `_ndi_id`.""" + return { + "ndiId": subject_id, + "data": { + "base": {"id": subject_id, "name": f"subject-{subject_id}"}, + "document_class": {"class_name": "subject"}, + "subject": { + "local_identifier": f"local-{subject_id}", + "description": "test subject", + }, + }, + } + + def _make_treatment_drug_doc( + self, subject_id: str, mixture_name: str, + ) -> dict[str, Any]: + """A treatment_drug doc shape matching F-1e's projection input. + `_row_treatment` reads `data.treatment_drug.mixture_table` for + the name (CSV header + data rows; column 2 carries the name) + and `data.depends_on[subject_id]` for the join key.""" + return { + "data": { + "document_class": {"class_name": "treatment_drug"}, + "depends_on": [{"name": "subject_id", "value": subject_id}], + "treatment_drug": { + # CSV: "ontologyName,name\nNCBITaxon:000,{name}\n". + "mixture_table": ( + f"ontologyName,name\nNCBITaxon:000,{mixture_name}\n" + ), + "administration_onset_time": -12600.0, + "administration_offset_time": 0.0, + }, + }, + } + + def _make_treatment_transfer_doc( + self, recipient_id: str, entity: str, + ) -> dict[str, Any]: + """A treatment_transfer doc shape — uses recipient_id (NOT + subject_id) per F-1e.""" + return { + "data": { + "document_class": {"class_name": "treatment_transfer"}, + "depends_on": [ + {"name": "recipient_id", "value": recipient_id}, + ], + "treatment_transfer": { + "entity_name": entity, + "timestamp": -3600.0, + }, + }, + } + + def test_subject_branch_merges_treatment_drug_into_broadcast( + self, + ) -> None: + """When enriched["treatment_drug"] has rows but enriched["treatment"] + is empty, the F-1b subject branch must still broadcast the drug + treatments onto the subjects. This is the Bhar parity scenario.""" + from backend.services.summary_table_service import _project_for_class + + subject_docs = [self._make_subject_doc("S1")] + enriched = { + "openminds_subject": [], + "treatment": [], + "treatment_drug": [ + self._make_treatment_drug_doc("S1", "Eschericia coli OP50"), + ], + "treatment_transfer": [], + } + columns, rows = _project_for_class("subject", subject_docs, enriched) + col_keys = [c["key"] for c in columns] + # Dynamic broadcast columns are present: + assert any( + k.startswith("EschericiaColiOP50") and k.endswith("Name") + for k in col_keys + ), f"missing EschericiaColiOP50* in {col_keys}" + # Row 0 carries the broadcast cell — value isn't None. + broadcast_name_key = next( + k for k in col_keys + if k.startswith("EschericiaColiOP50") and k.endswith("Name") + ) + assert rows[0][broadcast_name_key] is not None + + def test_subject_branch_merges_treatment_transfer_into_broadcast( + self, + ) -> None: + """treatment_transfer uses recipient_id as the depends_on key but + _row_treatment maps it to subjectDocumentIdentifier just like + treatment_drug — so the broadcast should pick it up too.""" + from backend.services.summary_table_service import _project_for_class + + subject_docs = [self._make_subject_doc("S1")] + enriched = { + "openminds_subject": [], + "treatment": [], + "treatment_drug": [], + "treatment_transfer": [ + self._make_treatment_transfer_doc("S1", "Bacteria"), + ], + } + columns, _ = _project_for_class("subject", subject_docs, enriched) + col_keys = [c["key"] for c in columns] + assert any( + k.startswith("Bacteria") and k.endswith("Name") + for k in col_keys + ), f"missing Bacteria* in {col_keys}" + + def test_subject_branch_all_three_classes_merge(self) -> None: + """All three treatment-class enrichments populated → all three + kinds appear as broadcast columns.""" + from backend.services.summary_table_service import _project_for_class + + # We can't easily construct a "literal treatment" doc with the + # same projector shape (it goes through the legacy path), so + # cover the two subclasses for the merge assertion. The merge + # logic concatenates all three lists into _row_treatment input. + subject_docs = [self._make_subject_doc("S1")] + enriched = { + "openminds_subject": [], + "treatment": [], + "treatment_drug": [ + self._make_treatment_drug_doc("S1", "Drug A"), + ], + "treatment_transfer": [ + self._make_treatment_transfer_doc("S1", "Transfer B"), + ], + } + columns, _ = _project_for_class("subject", subject_docs, enriched) + col_keys = [c["key"] for c in columns] + assert any(k.startswith("DrugA") for k in col_keys) + assert any(k.startswith("TransferB") for k in col_keys) + + def test_subject_branch_empty_when_no_treatments(self) -> None: + """Backwards-compat: subject branch with no treatment enrichments + at all returns SUBJECT_COLUMNS verbatim (no dynamic broadcast).""" + from backend.services.summary_table_service import _project_for_class + + subject_docs = [self._make_subject_doc("S1")] + enriched = { + "openminds_subject": [], + "treatment": [], + "treatment_drug": [], + "treatment_transfer": [], + } + columns, rows = _project_for_class("subject", subject_docs, enriched) + # Column set equals SUBJECT_COLUMNS exactly. + assert [c["key"] for c in columns] == [c["key"] for c in SUBJECT_COLUMNS] + assert len(rows) == 1 diff --git a/backend/tests/unit/test_summary_table_class_alias.py b/backend/tests/unit/test_summary_table_class_alias.py new file mode 100644 index 0000000..00af989 --- /dev/null +++ b/backend/tests/unit/test_summary_table_class_alias.py @@ -0,0 +1,218 @@ +"""Regression coverage for the ``probe → element`` class-name alias added +2026-05-14 to fix the chat tool's ``query_documents(className="probe")`` +misshit on Dabrowska BNST (and every other dataset published under the +modern schema, where ``element`` is the canonical class name and +``probe`` returns 0 docs). + +Behavior under test: + +- :class:`SummaryTableService` accepts the user-friendly literal + ``"probe"`` even when the underlying dataset stores its probe-class + docs as ``"element"``. The cloud's ``isa probe`` query returns 0 IDs + for those datasets — but a second ``isa element`` query succeeds, and + the projection emits ``PROBE_COLUMNS`` rows (matching what the chat + tool's ``query_documents`` consumer expects). + +- The alias is logged so observability sees ``resolved_class=element`` + on a request for ``class_name=probe``. + +- ``epoch → element_epoch`` follows the same pattern. + +- When the literal class DOES return IDs (legacy datasets that emit + ``probe`` directly), the alias is never invoked — i.e. zero behavior + change on the happy path. +""" +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from backend.services.summary_table_service import ( + _CLASS_ALIASES, + SummaryTableService, +) + + +def _make_service(*, ndiquery_responses: dict[str, dict[str, Any]]) -> tuple[ + SummaryTableService, MagicMock, +]: + """Build a ``SummaryTableService`` whose cloud client returns the + canned ``ndiquery`` payload for each class name in + ``ndiquery_responses`` (keyed by ``param1``). Any class not in the + map returns an empty document list (matching the cloud's behavior + for missing classes). + """ + cloud = MagicMock() + + async def _ndiquery(*, searchstructure, scope, access_token, **kwargs): + class_name = searchstructure[0]["param1"] + return ndiquery_responses.get( + class_name, + {"documents": [], "totalItems": 0, "page": 1, "pageSize": 1000}, + ) + + cloud.ndiquery = _ndiquery + cloud.bulk_fetch = AsyncMock(return_value=[]) + + svc = SummaryTableService(cloud=cloud, cache=None) + return svc, cloud + + +@pytest.mark.asyncio +async def test_probe_alias_falls_back_to_element_when_probe_returns_zero(): + """Dabrowska-shape: 0 probe docs, 2 element docs. ``isa probe`` returns + empty so the service must retry ``isa element`` and surface those. + """ + element_doc = { + "id": "el1", + "ndiId": "ndi-el1", + "data": { + "base": {"id": "ndi-el1", "name": "patch-Vm-01"}, + "element": {"name": "patch-Vm-01", "type": "patch-Vm"}, + }, + } + element_doc_2 = { + "id": "el2", + "ndiId": "ndi-el2", + "data": { + "base": {"id": "ndi-el2", "name": "stim-01"}, + "element": {"name": "stim-01", "type": "stimulator"}, + }, + } + + svc, cloud = _make_service( + ndiquery_responses={ + # `isa probe` returns nothing — modern dataset. + "probe": {"documents": [], "totalItems": 0, "page": 1, "pageSize": 1000}, + # `isa element` returns 2 docs — the alias hit. + "element": { + "documents": [{"id": "el1"}, {"id": "el2"}], + "totalItems": 2, + "page": 1, + "pageSize": 1000, + }, + }, + ) + cloud.bulk_fetch = AsyncMock(return_value=[element_doc, element_doc_2]) + + result = await svc.single_class("DS_DABROWSKA", "probe", session=None) + rows = result["rows"] + # Both element docs projected as probe rows under PROBE_COLUMNS. + assert len(rows) == 2, f"expected probe→element alias to return 2 rows, got {rows!r}" + # Probe-column shape: probeName + probeType present. + assert rows[0]["probeName"] in {"patch-Vm-01", "stim-01"} + assert rows[0]["probeType"] in {"patch-Vm", "stimulator"} + types = {r["probeType"] for r in rows} + assert types == {"patch-Vm", "stimulator"}, ( + f"probe alias must surface element.type values; got {types!r}" + ) + + +@pytest.mark.asyncio +async def test_probe_alias_not_invoked_when_probe_returns_docs(): + """Legacy datasets (Van Hooser): ``isa probe`` returns docs; the alias + must NOT fire, and the resolved class stays ``probe`` (logged for + observability). Behavior is byte-identical to the pre-alias build. + """ + probe_doc = { + "id": "p1", + "ndiId": "ndi-p1", + "data": { + "base": {"id": "ndi-p1", "name": "n-trode-01"}, + "probe": {"name": "n-trode-01", "type": "n-trode"}, + }, + } + + svc, cloud = _make_service( + ndiquery_responses={ + "probe": { + "documents": [{"id": "p1"}], + "totalItems": 1, + "page": 1, + "pageSize": 1000, + }, + # `element` would also return data but the alias path must + # not consult it. Assert by giving `element` a poison value + # — if the service queried it, the probeType field would + # show "POISON" instead of "n-trode". + "element": { + "documents": [{"id": "POISON"}], + "totalItems": 1, + "page": 1, + "pageSize": 1000, + }, + }, + ) + cloud.bulk_fetch = AsyncMock(return_value=[probe_doc]) + + result = await svc.single_class("DS_VANHOOSER", "probe", session=None) + rows = result["rows"] + assert len(rows) == 1 + assert rows[0]["probeType"] == "n-trode", ( + "alias must not fire when literal class returns docs" + ) + + +@pytest.mark.asyncio +async def test_epoch_alias_falls_back_to_element_epoch(): + """``isa epoch`` returns zero on modern datasets; ``isa element_epoch`` + is the canonical class name. Same alias pattern as probe→element. + """ + element_epoch_doc = { + "id": "ee1", + "ndiId": "ndi-ee1", + "data": { + "base": {"id": "ndi-ee1", "name": "epoch-1"}, + "element_epoch": { + "name": "epoch-1", + "t0_t1": [0.0, 100.0], + "epoch_clock": "dev_local_time", + }, + }, + } + + svc, cloud = _make_service( + ndiquery_responses={ + "epoch": {"documents": [], "totalItems": 0, "page": 1, "pageSize": 1000}, + "element_epoch": { + "documents": [{"id": "ee1"}], + "totalItems": 1, + "page": 1, + "pageSize": 1000, + }, + }, + ) + cloud.bulk_fetch = AsyncMock(return_value=[element_epoch_doc]) + + result = await svc.single_class("DS_MODERN", "epoch", session=None) + rows = result["rows"] + assert len(rows) == 1, "epoch→element_epoch alias must surface the row" + # EPOCH_COLUMNS shape: epochNumber + t0_t1 normalized. + assert rows[0]["epochNumber"] == "epoch-1" + + +def test_class_aliases_table_has_expected_entries(): + """Snapshot of the alias map — additions are intentional, removals + require updating this test + the chat tool's system prompt. + + 2026-05-19 — F-1d extension. Added legacy-class chains for + `epoch` and `element_epoch` so Phase-A ingested datasets + (Francesconi: `epochfiles_ingested` + `daqreader_mfdaq_epochdata_ingested`) + surface via the same /tables/element_epoch route the workspace + Sessions cascade uses. + """ + assert _CLASS_ALIASES == { + "probe": ["element"], + "epoch": [ + "element_epoch", + "epochfiles_ingested", + "daqreader_mfdaq_epochdata_ingested", + ], + "element_epoch": [ + "epochfiles_ingested", + "daqreader_mfdaq_epochdata_ingested", + ], + "stimulus": ["stimulus_presentation"], + } diff --git a/backend/tests/unit/test_summary_table_pagination.py b/backend/tests/unit/test_summary_table_pagination.py new file mode 100644 index 0000000..b7ba3e5 --- /dev/null +++ b/backend/tests/unit/test_summary_table_pagination.py @@ -0,0 +1,236 @@ +"""Server-side pagination on /tables/{class} (Stream 5.8, 2026-05-16). + +Locks the new pagination contract: + +* When neither ``page`` nor ``page_size`` is supplied, the response keeps + the legacy unpaged envelope (BC for the Document Explorer + cron). +* When either is supplied, the response gains ``{page, pageSize, totalRows, + hasMore}`` and ``rows`` is sliced server-side. +* Pagination happens AFTER the cache layer so the cache stays keyed by + ``(dataset_id, class_name, user_scope)`` only — every page reads from + the same cached full envelope. + +The unit test exercises the pure ``_paginate`` helper plus the +``single_class`` flow with a stubbed cloud client so the cache + slice +chain is end-to-end testable without a live Railway env. +""" +from __future__ import annotations + +from typing import Any + +import pytest + +from backend.services.summary_table_service import ( + SummaryTableService, + _paginate, +) + +# --------------------------------------------------------------------------- +# Pure helper: _paginate +# --------------------------------------------------------------------------- + +def _envelope(n: int) -> dict[str, Any]: + """Build a synthetic full-table envelope with ``n`` rows.""" + return { + "columns": [{"key": "x", "label": "X"}], + "rows": [{"x": i} for i in range(n)], + "distinct_summary": {"x": {"distinct_count": n, "top_values": []}}, + } + + +class TestPaginateHelper: + def test_first_page_with_more_to_come(self) -> None: + out = _paginate(_envelope(500), page=1, page_size=200) + assert out["page"] == 1 + assert out["pageSize"] == 200 + assert out["totalRows"] == 500 + assert out["hasMore"] is True + assert len(out["rows"]) == 200 + # First row should be index 0; last index 199. + assert out["rows"][0]["x"] == 0 + assert out["rows"][-1]["x"] == 199 + + def test_middle_page(self) -> None: + out = _paginate(_envelope(500), page=2, page_size=200) + assert out["page"] == 2 + assert out["totalRows"] == 500 + assert out["hasMore"] is True + assert len(out["rows"]) == 200 + assert out["rows"][0]["x"] == 200 + assert out["rows"][-1]["x"] == 399 + + def test_last_page_partial(self) -> None: + out = _paginate(_envelope(500), page=3, page_size=200) + assert out["page"] == 3 + assert out["totalRows"] == 500 + # 500 rows / 200 page_size = page 3 has rows 400-499 (100 rows). + assert out["hasMore"] is False + assert len(out["rows"]) == 100 + assert out["rows"][0]["x"] == 400 + assert out["rows"][-1]["x"] == 499 + + def test_page_beyond_total_yields_empty_rows(self) -> None: + out = _paginate(_envelope(50), page=2, page_size=200) + # Page 2 of a 50-row table is past the end. Don't error; return + # an empty rows array so callers can still inspect totalRows + + # hasMore. + assert out["rows"] == [] + assert out["totalRows"] == 50 + assert out["hasMore"] is False + + def test_carries_distinct_summary_verbatim(self) -> None: + full = _envelope(500) + out = _paginate(full, page=1, page_size=10) + # distinct_summary is full-table — should be unchanged regardless + # of how the rows are sliced. + assert out["distinct_summary"] == full["distinct_summary"] + + def test_carries_columns_verbatim(self) -> None: + full = _envelope(50) + out = _paginate(full, page=1, page_size=200) + assert out["columns"] == full["columns"] + + def test_empty_full_table(self) -> None: + out = _paginate({"columns": [], "rows": []}, page=1, page_size=200) + assert out["rows"] == [] + assert out["totalRows"] == 0 + assert out["hasMore"] is False + + +# --------------------------------------------------------------------------- +# single_class flow — verify BC unpaged path + paged path +# --------------------------------------------------------------------------- + +class _FakeCache: + """In-memory cache that mimics RedisTableCache's get_or_compute API.""" + + def __init__(self) -> None: + self._store: dict[str, dict[str, Any]] = {} + self.compute_count = 0 + + async def get_or_compute(self, key: str, compute: Any) -> dict[str, Any]: + if key in self._store: + return self._store[key] + self.compute_count += 1 + value = await compute() + self._store[key] = value + return value + + +class _StubService(SummaryTableService): + """SummaryTableService that bypasses the real cloud + projection so + pagination unit tests don't need fixture docs. ``_build_single_class`` + is mocked to return a fixed envelope. + """ + + def __init__(self, full_envelope: dict[str, Any]) -> None: + self._envelope = full_envelope + self.cache = _FakeCache() + self.cloud = None # type: ignore[assignment] + + async def _build_single_class( # type: ignore[override] + self, + dataset_id: str, # noqa: ARG002 — args required by parent signature; test ignores them + class_name: str, # noqa: ARG002 + *, + access_token: str | None, # noqa: ARG002 + ) -> dict[str, Any]: + return self._envelope + + +@pytest.mark.asyncio +async def test_single_class_unpaged_returns_full_envelope() -> None: + """When page/page_size are both None the response keeps the legacy shape.""" + full = _envelope(300) + svc = _StubService(full) + + result = await svc.single_class("DS1", "subject", session=None) + + assert "page" not in result + assert "pageSize" not in result + assert "totalRows" not in result + assert "hasMore" not in result + assert len(result["rows"]) == 300 + assert result["columns"] == full["columns"] + + +@pytest.mark.asyncio +async def test_single_class_paged_slices_server_side() -> None: + """Passing page+page_size returns the paged envelope.""" + full = _envelope(750) + svc = _StubService(full) + + page1 = await svc.single_class( + "DS1", "subject", session=None, page=1, page_size=200, + ) + assert page1["page"] == 1 + assert page1["pageSize"] == 200 + assert page1["totalRows"] == 750 + assert page1["hasMore"] is True + assert len(page1["rows"]) == 200 + + page2 = await svc.single_class( + "DS1", "subject", session=None, page=2, page_size=200, + ) + assert page2["page"] == 2 + assert page2["rows"][0]["x"] == 200 + + +@pytest.mark.asyncio +async def test_pagination_shares_one_cached_full_envelope() -> None: + """The cache is keyed by (dataset, class) — not by page. Asking for + pages 1, 2, 3 should compute the full envelope ONCE; subsequent pages + hit the cache and slice in-memory. + + This is THE egress-saving invariant: the cloud-fetch + projection work + happens once per dataset/class regardless of how many pages a viewer + requests.""" + full = _envelope(1000) + svc = _StubService(full) + + # First request — populates cache. + await svc.single_class( + "DS1", "subject", session=None, page=1, page_size=200, + ) + # Three more requests at different pages should all hit cache. + await svc.single_class( + "DS1", "subject", session=None, page=2, page_size=200, + ) + await svc.single_class( + "DS1", "subject", session=None, page=3, page_size=200, + ) + # An unpaged request from a different consumer also hits the same cache. + await svc.single_class("DS1", "subject", session=None) + + cache = svc.cache + assert isinstance(cache, _FakeCache) + assert cache.compute_count == 1 + + +@pytest.mark.asyncio +async def test_single_class_only_page_defaults_page_size() -> None: + """If only ``page`` is supplied, page_size defaults to 200.""" + full = _envelope(500) + svc = _StubService(full) + + result = await svc.single_class( + "DS1", "subject", session=None, page=1, + ) + assert result["page"] == 1 + assert result["pageSize"] == 200 + assert len(result["rows"]) == 200 + + +@pytest.mark.asyncio +async def test_single_class_only_page_size_defaults_page() -> None: + """If only ``page_size`` is supplied, page defaults to 1.""" + full = _envelope(500) + svc = _StubService(full) + + result = await svc.single_class( + "DS1", "subject", session=None, page_size=100, + ) + assert result["page"] == 1 + assert result["pageSize"] == 100 + assert len(result["rows"]) == 100 + assert result["rows"][0]["x"] == 0 diff --git a/backend/tests/unit/test_summary_table_projection.py b/backend/tests/unit/test_summary_table_projection.py index 23c385e..9d54593 100644 --- a/backend/tests/unit/test_summary_table_projection.py +++ b/backend/tests/unit/test_summary_table_projection.py @@ -13,15 +13,19 @@ import pytest from backend.services.summary_table_service import ( + DISTINCT_SUMMARY_MAX_ROWS, + DISTINCT_SUMMARY_TOP_K, SUBJECT_COLUMNS, _attach_openminds_enrichment, _background_strain_from_strain, + _build_distinct_summary, _clean, _clock_indices, _depends_on_value_by_name, _depends_on_values, _element_subject_ndi, _epoch_element_ndi, + _hashable, _index_by_ndi_id, _ndi_id, _normalize_t0_t1, @@ -621,6 +625,217 @@ def test_basic(self) -> None: assert row["stringValue"] is None # "" normalized to None assert row["subjectDocumentIdentifier"] == "SUBJ_X" + def test_treatment_drug_subclass_projection(self) -> None: + """F-1e (2026-05-19) — treatment_drug docs project under + TREATMENT_COLUMNS via the auto-detect branch in _row_treatment. + Bhar treatment_drug shape: mixture_table CSV holds the name, + administration_onset_time / administration_offset_time hold + timing (with both numeric AND HH:MM:SS string forms). + """ + doc = {"data": { + "depends_on": [{"name": "subject_id", "value": "SUBJ_BHAR"}], + "treatment_drug": { + "location_ontologyName": "MICRO:0000480", + "location_name": "agar plate medium", + "mixture_table": "ontologyName,name\nNCBITaxon:637912,Eschericia coli OP50\n", + "administration_onset_time": -21600, # numeric seconds + "administration_offset_time": 0, + "administration_duration": 6, + }, + }} + row = _row_treatment(doc) + assert row["treatmentName"] == "Eschericia coli OP50" + assert row["treatmentOntology"] == "MICRO:0000480" + assert row["numericValue"] == [-21600.0, 0.0] + assert row["stringValue"] == "agar plate medium" + assert row["subjectDocumentIdentifier"] == "SUBJ_BHAR" + + def test_treatment_drug_hhmmss_timing_parses(self) -> None: + """administration_onset_time / offset_time can also be + HH:MM:SS strings (Bhar emits the negative form for + pre-experiment treatments: "-06:00:00" = 6h before). + """ + doc = {"data": { + "depends_on": [{"name": "subject_id", "value": "SUBJ_Y"}], + "treatment_drug": { + "mixture_table": "ontologyName,name\nFOO:1,Some Drug\n", + "administration_onset_time": "-06:00:00", + "administration_offset_time": "00:00:00", + }, + }} + row = _row_treatment(doc) + assert row["numericValue"] == [-21600.0, 0.0] + assert row["treatmentName"] == "Some Drug" + + def test_treatment_drug_mixture_table_multi_row_joins_with_plus(self) -> None: + """Multi-row mixtures join names with ' + ' so a 2-drug cocktail + renders 'NameA + NameB' on the Gantt label.""" + doc = {"data": { + "treatment_drug": { + "mixture_table": ( + "ontologyName,name\n" + "FOO:1,DrugA\n" + "FOO:2,DrugB\n" + ), + }, + }} + row = _row_treatment(doc) + assert row["treatmentName"] == "DrugA + DrugB" + + def test_treatment_transfer_subclass_projection(self) -> None: + """F-1e (2026-05-19) — treatment_transfer docs project under + TREATMENT_COLUMNS. Different sub-block, different depends_on + edge name (recipient_id), single-point timing. + """ + doc = {"data": { + "depends_on": [ + {"name": "recipient_id", "value": "SUBJ_RECIP"}, + {"name": "donor_id", "value": "SUBJ_DONOR"}, + ], + "treatment_transfer": { + "timestamp": -72000, + "clocktype": "dev_local_time", + "entity_name": "agar plate medium", + "entity_ontologyNode": "MICRO:0000480", + "method_name": "C. elegans transfer method: titanium pick", + "method_ontologyNode": "EMPTY:0000256", + }, + }} + row = _row_treatment(doc) + assert row["treatmentName"] == "agar plate medium" + assert row["treatmentOntology"] == "MICRO:0000480" + assert row["numericValue"] == [-72000] # length-1 → single-point Gantt tick + assert row["stringValue"] == "C. elegans transfer method: titanium pick" + assert row["subjectDocumentIdentifier"] == "SUBJ_RECIP" + + def test_treatment_drug_missing_timing_returns_null(self) -> None: + """If onset/offset are missing/empty, numericValue is None + (downstream falls back to ordinal Gantt timing).""" + doc = {"data": { + "treatment_drug": { + "mixture_table": "ontologyName,name\nFOO:1,X\n", + # no administration_*_time + }, + }} + row = _row_treatment(doc) + assert row["numericValue"] is None + assert row["treatmentName"] == "X" + + def test_dispatch_routes_treatment_subclasses(self) -> None: + """The /tables/{class} dispatcher routes treatment_drug + treatment_transfer + to TREATMENT_COLUMNS, same as `treatment`.""" + from backend.services.summary_table_service import ( + TREATMENT_COLUMNS, + _project_for_class, + ) + for cls in ("treatment", "treatment_drug", "treatment_transfer"): + doc = {"data": { + cls: ({"name": "x"} if cls == "treatment" else + {"mixture_table": "h,n\nO:1,X\n"} if cls == "treatment_drug" else + {"entity_name": "X", "timestamp": 0}), + }} + columns, rows = _project_for_class(cls, [doc], {}) + assert columns == TREATMENT_COLUMNS + assert len(rows) == 1 + + +class TestStimulusRow: + """F-1 (2026-05-19) — stimulus_presentation row projection.""" + + def test_basic_presentation_extracts_count_and_timing(self) -> None: + from backend.services.summary_table_service import _row_stimulus + doc = { + "data": { + "base": {"id": "STIM_DOC_42"}, + "depends_on": [ + {"name": "element_id", "value": "EL_STIM_7"}, + ], + "stimulus_presentation": { + "name": "Visual Grating", + "presentations": [ + {"time_started": 1.5, "time_stopped": 2.5}, + {"time_started": 11.5, "time_stopped": 12.5}, + {"time_started": 21.5, "time_stopped": 22.5}, + ], + }, + }, + } + row = _row_stimulus(doc) + assert row["stimulusName"] == "Visual Grating" + assert row["elementDocumentIdentifier"] == "EL_STIM_7" + assert row["presentationCount"] == 3 + assert row["firstPresentationTime"] == 1.5 + assert row["lastPresentationTime"] == 21.5 + + def test_empty_presentations_is_zero_count(self) -> None: + from backend.services.summary_table_service import _row_stimulus + doc = { + "data": { + "stimulus_presentation": {"name": "Empty", "presentations": []}, + }, + } + row = _row_stimulus(doc) + assert row["presentationCount"] == 0 + assert row["firstPresentationTime"] is None + assert row["lastPresentationTime"] is None + assert row["stimulusName"] == "Empty" + + def test_missing_presentations_field_handled(self) -> None: + from backend.services.summary_table_service import _row_stimulus + doc = {"data": {"stimulus_presentation": {"name": "Sparse"}}} + row = _row_stimulus(doc) + assert row["presentationCount"] == 0 + assert row["firstPresentationTime"] is None + assert row["lastPresentationTime"] is None + + def test_falls_back_to_stimulus_element_id_depends_on(self) -> None: + from backend.services.summary_table_service import _row_stimulus + doc = { + "data": { + "depends_on": [ + {"name": "stimulus_element_id", "value": "EL_OLD_5"}, + ], + "stimulus_presentation": {"name": "Old shape", "presentations": []}, + }, + } + row = _row_stimulus(doc) + assert row["elementDocumentIdentifier"] == "EL_OLD_5" + + def test_no_name_falls_back_to_base_name(self) -> None: + from backend.services.summary_table_service import _row_stimulus + doc = { + "data": { + "base": {"id": "STIM_X", "name": "from-base"}, + "stimulus_presentation": {"presentations": []}, + }, + } + row = _row_stimulus(doc) + assert row["stimulusName"] == "from-base" + + def test_class_alias_short_form(self) -> None: + """Calling /tables/stimulus should resolve to stimulus_presentation.""" + from backend.services.summary_table_service import _CLASS_ALIASES + assert "stimulus" in _CLASS_ALIASES + assert "stimulus_presentation" in _CLASS_ALIASES["stimulus"] + + def test_project_for_class_dispatches_to_stimulus(self) -> None: + """The dispatcher routes both `stimulus` and `stimulus_presentation`.""" + from backend.services.summary_table_service import ( + STIMULUS_COLUMNS, + _project_for_class, + ) + doc = { + "data": { + "base": {"id": "X"}, + "stimulus_presentation": {"name": "X", "presentations": []}, + }, + } + for cls in ("stimulus", "stimulus_presentation"): + columns, rows = _project_for_class(cls, [doc], {}) + assert columns == STIMULUS_COLUMNS + assert len(rows) == 1 + assert rows[0]["stimulusName"] == "X" + # --------------------------------------------------------------------------- # Top-level dispatcher @@ -728,3 +943,164 @@ def test_falls_back_to_age_category(self) -> None: ]) def test_clean_table(v: object, expected: object) -> None: assert _clean(v) == expected + + +# --------------------------------------------------------------------------- +# distinct_summary — per-column cardinality + top values +# --------------------------------------------------------------------------- + +class TestHashable: + """`_hashable` makes any cell value usable as a dict key for counting.""" + + def test_scalars_pass_through(self) -> None: + assert _hashable("x") == "x" + assert _hashable(1) == 1 + assert _hashable(1.5) == 1.5 + assert _hashable(True) is True + assert _hashable(None) is None + + def test_dict_stringifies_deterministically(self) -> None: + # Same dict, different insertion order → same key. + a = _hashable({"devTime": 0, "globalTime": None}) + b = _hashable({"globalTime": None, "devTime": 0}) + assert a == b + assert isinstance(a, str) + + def test_list_is_hashable(self) -> None: + # Lists aren't hashable in Python; we stringify. + h = _hashable([1, 2, 3]) + assert isinstance(h, str) + + +class TestBuildDistinctSummary: + """Per-column cardinality + top-K values across ALL rows.""" + + def test_dabrowska_optogenetic_treatment_collapse(self) -> None: + """Smoke-tested 2026-05-13 case that motivated this feature. + + Dabrowska BNST has 49 treatment rows all named + "Optogenetic Tetanus Stimulation Target Location". The LLM + assumed only optogenetic treatments existed because all rows + looked identical; distinct_summary surfaces the collapse so + the model knows to pivot to ontologyTableRow for Saline/CNO. + """ + columns = [ + {"key": "treatmentName", "label": "Treatment"}, + {"key": "treatmentOntology", "label": "Treatment Ontology"}, + ] + rows = [ + {"treatmentName": "Optogenetic Tetanus Stimulation Target Location", + "treatmentOntology": "UBERON:0001234"} + for _ in range(49) + ] + summary = _build_distinct_summary(columns, rows) + assert summary["treatmentName"]["distinct_count"] == 1 + assert summary["treatmentName"]["top_values"] == [ + {"value": "Optogenetic Tetanus Stimulation Target Location", + "count": 49}, + ] + assert summary["treatmentOntology"]["distinct_count"] == 1 + + def test_multi_value_top_k(self) -> None: + columns = [{"key": "speciesName", "label": "Species"}] + rows = ( + [{"speciesName": "Mus musculus"}] * 10 + + [{"speciesName": "Rattus norvegicus"}] * 5 + + [{"speciesName": "Macaca mulatta"}] * 3 + + [{"speciesName": "Drosophila"}] * 2 + + [{"speciesName": "Danio rerio"}] * 1 + + [{"speciesName": "Mustela putorius furo"}] * 1 + ) + summary = _build_distinct_summary(columns, rows) + assert summary["speciesName"]["distinct_count"] == 6 + # top-K cap at DISTINCT_SUMMARY_TOP_K (default 5). + top = summary["speciesName"]["top_values"] + assert len(top) == DISTINCT_SUMMARY_TOP_K + # Descending by count. + assert top[0] == {"value": "Mus musculus", "count": 10} + assert top[1] == {"value": "Rattus norvegicus", "count": 5} + assert top[2] == {"value": "Macaca mulatta", "count": 3} + + def test_none_cells_are_counted(self) -> None: + """Missing/None values are tallied so the LLM sees row-count gaps.""" + columns = [{"key": "ageAtRecording", "label": "Age"}] + rows = [ + {"ageAtRecording": "P30"}, + {"ageAtRecording": None}, + {"ageAtRecording": None}, + ] + summary = _build_distinct_summary(columns, rows) + assert summary["ageAtRecording"]["distinct_count"] == 2 + # None counts as a value so the LLM can see "2/3 had no age". + none_entry = next( + e for e in summary["ageAtRecording"]["top_values"] + if e["value"] is None + ) + assert none_entry["count"] == 2 + + def test_empty_rows_returns_empty_per_column_entries(self) -> None: + columns = [{"key": "a", "label": "A"}, {"key": "b", "label": "B"}] + summary = _build_distinct_summary(columns, []) + assert summary == { + "a": {"distinct_count": 0, "top_values": []}, + "b": {"distinct_count": 0, "top_values": []}, + } + + def test_skipped_when_too_many_rows(self) -> None: + columns = [{"key": "x", "label": "X"}] + rows = [{"x": i} for i in range(DISTINCT_SUMMARY_MAX_ROWS + 1)] + summary = _build_distinct_summary(columns, rows) + assert summary == {"_meta": "skipped due to large row count"} + + def test_handles_dict_cell_values(self) -> None: + """epochStart projects as {devTime, globalTime} — must still tally.""" + columns = [{"key": "epochStart", "label": "Start"}] + rows = [ + {"epochStart": {"devTime": 0, "globalTime": None}}, + {"epochStart": {"devTime": 0, "globalTime": None}}, + {"epochStart": {"devTime": 100, "globalTime": None}}, + ] + summary = _build_distinct_summary(columns, rows) + assert summary["epochStart"]["distinct_count"] == 2 + # Top value count==2. + assert summary["epochStart"]["top_values"][0]["count"] == 2 + + def test_ignores_columns_with_non_string_keys(self) -> None: + """Defensive: a malformed columns entry shouldn't crash the build.""" + columns = [ + {"key": "good", "label": "Good"}, + {"label": "no key here"}, # type: ignore[typeddict-item] + ] + rows = [{"good": "v"}] + summary = _build_distinct_summary(columns, rows) + assert "good" in summary + # The malformed column is silently skipped — no extra key surfaced. + assert len(summary) == 1 + + +class TestBuildSingleClassResponseShape: + """Smoke-level: _build_distinct_summary is folded into the response.""" + + def test_response_includes_distinct_summary_key(self) -> None: + from backend.services.summary_table_service import _project_for_class + + # Project a tiny subject set and verify _project_for_class + + # _build_distinct_summary compose into the expected response shape. + subject = { + "data": { + "base": {"id": "S1", "session_id": "sess"}, + "subject": {"local_identifier": "A1"}, + }, + } + columns, rows = _project_for_class( + "subject", [subject], + {"openminds_subject": [], "subject": [subject], "treatment": []}, + ) + summary = _build_distinct_summary(columns, rows) + # Every projected column appears in the summary. + for col in columns: + assert col["key"] in summary + entry = summary[col["key"]] + assert "distinct_count" in entry + assert "top_values" in entry + assert isinstance(entry["top_values"], list) diff --git a/backend/tests/unit/test_summary_table_subject_filter.py b/backend/tests/unit/test_summary_table_subject_filter.py new file mode 100644 index 0000000..3630cbc --- /dev/null +++ b/backend/tests/unit/test_summary_table_subject_filter.py @@ -0,0 +1,93 @@ +"""F-2 (2026-05-19) — ``?subject=`` filter post-cache, pre-paginate. + +Tests target the pure ``_filter_rows_by_subject`` helper that the +``/tables/{class}?subject=`` route uses to narrow workspace +SessionsBrowser cascade traffic. +""" +from __future__ import annotations + +from backend.services.summary_table_service import _filter_rows_by_subject + + +def _table(rows): + return { + "columns": [ + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + {"key": "epochNumber", "label": "Epoch"}, + ], + "rows": rows, + "distinct_summary": {"subjectDocumentIdentifier": 3}, + } + + +def test_filter_keeps_matching_subject_only(): + table = _table([ + {"subjectDocumentIdentifier": "subj_A", "epochNumber": 1}, + {"subjectDocumentIdentifier": "subj_B", "epochNumber": 2}, + {"subjectDocumentIdentifier": "subj_A", "epochNumber": 3}, + {"subjectDocumentIdentifier": "subj_C", "epochNumber": 4}, + ]) + out = _filter_rows_by_subject(table, "subj_A") + assert len(out["rows"]) == 2 + assert all(r["subjectDocumentIdentifier"] == "subj_A" for r in out["rows"]) + + +def test_filter_returns_empty_when_no_matches(): + table = _table([ + {"subjectDocumentIdentifier": "subj_A", "epochNumber": 1}, + {"subjectDocumentIdentifier": "subj_B", "epochNumber": 2}, + ]) + out = _filter_rows_by_subject(table, "subj_NOT_PRESENT") + assert out["rows"] == [] + # Columns + distinct_summary preserved verbatim. + assert out["columns"] == table["columns"] + assert out["distinct_summary"] == table["distinct_summary"] + + +def test_filter_handles_empty_table(): + table = _table([]) + out = _filter_rows_by_subject(table, "anything") + assert out is table # passthrough when no rows + + +def test_filter_falls_back_to_subject_id_key(): + """Older ingest paths used ``subjectId`` instead of + ``subjectDocumentIdentifier``; the filter accepts either.""" + table = _table([ + {"subjectId": "subj_A", "epochNumber": 1}, + {"subjectId": "subj_B", "epochNumber": 2}, + ]) + out = _filter_rows_by_subject(table, "subj_A") + assert len(out["rows"]) == 1 + assert out["rows"][0]["subjectId"] == "subj_A" + + +def test_filter_ignores_non_string_subject_values(): + table = _table([ + {"subjectDocumentIdentifier": None, "epochNumber": 1}, + {"subjectDocumentIdentifier": 42, "epochNumber": 2}, + {"subjectDocumentIdentifier": "subj_A", "epochNumber": 3}, + ]) + out = _filter_rows_by_subject(table, "subj_A") + assert len(out["rows"]) == 1 + + +def test_filter_does_not_mutate_input(): + rows = [ + {"subjectDocumentIdentifier": "subj_A", "epochNumber": 1}, + {"subjectDocumentIdentifier": "subj_B", "epochNumber": 2}, + ] + table = _table(rows) + _filter_rows_by_subject(table, "subj_A") + # Original table's rows untouched — filter returns a new envelope. + assert len(table["rows"]) == 2 + + +def test_filter_returns_full_envelope_with_filtered_rows(): + table = _table([ + {"subjectDocumentIdentifier": "subj_A", "epochNumber": 1}, + {"subjectDocumentIdentifier": "subj_B", "epochNumber": 2}, + ]) + out = _filter_rows_by_subject(table, "subj_A") + # Top-level shape preserved. + assert set(out.keys()) >= {"columns", "rows", "distinct_summary"} diff --git a/backend/tests/unit/test_tabular_query_service.py b/backend/tests/unit/test_tabular_query_service.py new file mode 100644 index 0000000..5e94e24 --- /dev/null +++ b/backend/tests/unit/test_tabular_query_service.py @@ -0,0 +1,589 @@ +"""Unit tests for TabularQueryService. + +Tests focus on the aggregation math + edge cases. The SummaryTableService +dependency is stubbed — its own tests cover the ontologyTableRow +projection logic. +""" +from __future__ import annotations + +from typing import Any + +import pytest + +from backend.services.tabular_query_service import ( + MAX_GROUPS, + MAX_VALUES_PER_GROUP, + TabularQueryService, + _percentile, + _stride_sample, + _summary_stats, +) + +# --------------------------------------------------------------------------- +# Stat-helper unit tests — pure functions, no IO +# --------------------------------------------------------------------------- + + +class TestSummaryStats: + def test_basic_stats(self): + vals = [1.0, 2.0, 3.0, 4.0, 5.0] + s = _summary_stats(vals) + assert s["count"] == 5 + assert s["mean"] == 3.0 + assert s["median"] == 3.0 + assert s["min"] == 1.0 + assert s["max"] == 5.0 + assert abs(s["std"] - 1.5811) < 0.001 + assert s["q1"] == 2.0 + assert s["q3"] == 4.0 + + def test_single_value_zero_std(self): + s = _summary_stats([7.0]) + assert s["count"] == 1 + assert s["std"] == 0.0 + assert s["mean"] == 7.0 + + def test_two_values(self): + s = _summary_stats([10.0, 20.0]) + assert s["count"] == 2 + assert s["mean"] == 15.0 + assert s["median"] == 15.0 + + +class TestPercentile: + def test_quartiles(self): + assert _percentile([1, 2, 3, 4, 5], 25) == 2.0 + assert _percentile([1, 2, 3, 4, 5], 50) == 3.0 + assert _percentile([1, 2, 3, 4, 5], 75) == 4.0 + + def test_endpoints(self): + assert _percentile([1, 2, 3, 4, 5], 0) == 1.0 + assert _percentile([1, 2, 3, 4, 5], 100) == 5.0 + + def test_empty_returns_zero(self): + assert _percentile([], 50) == 0.0 + + def test_single_value(self): + assert _percentile([42.0], 50) == 42.0 + + +class TestStrideSample: + def test_under_cap_returns_all(self): + assert _stride_sample([1.0, 2.0, 3.0], cap=10) == [1.0, 2.0, 3.0] + + def test_over_cap_preserves_endpoints(self): + vals = [float(i) for i in range(100)] + out = _stride_sample(vals, cap=10) + assert len(out) == 10 + assert out[0] == 0.0 + assert out[-1] == 99.0 + + +# --------------------------------------------------------------------------- +# Service-level: stub SummaryTableService with the real ontology_tables +# response shape (one group per distinct variableNames schema, rows are +# dicts keyed by variableName). +# --------------------------------------------------------------------------- + + +def _make_ontology_response( + columns: list[dict[str, Any]], + rows: list[dict[str, Any]], + *, + doc_ids: list[str] | None = None, +) -> dict[str, Any]: + """Build a one-group ontology_tables response matching the real + shape returned by SummaryTableService.ontology_tables. + """ + return { + "groups": [ + { + "variableNames": [c["key"] for c in columns], + "names": [c.get("label", c["key"]) for c in columns], + "ontologyNodes": [c.get("ontologyTerm") for c in columns], + "table": {"columns": columns, "rows": rows}, + "docIds": doc_ids or [], + "rowCount": len(rows), + }, + ], + } + + +class _FakeSummaryService: + """Stub for SummaryTableService — returns a canned ontology_tables payload.""" + + def __init__(self, response: dict[str, Any]) -> None: + self._response = response + + async def ontology_tables( + self, + dataset_id: str, # noqa: ARG002 — stub mirrors the real signature + *, + session: Any, # noqa: ARG002 — stub mirrors the real signature + ) -> dict[str, Any]: + return self._response + + +@pytest.mark.asyncio +async def test_violin_groups_basic(): + """Two-group violin keyed on a column label substring.""" + columns = [ + {"key": "treatment_group", "label": "treatment_group"}, + {"key": "EPM_OpenArm_Entries", "label": "EPM Open Arm Entries"}, + ] + rows = [ + {"treatment_group": "Saline", "EPM_OpenArm_Entries": 5.0}, + {"treatment_group": "Saline", "EPM_OpenArm_Entries": 7.0}, + {"treatment_group": "Saline", "EPM_OpenArm_Entries": 6.0}, + {"treatment_group": "CNO", "EPM_OpenArm_Entries": 2.0}, + {"treatment_group": "CNO", "EPM_OpenArm_Entries": 3.0}, + {"treatment_group": "CNO", "EPM_OpenArm_Entries": 1.0}, + ] + # One docId per row, parallel to `rows` per the ontology_tables + # projection contract. This lets the service route each row's + # docId to its bucket so the frontend can build per-group + # sample-row references. + doc_ids = [ + "doc_saline_1", + "doc_saline_2", + "doc_saline_3", + "doc_cno_1", + "doc_cno_2", + "doc_cno_3", + ] + response = _make_ontology_response(columns, rows, doc_ids=doc_ids) + svc = TabularQueryService(_FakeSummaryService(response)) # type: ignore[arg-type] + result = await svc.violin_groups( + "dataset_xyz", + "OpenArm", + group_by="treatment_group", + group_order=None, + session=None, + ) + assert len(result["groups"]) == 2 + by_name = {g["name"]: g for g in result["groups"]} + assert by_name["Saline"]["mean"] == 6.0 + assert by_name["CNO"]["mean"] == 2.0 + assert by_name["Saline"]["count"] == 3 + # Per-group docIds are surfaced so the frontend can build + # per-bucket sample-row references — capped at 3 per group. + assert by_name["Saline"]["docIds"] == [ + "doc_saline_1", + "doc_saline_2", + "doc_saline_3", + ] + assert by_name["CNO"]["docIds"] == [ + "doc_cno_1", + "doc_cno_2", + "doc_cno_3", + ] + assert by_name["Saline"]["totalRows"] == 3 + assert by_name["CNO"]["totalRows"] == 3 + # `source` is preserved for backwards-compat but is no longer the + # primary citation path. + assert result["source"]["document_id"] == "doc_saline_1" + assert result["xLabel"] == "treatment_group" + # Label comes from the human-readable column label, not the raw key. + assert "Open Arm Entries" in result["yLabel"] + + +@pytest.mark.asyncio +async def test_violin_groups_per_group_doc_id_cap(): + """Per-group docId list capped at MAX_DOC_IDS_PER_GROUP (3) even + when the underlying group has dozens of contributing rows.""" + columns = [ + {"key": "treatment_group", "label": "treatment_group"}, + {"key": "EPM_OpenArm_Entries", "label": "EPM Open Arm Entries"}, + ] + # 10 Saline rows, 2 CNO rows — the cap should clip Saline to 3 docs + # while CNO is left as-is. + rows = ( + [{"treatment_group": "Saline", "EPM_OpenArm_Entries": float(i)} for i in range(10)] + + [{"treatment_group": "CNO", "EPM_OpenArm_Entries": float(i)} for i in range(2)] + ) + doc_ids = [f"doc_saline_{i}" for i in range(10)] + ["doc_cno_0", "doc_cno_1"] + response = _make_ontology_response(columns, rows, doc_ids=doc_ids) + svc = TabularQueryService(_FakeSummaryService(response)) # type: ignore[arg-type] + result = await svc.violin_groups( + "dataset_xyz", + "OpenArm", + group_by="treatment_group", + group_order=None, + session=None, + ) + by_name = {g["name"]: g for g in result["groups"]} + # Saline contributes 10 rows but only 3 docIds surface — the + # frontend doesn't need all 10 as chips; the table-view citation + # already covers the full set. + assert by_name["Saline"]["docIds"] == [ + "doc_saline_0", + "doc_saline_1", + "doc_saline_2", + ] + assert by_name["Saline"]["totalRows"] == 10 + # CNO has only 2 — no cap kicks in. + assert by_name["CNO"]["docIds"] == ["doc_cno_0", "doc_cno_1"] + assert by_name["CNO"]["totalRows"] == 2 + + +@pytest.mark.asyncio +async def test_violin_groups_missing_doc_ids_tolerated(): + """When the projection desynchronizes (rows longer than docIds), + surface what's available without faking ids.""" + columns = [ + {"key": "treatment_group", "label": "treatment_group"}, + {"key": "EPM_OpenArm_Entries", "label": "EPM Open Arm Entries"}, + ] + rows = [ + {"treatment_group": "Saline", "EPM_OpenArm_Entries": 5.0}, + {"treatment_group": "Saline", "EPM_OpenArm_Entries": 7.0}, + {"treatment_group": "CNO", "EPM_OpenArm_Entries": 2.0}, + ] + # Only one docId for three rows. The service must NOT crash or + # invent IDs — Saline gets its one real id, CNO gets nothing. + response = _make_ontology_response(columns, rows, doc_ids=["doc_only_one"]) + svc = TabularQueryService(_FakeSummaryService(response)) # type: ignore[arg-type] + result = await svc.violin_groups( + "dataset_xyz", + "OpenArm", + group_by="treatment_group", + group_order=None, + session=None, + ) + by_name = {g["name"]: g for g in result["groups"]} + assert by_name["Saline"]["docIds"] == ["doc_only_one"] + assert by_name["CNO"]["docIds"] == [] + # Values + stats still computed correctly on the full row set. + assert by_name["Saline"]["totalRows"] == 2 + assert by_name["CNO"]["totalRows"] == 1 + + +@pytest.mark.asyncio +async def test_violin_groups_no_match_returns_empty_with_meta(): + columns = [{"key": "unrelated", "label": "Unrelated Variable"}] + rows = [{"unrelated": 1.0}] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "ElevatedPlusMaze", group_by="g", group_order=None, session=None, + ) + assert result["groups"] == [] + assert "no ontologyTableRow column matched" in result["_meta"]["reason"] + + +@pytest.mark.asyncio +async def test_violin_groups_respects_group_order(): + columns = [ + {"key": "group", "label": "group"}, + {"key": "y", "label": "y"}, + ] + rows = [ + {"group": "A", "y": 1.0}, + {"group": "B", "y": 2.0}, + {"group": "C", "y": 3.0}, + ] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "y", group_by="group", group_order=["C", "A"], session=None, + ) + names = [g["name"] for g in result["groups"]] + # C and A specified first; B (unspecified) appears after. + assert names == ["C", "A", "B"] + + +@pytest.mark.asyncio +async def test_violin_groups_no_group_by_makes_single_group(): + columns = [{"key": "y", "label": "Value"}] + rows = [{"y": 1.0}, {"y": 2.0}, {"y": 3.0}, {"y": 4.0}] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "Value", group_by=None, group_order=None, session=None, + ) + assert len(result["groups"]) == 1 + assert result["groups"][0]["name"] == "all" + assert result["groups"][0]["count"] == 4 + + +@pytest.mark.asyncio +async def test_violin_groups_caps_group_count(): + columns = [{"key": "g", "label": "g"}, {"key": "y", "label": "y"}] + rows = [ + {"g": f"g{i}", "y": float(i)} for i in range(MAX_GROUPS + 5) + ] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "y", group_by="g", group_order=None, session=None, + ) + assert len(result["groups"]) == MAX_GROUPS + + +@pytest.mark.asyncio +async def test_violin_groups_caps_values_per_group_but_stats_use_full(): + """Stats are computed BEFORE the value-list sampling so they remain accurate.""" + columns = [{"key": "g", "label": "g"}, {"key": "y", "label": "Value"}] + n = MAX_VALUES_PER_GROUP + 200 + rows = [{"g": "all", "y": float(i)} for i in range(n)] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "Value", group_by="g", group_order=None, session=None, + ) + g = result["groups"][0] + assert len(g["values"]) <= MAX_VALUES_PER_GROUP + expected_mean = (n - 1) / 2 + assert abs(g["mean"] - expected_mean) < 0.001 + assert g["count"] == n + + +@pytest.mark.asyncio +async def test_violin_groups_skips_nonfinite_values(): + """NaN / inf rows shouldn't blow up the aggregation.""" + columns = [{"key": "g", "label": "g"}, {"key": "y", "label": "y"}] + rows = [ + {"g": "A", "y": 1.0}, + {"g": "A", "y": 2.0}, + {"g": "A", "y": float("nan")}, + {"g": "A", "y": float("inf")}, + {"g": "B", "y": 5.0}, + ] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "y", group_by="g", group_order=None, session=None, + ) + by_name = {g["name"]: g for g in result["groups"]} + assert by_name["A"]["count"] == 2 + assert by_name["A"]["mean"] == 1.5 + + +@pytest.mark.asyncio +async def test_violin_groups_empty_substring_returns_empty(): + columns = [{"key": "y", "label": "y"}] + rows = [{"y": 1.0}] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "", group_by=None, group_order=None, session=None, + ) + assert result["groups"] == [] + assert "empty" in result["_meta"]["reason"] + + +@pytest.mark.asyncio +async def test_violin_groups_prefers_numeric_column_over_identifier(): + """Real ontologyTableRow tables often have multiple columns sharing + a topic prefix (e.g. an identifier column + measurement columns). + The matcher must skip the non-numeric identifier and pick the + numeric measurement. + """ + columns = [ + {"key": "EPM_TestIdentifier", "label": "EPM: Test Identifier"}, + {"key": "EPM_OpenArmEntries", "label": "EPM: Open Arm Entries"}, + {"key": "treatment", "label": "treatment"}, + ] + rows = [ + {"EPM_TestIdentifier": "EPM-001", "EPM_OpenArmEntries": 5.0, "treatment": "Saline"}, + {"EPM_TestIdentifier": "EPM-002", "EPM_OpenArmEntries": 7.0, "treatment": "Saline"}, + {"EPM_TestIdentifier": "EPM-003", "EPM_OpenArmEntries": 3.0, "treatment": "CNO"}, + {"EPM_TestIdentifier": "EPM-004", "EPM_OpenArmEntries": 2.0, "treatment": "CNO"}, + ] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + # Search "EPM" matches BOTH identifier and entries; should pick + # the numeric one. + result = await svc.violin_groups( + "ds", "EPM", group_by="treatment", group_order=None, session=None, + ) + assert len(result["groups"]) == 2 # Saline + CNO + by_name = {g["name"]: g for g in result["groups"]} + assert by_name["Saline"]["mean"] == 6.0 # (5+7)/2 + assert by_name["CNO"]["mean"] == 2.5 # (3+2)/2 + # And the label should be the numeric column's label. + assert "Open Arm Entries" in result["yLabel"] + + +@pytest.mark.asyncio +async def test_violin_groups_substring_groupby_resolves_to_column(): + """LLM rarely knows the exact column key. A `groupBy='Treatment'` + should resolve to `Treatment_CNOOrSalineAdministration` via + substring match. + """ + columns = [ + {"key": "value", "label": "Measurement"}, + {"key": "Treatment_CNOOrSalineAdministration", "label": "treatment: CNO or saline"}, + ] + rows = [ + {"value": 1.0, "Treatment_CNOOrSalineAdministration": "Saline"}, + {"value": 2.0, "Treatment_CNOOrSalineAdministration": "Saline"}, + {"value": 3.0, "Treatment_CNOOrSalineAdministration": "CNO"}, + {"value": 4.0, "Treatment_CNOOrSalineAdministration": "CNO"}, + ] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "Measurement", group_by="Treatment", group_order=None, session=None, + ) + assert len(result["groups"]) == 2 + by_name = {g["name"]: g for g in result["groups"]} + assert by_name["Saline"]["mean"] == 1.5 + assert by_name["CNO"]["mean"] == 3.5 + + +@pytest.mark.asyncio +async def test_violin_groups_fuzzy_substring_matches_across_underscores(): + """Stream 5.1 (2026-05-15): the column-key matcher should ignore + underscores so a query for ``OpenArmNorthEntries`` resolves the real + column ``ElevatedPlusMaze_OpenArmNorth_Entries``. Pre-fix this + returned an empty result because the contiguous-substring match + didn't bridge the underscore between "North" and "Entries". + """ + columns = [ + { + "key": "ElevatedPlusMaze_OpenArmNorth_Entries", + "label": "Elevated Plus Maze: Open Arm (North) Entries", + }, + {"key": "Treatment_CNOOrSalineAdministration", "label": "treatment"}, + ] + rows = [ + { + "ElevatedPlusMaze_OpenArmNorth_Entries": 5.0, + "Treatment_CNOOrSalineAdministration": "Saline", + }, + { + "ElevatedPlusMaze_OpenArmNorth_Entries": 7.0, + "Treatment_CNOOrSalineAdministration": "Saline", + }, + { + "ElevatedPlusMaze_OpenArmNorth_Entries": 3.0, + "Treatment_CNOOrSalineAdministration": "CNO", + }, + { + "ElevatedPlusMaze_OpenArmNorth_Entries": 2.0, + "Treatment_CNOOrSalineAdministration": "CNO", + }, + ] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + # The needle uses no underscores; the column has underscores + # between every word. Direct case-insensitive substring fails; + # alphanumeric-stripped fallback must catch it. + result = await svc.violin_groups( + "ds", + "OpenArmNorthEntries", + group_by="Treatment", + group_order=None, + session=None, + ) + assert len(result["groups"]) == 2 + by_name = {g["name"]: g for g in result["groups"]} + assert by_name["Saline"]["mean"] == 6.0 # (5+7)/2 + assert by_name["CNO"]["mean"] == 2.5 # (3+2)/2 + + +@pytest.mark.asyncio +async def test_violin_groups_groupby_fuzzy_matches_across_underscores(): + """Stream 5.1 (2026-05-15) — groupBy resolution also uses the + alphanumeric-stripped fallback. ``CNOorSaline`` (no underscore) + resolves to ``Treatment_CNOOrSalineAdministration``. + """ + columns = [ + {"key": "ElevatedPlusMaze_OpenArmEntries", "label": "EPM entries"}, + {"key": "Treatment_CNOOrSalineAdministration", "label": "treatment"}, + ] + rows = [ + {"ElevatedPlusMaze_OpenArmEntries": 1.0, "Treatment_CNOOrSalineAdministration": "Saline"}, + {"ElevatedPlusMaze_OpenArmEntries": 2.0, "Treatment_CNOOrSalineAdministration": "Saline"}, + {"ElevatedPlusMaze_OpenArmEntries": 3.0, "Treatment_CNOOrSalineAdministration": "CNO"}, + ] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", + "OpenArmEntries", + group_by="CNOorSaline", + group_order=None, + session=None, + ) + # Did we resolve to the right grouping column? Two groups land. + assert len(result["groups"]) == 2 + names = {g["name"] for g in result["groups"]} + assert names == {"Saline", "CNO"} + + +@pytest.mark.asyncio +async def test_violin_groups_precise_match_wins_over_fuzzy_when_both_present(): + """Stream 5.1 (2026-05-15) — when one column matches case- + insensitively AND another matches only via alphanumeric strip, the + precise match wins. Preserves existing semantics for direct hits. + """ + columns = [ + # Precise match for "EPM_Entries". + {"key": "EPM_Entries", "label": "EPM Entries"}, + # Fuzzy-only match — `EPMEntries` is a substring after stripping. + {"key": "EPM_Other_Entries", "label": "EPM Other Entries"}, + ] + rows = [ + {"EPM_Entries": 100.0, "EPM_Other_Entries": 999.0}, + {"EPM_Entries": 200.0, "EPM_Other_Entries": 999.0}, + ] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + # Query is the precise key — must resolve to it, not the fuzzy + # sibling. yLabel should be "EPM Entries" (precise label). + result = await svc.violin_groups( + "ds", + "EPM_Entries", + group_by=None, + group_order=None, + session=None, + ) + assert result["yLabel"] == "EPM Entries" + # Single 'all' group with the precise column's values. + assert result["groups"][0]["mean"] == 150.0 + + +@pytest.mark.asyncio +async def test_violin_groups_unresolvable_groupby_returns_empty_with_available(): + """When groupBy doesn't match any column, return empty + the list + of available columns so the caller can retry.""" + columns = [ + {"key": "value", "label": "Measurement"}, + {"key": "strain", "label": "strain"}, + ] + rows = [{"value": 1.0, "strain": "N2"}] + svc = TabularQueryService( + _FakeSummaryService(_make_ontology_response(columns, rows)), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "Measurement", group_by="NotAColumn", group_order=None, session=None, + ) + assert result["groups"] == [] + assert "no column matched groupBy" in result["_meta"]["reason"] + assert "strain" in result["_meta"]["columns"] + + +@pytest.mark.asyncio +async def test_violin_groups_no_ontology_docs_returns_empty(): + svc = TabularQueryService( + _FakeSummaryService({"groups": []}), # type: ignore[arg-type] + ) + result = await svc.violin_groups( + "ds", "anything", group_by=None, group_order=None, session=None, + ) + assert result["groups"] == [] + assert "no ontologyTableRow docs" in result["_meta"]["reason"] diff --git a/backend/tests/unit/test_tabular_query_service_cross_table.py b/backend/tests/unit/test_tabular_query_service_cross_table.py new file mode 100644 index 0000000..a0b3853 --- /dev/null +++ b/backend/tests/unit/test_tabular_query_service_cross_table.py @@ -0,0 +1,1165 @@ +"""Unit tests for TabularQueryService.cross_table_pairs (S5.3). + +Sister to test_tabular_query_service.py — same stub pattern (canned +ontology_tables / single_class payloads) so the math + edge cases +are exercised without touching the cloud. + +The cross-table-query backend was originally implemented in this arc, +lost to a git reset during agent collision, and re-implemented from the +spec preserved in +``apps/web/docs/reviews/2026-05-19b-post-handoff-execution.md``. +These tests pin the contract end-to-end. +""" +from __future__ import annotations + +from typing import Any + +import pytest + +from backend.services.tabular_query_service import ( + MAX_PAIRS, + TabularQueryService, + _build_subject_group_map, + _build_subject_value_map, + _columns_for_pair_group_by, + _find_matching_group, + _index_of_group, + _inner_join_pairs, + _inner_join_treatment_pairs, + _order_pairs_by_group, + _pick_treatment_label_for_needle, +) + +# --------------------------------------------------------------------------- +# Stub builders — mirror the real SummaryTableService.ontology_tables shape +# (one group per distinct variableNames CSV) and single_class shape +# (columns + rows + distinct_summary envelope). +# --------------------------------------------------------------------------- + + +def _make_ontology_group( + columns: list[dict[str, Any]], + rows: list[dict[str, Any]], + *, + doc_ids: list[str] | None = None, +) -> dict[str, Any]: + return { + "variableNames": [c["key"] for c in columns], + "names": [c.get("label", c["key"]) for c in columns], + "ontologyNodes": [c.get("ontologyTerm") for c in columns], + "table": {"columns": columns, "rows": rows}, + "docIds": doc_ids or [], + "rowCount": len(rows), + } + + +class _FakeSummaryService: + """Stub returning canned ontology_tables + single_class payloads.""" + + def __init__( + self, + *, + ontology_groups: list[dict[str, Any]] | None = None, + treatment_payloads: dict[str, dict[str, Any]] | None = None, + ) -> None: + self._ontology = {"groups": ontology_groups or []} + self._treatments = treatment_payloads or {} + self.single_class_calls: list[str] = [] + + async def ontology_tables( + self, + dataset_id: str, # noqa: ARG002 — stub mirrors real sig + *, + session: Any, # noqa: ARG002 + ) -> dict[str, Any]: + return self._ontology + + async def single_class( + self, + dataset_id: str, # noqa: ARG002 + class_name: str, + *, + session: Any, # noqa: ARG002 + page: int | None = None, # noqa: ARG002 + page_size: int | None = None, # noqa: ARG002 + subject_filter: str | None = None, # noqa: ARG002 + ) -> dict[str, Any]: + self.single_class_calls.append(class_name) + if class_name in self._treatments: + return self._treatments[class_name] + # Mirror the real service's "class absent → empty envelope" behavior. + return {"columns": [], "rows": [], "distinct_summary": []} + + +# --------------------------------------------------------------------------- +# Pure helper tests — no IO +# --------------------------------------------------------------------------- + + +class TestIndexOfGroup: + def test_target_found(self): + a = {"name": "a"} + b = {"name": "b"} + assert _index_of_group([a, b], b) == 1 + + def test_target_not_found(self): + a = {"name": "a"} + b = {"name": "b"} + other = {"name": "a"} # equal-by-value but not by identity + assert _index_of_group([a, b], other) == -1 + + def test_target_none(self): + assert _index_of_group([{"x": 1}], None) == -1 + + def test_empty_groups(self): + assert _index_of_group([], {"x": 1}) == -1 + + +class TestFindMatchingGroupExclude: + """Pin the exclude_group_idx kwarg added for S5.3.""" + + def test_exclude_skips_named_group(self): + rows_a = [{"valueA": 5.0}, {"valueA": 7.0}] + rows_b = [{"valueB": 10.0}] + groups = [ + _make_ontology_group( + [{"key": "valueA", "label": "Value A"}], rows_a, + ), + _make_ontology_group( + [{"key": "valueB", "label": "Value B"}], rows_b, + ), + ] + # Unrestricted: "value" matches both; first (A) wins on numeric_count tie. + unrestricted = _find_matching_group(groups, "value") + assert unrestricted is not None + assert unrestricted[1] == "valueA" + # With group A excluded, the matcher must pick group B. + restricted = _find_matching_group(groups, "value", exclude_group_idx=0) + assert restricted is not None + assert restricted[1] == "valueB" + + def test_exclude_index_out_of_range_noop(self): + rows = [{"x": 1.0}] + groups = [_make_ontology_group([{"key": "x", "label": "x"}], rows)] + # -1 should preserve the unconstrained search. + assert _find_matching_group(groups, "x", exclude_group_idx=-1) is not None + assert _find_matching_group(groups, "x", exclude_group_idx=99) is not None + + +class TestBuildSubjectValueMap: + def test_numeric_happy_path(self): + rows = [ + {"subjectDocumentIdentifier": "s1", "v": 5.0}, + {"subjectDocumentIdentifier": "s2", "v": 7.0}, + ] + doc_ids = ["doc_s1", "doc_s2"] + m = _build_subject_value_map(rows, doc_ids, "v", numeric=True) + assert m == {"s1": (5.0, "doc_s1"), "s2": (7.0, "doc_s2")} + + def test_skips_non_finite(self): + rows = [ + {"subjectDocumentIdentifier": "s1", "v": float("nan")}, + {"subjectDocumentIdentifier": "s2", "v": 7.0}, + {"subjectDocumentIdentifier": "s3", "v": float("inf")}, + ] + m = _build_subject_value_map(rows, [], "v", numeric=True) + assert set(m.keys()) == {"s2"} + + def test_skips_rows_missing_subject_id(self): + rows = [ + {"subjectDocumentIdentifier": "", "v": 1.0}, + {"v": 2.0}, # no key at all + {"subjectDocumentIdentifier": "s1", "v": 3.0}, + ] + m = _build_subject_value_map(rows, [], "v", numeric=True) + assert m == {"s1": (3.0, None)} + + def test_string_mode(self): + rows = [ + {"subjectDocumentIdentifier": "s1", "v": "Saline"}, + {"subjectDocumentIdentifier": "s2", "v": " "}, # whitespace + {"subjectDocumentIdentifier": "s3", "v": "CNO"}, + ] + m = _build_subject_value_map(rows, [], "v", numeric=False) + assert m == {"s1": ("Saline", None), "s3": ("CNO", None)} + + def test_doc_id_misalignment_tolerated(self): + rows = [ + {"subjectDocumentIdentifier": "s1", "v": 1.0}, + {"subjectDocumentIdentifier": "s2", "v": 2.0}, + ] + # Shorter doc_ids list — should not crash. + m = _build_subject_value_map(rows, ["only_one"], "v", numeric=True) + assert m["s1"] == (1.0, "only_one") + assert m["s2"] == (2.0, None) + + def test_last_write_wins_per_subject(self): + rows = [ + {"subjectDocumentIdentifier": "s1", "v": 1.0}, + {"subjectDocumentIdentifier": "s1", "v": 99.0}, + ] + m = _build_subject_value_map(rows, [], "v", numeric=True) + assert m == {"s1": (99.0, None)} + + +class TestBuildSubjectGroupMap: + def test_happy_path(self): + rows = [ + {"subjectDocumentIdentifier": "s1", "g": "A"}, + {"subjectDocumentIdentifier": "s2", "g": "B"}, + ] + assert _build_subject_group_map(rows, "g") == {"s1": "A", "s2": "B"} + + def test_skips_missing_or_empty_group_value(self): + rows = [ + {"subjectDocumentIdentifier": "s1", "g": ""}, + {"subjectDocumentIdentifier": "s2"}, # no g + {"subjectDocumentIdentifier": "s3", "g": "X"}, + ] + assert _build_subject_group_map(rows, "g") == {"s3": "X"} + + def test_stringifies_non_string_values(self): + rows = [{"subjectDocumentIdentifier": "s1", "g": 42}] + assert _build_subject_group_map(rows, "g") == {"s1": "42"} + + +class TestColumnsForPairGroupBy: + def test_excludes_value_cols_and_dedupes(self): + gx = _make_ontology_group( + [ + {"key": "valueX", "label": "X"}, + {"key": "treatment", "label": "treatment"}, + {"key": "shared", "label": "shared"}, + ], + [], + ) + gy = _make_ontology_group( + [ + {"key": "valueY", "label": "Y"}, + {"key": "shared", "label": "shared"}, # dedup with gx + {"key": "strain", "label": "strain"}, + ], + [], + ) + cols = _columns_for_pair_group_by(gx, gy, "valueX", "valueY") + # No valueX, no valueY, shared appears once. + assert "valueX" not in cols + assert "valueY" not in cols + assert cols.count("shared") == 1 + assert "treatment" in cols + assert "strain" in cols + + def test_handles_none_groups(self): + gx = _make_ontology_group([{"key": "g", "label": "g"}], []) + assert _columns_for_pair_group_by(gx, None, "g", "") == [] + + def test_caps_at_twenty(self): + many = [{"key": f"c{i}", "label": f"c{i}"} for i in range(50)] + gx = _make_ontology_group(many, []) + assert len(_columns_for_pair_group_by(gx, None, "", "")) == 20 + + +class TestInnerJoinPairs: + def test_happy_path_no_group(self): + x_map = {"s1": (1.0, "dx1"), "s2": (2.0, None)} + y_map = {"s1": (10.0, "dy1"), "s2": (20.0, "dy2")} + pairs, unjoined = _inner_join_pairs( + x_map, y_map, subject_to_group=None, + ) + assert unjoined == {"x_only": 0, "y_only": 0} + by_sid = {p["subjectId"]: p for p in pairs} + assert by_sid["s1"]["x"] == 1.0 + assert by_sid["s1"]["y"] == 10.0 + assert by_sid["s1"]["docIdX"] == "dx1" + assert by_sid["s1"]["docIdY"] == "dy1" + # s2 has no docIdX — omitted from the dict + assert "docIdX" not in by_sid["s2"] + # No groupBy → no `group` key on pairs + assert all("group" not in p for p in pairs) + + def test_with_group_assignments(self): + x_map = {"s1": (1.0, None), "s2": (2.0, None)} + y_map = {"s1": (10.0, None), "s2": (20.0, None)} + group_map = {"s1": "A", "s2": "B"} + pairs, _ = _inner_join_pairs( + x_map, y_map, subject_to_group=group_map, + ) + by_sid = {p["subjectId"]: p for p in pairs} + assert by_sid["s1"]["group"] == "A" + assert by_sid["s2"]["group"] == "B" + + def test_unjoined_counts(self): + x_map = {"s1": (1.0, None), "s2": (2.0, None), "s3": (3.0, None)} + y_map = {"s1": (10.0, None), "s4": (40.0, None)} + pairs, unjoined = _inner_join_pairs( + x_map, y_map, subject_to_group=None, + ) + assert len(pairs) == 1 + assert pairs[0]["subjectId"] == "s1" + assert unjoined == {"x_only": 2, "y_only": 1} + + def test_missing_group_assignment_keeps_pair(self): + x_map = {"s1": (1.0, None), "s2": (2.0, None)} + y_map = {"s1": (10.0, None), "s2": (20.0, None)} + # s2 absent from group_map → pair still emitted, no `group` key + group_map = {"s1": "A"} + pairs, _ = _inner_join_pairs( + x_map, y_map, subject_to_group=group_map, + ) + by_sid = {p["subjectId"]: p for p in pairs} + assert by_sid["s1"]["group"] == "A" + assert "group" not in by_sid["s2"] + + +class TestInnerJoinTreatmentPairs: + def test_auto_color_by_treatment_when_no_groupby(self): + x_map = {"s1": (1.0, "dx1"), "s2": (2.0, None)} + treatment_map = { + "s1": ("Saline", None), + "s2": ("CNO", None), + } + pairs, _ = _inner_join_treatment_pairs( + x_map, treatment_map, subject_to_group=None, + ) + by_sid = {p["subjectId"]: p for p in pairs} + assert by_sid["s1"]["y"] == "Saline" + # auto-color: group falls back to treatment label + assert by_sid["s1"]["group"] == "Saline" + assert by_sid["s2"]["group"] == "CNO" + + def test_explicit_group_overrides_treatment_color(self): + x_map = {"s1": (1.0, None)} + treatment_map = {"s1": ("Saline", None)} + group_map = {"s1": "strain_X"} + pairs, _ = _inner_join_treatment_pairs( + x_map, treatment_map, subject_to_group=group_map, + ) + assert pairs[0]["group"] == "strain_X" + # y still holds the treatment label + assert pairs[0]["y"] == "Saline" + + +class TestOrderPairsByGroup: + def test_explicit_order_applied(self): + pairs = [ + {"subjectId": "s1", "group": "A"}, + {"subjectId": "s2", "group": "B"}, + {"subjectId": "s3", "group": "C"}, + ] + ordered = _order_pairs_by_group(pairs, ["C", "A"]) + assert [p["group"] for p in ordered] == ["C", "A", "B"] + + def test_no_order_preserves_input(self): + pairs = [ + {"subjectId": "s1", "group": "B"}, + {"subjectId": "s2", "group": "A"}, + ] + assert _order_pairs_by_group(pairs, None) == pairs + + def test_empty_order_preserves_input(self): + pairs = [{"subjectId": "s1", "group": "B"}] + assert _order_pairs_by_group(pairs, []) == pairs + + def test_pairs_without_group_treated_as_sentinel(self): + pairs = [ + {"subjectId": "s1"}, + {"subjectId": "s2", "group": "A"}, + {"subjectId": "s3", "group": "B"}, + ] + ordered = _order_pairs_by_group(pairs, ["B", "A"]) + # B first, A second, s1 (no group) last + assert [p["subjectId"] for p in ordered] == ["s3", "s2", "s1"] + + +class TestPickTreatmentLabelForNeedle: + def test_needle_direct_match(self): + row = { + "treatmentName": "Eschericia coli OP50", + "treatmentReference": "REF-1", + } + assert ( + _pick_treatment_label_for_needle(row, "reference") == "REF-1" + ) + + def test_priority_fallback_when_no_needle_match(self): + row = { + "treatmentName": "OP50", + "otherKey": "OP50-other", + } + # needle="xxxx" → no direct match; falls to priority chain + # (`name` is first → row's "treatmentName" matches) + assert _pick_treatment_label_for_needle(row, "xxxx") == "OP50" + + def test_empty_needle_uses_priority(self): + row = {"treatmentName": "Saline"} + assert _pick_treatment_label_for_needle(row, "") == "Saline" + + def test_skips_empty_or_non_string(self): + row = { + "treatmentName": "", + "treatmentReference": " ", + "drugName": 42, # non-string + "drug": "Aldosterone", + } + # All fields except drug fail the non-empty-string guard + assert ( + _pick_treatment_label_for_needle(row, "drug") == "Aldosterone" + ) + + def test_returns_none_when_nothing_matches(self): + row = {"unrelated": "x"} + assert _pick_treatment_label_for_needle(row, "drug") is None + + def test_empty_row(self): + assert _pick_treatment_label_for_needle({}, "name") is None + + +# --------------------------------------------------------------------------- +# Service-level: orchestrator tests +# --------------------------------------------------------------------------- + + +# Tests below build their own group dicts inline so the fixture stays simple. + + +@pytest.mark.asyncio +async def test_subject_join_happy_path(): + """Two ontology groups; each has a numeric column; common subjects + join into pairs.""" + epm_group = _make_ontology_group( + [ + { + "key": "ElevatedPlusMaze_OpenArmTime", + "label": "EPM: Open Arm Time", + }, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + { + "ElevatedPlusMaze_OpenArmTime": 12.5, + "subjectDocumentIdentifier": "subj_1", + }, + { + "ElevatedPlusMaze_OpenArmTime": 8.0, + "subjectDocumentIdentifier": "subj_2", + }, + { + "ElevatedPlusMaze_OpenArmTime": 15.0, + "subjectDocumentIdentifier": "subj_3", + }, + ], + doc_ids=["epm_1", "epm_2", "epm_3"], + ) + fps_group = _make_ontology_group( + [ + { + "key": "FearStartle_Amplitude", + "label": "FPS: Startle Amplitude", + }, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + { + "FearStartle_Amplitude": 0.5, + "subjectDocumentIdentifier": "subj_1", + }, + { + "FearStartle_Amplitude": 1.2, + "subjectDocumentIdentifier": "subj_2", + }, + # subj_3 missing from FPS group + ], + doc_ids=["fps_1", "fps_2"], + ) + svc = TabularQueryService( + _FakeSummaryService(ontology_groups=[epm_group, fps_group]), # type: ignore[arg-type] + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArmTime", + "StartleAmplitude", + join_on="subject", + group_by=None, + group_order=None, + session=None, + ) + assert result["joinKind"] == "subject" + # 2 subjects in both groups → 2 pairs + assert len(result["pairs"]) == 2 + by_sid = {p["subjectId"]: p for p in result["pairs"]} + assert by_sid["subj_1"]["x"] == 12.5 + assert by_sid["subj_1"]["y"] == 0.5 + assert by_sid["subj_1"]["docIdX"] == "epm_1" + assert by_sid["subj_1"]["docIdY"] == "fps_1" + # subj_3 only in X → x_only=1, y_only=0 + assert result["unjoined"] == {"x_only": 1, "y_only": 0} + assert "Open Arm Time" in result["xLabel"] + assert "Startle Amplitude" in result["yLabel"] + # Source citation points to the first X docId + assert result["source"]["document_id"] == "epm_1" + + +@pytest.mark.asyncio +async def test_subject_join_with_group_by_resolves_in_x_table(): + """groupBy resolves against the X group's columns; per-subject + group label propagates onto pairs.""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM Open Arm"}, + {"key": "treatment_group", "label": "Treatment"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + { + "EPM_OpenArm": 10.0, + "treatment_group": "Saline", + "subjectDocumentIdentifier": "s1", + }, + { + "EPM_OpenArm": 20.0, + "treatment_group": "CNO", + "subjectDocumentIdentifier": "s2", + }, + ], + ) + fps_group = _make_ontology_group( + [ + {"key": "FPS_Amp", "label": "FPS Amplitude"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + {"FPS_Amp": 0.4, "subjectDocumentIdentifier": "s1"}, + {"FPS_Amp": 1.1, "subjectDocumentIdentifier": "s2"}, + ], + ) + svc = TabularQueryService( + _FakeSummaryService(ontology_groups=[epm_group, fps_group]), # type: ignore[arg-type] + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "FPS_Amp", + join_on="subject", + group_by="treatment_group", + group_order=None, + session=None, + ) + by_sid = {p["subjectId"]: p for p in result["pairs"]} + assert by_sid["s1"]["group"] == "Saline" + assert by_sid["s2"]["group"] == "CNO" + assert result["groupLabel"] == "treatment_group" + + +@pytest.mark.asyncio +async def test_subject_join_groupby_resolves_in_y_table_when_absent_in_x(): + """groupBy column exists only in the Y table → still resolves + correctly per the chat-tool contract ('searches group_x first, + then group_y').""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM Open Arm"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + {"EPM_OpenArm": 10.0, "subjectDocumentIdentifier": "s1"}, + {"EPM_OpenArm": 20.0, "subjectDocumentIdentifier": "s2"}, + ], + ) + fps_group = _make_ontology_group( + [ + {"key": "FPS_Amp", "label": "FPS Amplitude"}, + {"key": "strain_group", "label": "Strain"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + { + "FPS_Amp": 0.4, + "strain_group": "N2", + "subjectDocumentIdentifier": "s1", + }, + { + "FPS_Amp": 1.1, + "strain_group": "PR811", + "subjectDocumentIdentifier": "s2", + }, + ], + ) + svc = TabularQueryService( + _FakeSummaryService(ontology_groups=[epm_group, fps_group]), # type: ignore[arg-type] + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "FPS_Amp", + join_on="subject", + group_by="strain_group", + group_order=None, + session=None, + ) + by_sid = {p["subjectId"]: p for p in result["pairs"]} + # Y-table strain_group must propagate (X-then-Y resolution) + assert by_sid["s1"]["group"] == "N2" + assert by_sid["s2"]["group"] == "PR811" + + +@pytest.mark.asyncio +async def test_subject_join_with_group_order(): + """Explicit groupOrder applied to pair ordering.""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM Open Arm"}, + {"key": "treatment_group", "label": "Treatment"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + { + "EPM_OpenArm": 1.0, + "treatment_group": "Saline", + "subjectDocumentIdentifier": "s1", + }, + { + "EPM_OpenArm": 2.0, + "treatment_group": "CNO", + "subjectDocumentIdentifier": "s2", + }, + { + "EPM_OpenArm": 3.0, + "treatment_group": "Vehicle", + "subjectDocumentIdentifier": "s3", + }, + ], + ) + fps_group = _make_ontology_group( + [ + {"key": "FPS_Amp", "label": "FPS Amplitude"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + {"FPS_Amp": 0.1, "subjectDocumentIdentifier": "s1"}, + {"FPS_Amp": 0.2, "subjectDocumentIdentifier": "s2"}, + {"FPS_Amp": 0.3, "subjectDocumentIdentifier": "s3"}, + ], + ) + svc = TabularQueryService( + _FakeSummaryService(ontology_groups=[epm_group, fps_group]), # type: ignore[arg-type] + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "FPS_Amp", + join_on="subject", + group_by="treatment_group", + group_order=["CNO", "Saline"], + session=None, + ) + groups_in_order = [p["group"] for p in result["pairs"]] + # CNO first, Saline second, Vehicle (unspecified) last + assert groups_in_order == ["CNO", "Saline", "Vehicle"] + + +@pytest.mark.asyncio +async def test_subject_join_no_overlap_returns_diagnostic(): + """Both groups match but no subjects overlap → empty pairs + _meta.""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [{"EPM_OpenArm": 10.0, "subjectDocumentIdentifier": "s_A"}], + ) + fps_group = _make_ontology_group( + [ + {"key": "FPS_Amp", "label": "FPS"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [{"FPS_Amp": 0.4, "subjectDocumentIdentifier": "s_B"}], + ) + svc = TabularQueryService( + _FakeSummaryService(ontology_groups=[epm_group, fps_group]), # type: ignore[arg-type] + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "FPS_Amp", + join_on="subject", + group_by=None, + group_order=None, + session=None, + ) + assert result["pairs"] == [] + assert "no overlapping subjects" in result["_meta"]["reason"] + assert result["unjoined"] == {"x_only": 1, "y_only": 1} + + +@pytest.mark.asyncio +async def test_subject_join_empty_input_strings(): + svc = TabularQueryService(_FakeSummaryService(ontology_groups=[])) # type: ignore[arg-type] + result = await svc.cross_table_pairs( + "ds", "", "FPS", + join_on="subject", + group_by=None, + group_order=None, + session=None, + ) + assert result["pairs"] == [] + assert "empty xVariableContains" in result["_meta"]["reason"] + + +@pytest.mark.asyncio +async def test_subject_join_no_ontology_docs(): + svc = TabularQueryService(_FakeSummaryService(ontology_groups=[])) # type: ignore[arg-type] + result = await svc.cross_table_pairs( + "ds", "EPM", "FPS", + join_on="subject", + group_by=None, + group_order=None, + session=None, + ) + assert result["pairs"] == [] + assert "no ontologyTableRow docs" in result["_meta"]["reason"] + + +@pytest.mark.asyncio +async def test_subject_join_no_x_match(): + """X column substring doesn't match any group.""" + epm_group = _make_ontology_group( + [ + {"key": "FPS_Amp", "label": "FPS"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [{"FPS_Amp": 0.1, "subjectDocumentIdentifier": "s1"}], + ) + svc = TabularQueryService( + _FakeSummaryService(ontology_groups=[epm_group]), # type: ignore[arg-type] + ) + result = await svc.cross_table_pairs( + "ds", + "NotInData", + "FPS", + join_on="subject", + group_by=None, + group_order=None, + session=None, + ) + assert result["pairs"] == [] + assert "no ontologyTableRow column matched X" in result["_meta"]["reason"] + # The "available variable_names" diagnostic surfaces the group's names + assert "variable_names" in result["_meta"] + + +@pytest.mark.asyncio +async def test_subject_join_no_y_match_in_different_group(): + """X matched group A; Y substring matches the SAME group only → no + different-group hit → diagnostic empty pairs.""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM Open Arm"}, + {"key": "EPM_OpenTime", "label": "EPM Open Time"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + { + "EPM_OpenArm": 5.0, + "EPM_OpenTime": 8.0, + "subjectDocumentIdentifier": "s1", + }, + ], + ) + svc = TabularQueryService( + _FakeSummaryService(ontology_groups=[epm_group]), # type: ignore[arg-type] + ) + result = await svc.cross_table_pairs( + "ds", + "EPM_OpenArm", + "EPM_OpenTime", + join_on="subject", + group_by=None, + group_order=None, + session=None, + ) + assert result["pairs"] == [] + assert "no DIFFERENT ontologyTableRow column matched Y" in result["_meta"]["reason"] + + +@pytest.mark.asyncio +async def test_subject_join_unresolvable_groupby(): + """groupBy doesn't match any column in X or Y → diagnostic with + available column list.""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [{"EPM_OpenArm": 5.0, "subjectDocumentIdentifier": "s1"}], + ) + fps_group = _make_ontology_group( + [ + {"key": "FPS_Amp", "label": "FPS"}, + {"key": "strain_group", "label": "Strain"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + { + "FPS_Amp": 0.4, + "strain_group": "N2", + "subjectDocumentIdentifier": "s1", + }, + ], + ) + svc = TabularQueryService( + _FakeSummaryService(ontology_groups=[epm_group, fps_group]), # type: ignore[arg-type] + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "FPS_Amp", + join_on="subject", + group_by="DOES_NOT_EXIST", + group_order=None, + session=None, + ) + assert result["pairs"] == [] + assert "no column matched groupBy" in result["_meta"]["reason"] + # `strain_group` is present in Y; should appear in available + assert "strain_group" in result["_meta"]["columns"] + + +@pytest.mark.asyncio +async def test_subject_join_max_pairs_cap(): + """Pair count >= MAX_PAIRS+5 → response truncated + _meta diagnostic.""" + x_rows = [ + { + "EPM_OpenArm": float(i), + "subjectDocumentIdentifier": f"s{i}", + } + for i in range(MAX_PAIRS + 5) + ] + y_rows = [ + { + "FPS_Amp": float(i), + "subjectDocumentIdentifier": f"s{i}", + } + for i in range(MAX_PAIRS + 5) + ] + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + x_rows, + ) + fps_group = _make_ontology_group( + [ + {"key": "FPS_Amp", "label": "FPS"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + y_rows, + ) + svc = TabularQueryService( + _FakeSummaryService(ontology_groups=[epm_group, fps_group]), # type: ignore[arg-type] + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "FPS_Amp", + join_on="subject", + group_by=None, + group_order=None, + session=None, + ) + assert len(result["pairs"]) == MAX_PAIRS + assert "MAX_PAIRS" in result["_meta"]["reason"] + + +# --------------------------------------------------------------------------- +# Treatment-join orchestrator tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_treatment_join_happy_path(): + """X (ontologyTableRow numeric column) vs Y (treatment label) pairs + by subject, with group auto-falling back to the treatment label.""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM Open Arm"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + {"EPM_OpenArm": 5.0, "subjectDocumentIdentifier": "s1"}, + {"EPM_OpenArm": 7.0, "subjectDocumentIdentifier": "s2"}, + {"EPM_OpenArm": 9.0, "subjectDocumentIdentifier": "s3"}, + ], + ) + treatment_payload = { + "columns": [], # not consumed by the service + "rows": [ + { + "treatmentName": "Saline", + "subjectDocumentIdentifier": "s1", + }, + { + "treatmentName": "CNO", + "subjectDocumentIdentifier": "s2", + }, + # s3 has no treatment doc + ], + "distinct_summary": [], + } + svc = TabularQueryService( + _FakeSummaryService( # type: ignore[arg-type] + ontology_groups=[epm_group], + treatment_payloads={"treatment": treatment_payload}, + ), + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "name", + join_on="treatment", + group_by=None, + group_order=None, + session=None, + ) + assert result["joinKind"] == "treatment" + by_sid = {p["subjectId"]: p for p in result["pairs"]} + assert by_sid["s1"]["y"] == "Saline" + assert by_sid["s1"]["group"] == "Saline" # auto-color + assert by_sid["s2"]["y"] == "CNO" + # s3 has no treatment doc → x_only=1 + assert result["unjoined"] == {"x_only": 1, "y_only": 0} + # Auto group label sentinel + assert result["groupLabel"] == "treatment" + + +@pytest.mark.asyncio +async def test_treatment_join_walks_chain_and_picks_subclass_label(): + """treatment_drug class supersedes treatment class via last-write-wins.""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [{"EPM_OpenArm": 5.0, "subjectDocumentIdentifier": "s1"}], + ) + treatment_payload = { + "columns": [], + "rows": [ + { + "treatmentName": "BroadGenericLabel", + "subjectDocumentIdentifier": "s1", + }, + ], + "distinct_summary": [], + } + treatment_drug_payload = { + "columns": [], + "rows": [ + { + "treatmentName": "Eschericia coli OP50", + "subjectDocumentIdentifier": "s1", + }, + ], + "distinct_summary": [], + } + fake = _FakeSummaryService( # type: ignore[arg-type] + ontology_groups=[epm_group], + treatment_payloads={ + "treatment": treatment_payload, + "treatment_drug": treatment_drug_payload, + }, + ) + svc = TabularQueryService(fake) # type: ignore[arg-type] + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "name", + join_on="treatment", + group_by=None, + group_order=None, + session=None, + ) + # treatment_drug came last in the chain → its label wins. + by_sid = {p["subjectId"]: p for p in result["pairs"]} + assert by_sid["s1"]["y"] == "Eschericia coli OP50" + # Both classes were probed (treatment_transfer too; missing payload). + assert "treatment" in fake.single_class_calls + assert "treatment_drug" in fake.single_class_calls + assert "treatment_transfer" in fake.single_class_calls + + +@pytest.mark.asyncio +async def test_treatment_join_with_explicit_group_by_overrides_auto_color(): + """When group_by is set, it propagates from the X table — auto-color + by treatment label is overridden.""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM"}, + {"key": "strain_group", "label": "Strain"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [ + { + "EPM_OpenArm": 5.0, + "strain_group": "N2", + "subjectDocumentIdentifier": "s1", + }, + { + "EPM_OpenArm": 7.0, + "strain_group": "PR811", + "subjectDocumentIdentifier": "s2", + }, + ], + ) + treatment_payload = { + "columns": [], + "rows": [ + { + "treatmentName": "Saline", + "subjectDocumentIdentifier": "s1", + }, + { + "treatmentName": "CNO", + "subjectDocumentIdentifier": "s2", + }, + ], + "distinct_summary": [], + } + svc = TabularQueryService( + _FakeSummaryService( # type: ignore[arg-type] + ontology_groups=[epm_group], + treatment_payloads={"treatment": treatment_payload}, + ), + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "name", + join_on="treatment", + group_by="strain_group", + group_order=None, + session=None, + ) + by_sid = {p["subjectId"]: p for p in result["pairs"]} + # group falls to strain_group, NOT treatment name + assert by_sid["s1"]["group"] == "N2" + assert by_sid["s2"]["group"] == "PR811" + # y still holds the treatment label + assert by_sid["s1"]["y"] == "Saline" + assert result["groupLabel"] == "strain_group" + + +@pytest.mark.asyncio +async def test_treatment_join_no_treatment_docs_returns_empty(): + """Treatment chain returns nothing → diagnostic empty pairs.""" + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [{"EPM_OpenArm": 5.0, "subjectDocumentIdentifier": "s1"}], + ) + svc = TabularQueryService( + _FakeSummaryService( # type: ignore[arg-type] + ontology_groups=[epm_group], + treatment_payloads={}, # all three classes empty + ), + ) + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "name", + join_on="treatment", + group_by=None, + group_order=None, + session=None, + ) + assert result["pairs"] == [] + assert "no treatment docs matched Y" in result["_meta"]["reason"] + + +@pytest.mark.asyncio +async def test_treatment_join_treatment_class_fetch_failure_continues_chain(): + """A failing single_class on one chain step is logged + skipped; + the chain continues with the next class.""" + + class _FlakySummary: + def __init__(self, ontology_groups: list[dict[str, Any]]) -> None: + self._ontology = {"groups": ontology_groups} + self.single_class_calls: list[str] = [] + + async def ontology_tables( + self, dataset_id: str, *, session: Any, # noqa: ARG002 + ) -> dict[str, Any]: + return self._ontology + + async def single_class( + self, + dataset_id: str, # noqa: ARG002 + class_name: str, + *, + session: Any, # noqa: ARG002 + page: int | None = None, # noqa: ARG002 + page_size: int | None = None, # noqa: ARG002 + subject_filter: str | None = None, # noqa: ARG002 + ) -> dict[str, Any]: + self.single_class_calls.append(class_name) + if class_name == "treatment": + raise RuntimeError("upstream timeout") + if class_name == "treatment_drug": + return { + "columns": [], + "rows": [ + { + "treatmentName": "Recovered", + "subjectDocumentIdentifier": "s1", + }, + ], + "distinct_summary": [], + } + return {"columns": [], "rows": [], "distinct_summary": []} + + epm_group = _make_ontology_group( + [ + {"key": "EPM_OpenArm", "label": "EPM"}, + {"key": "subjectDocumentIdentifier", "label": "Subject"}, + ], + [{"EPM_OpenArm": 5.0, "subjectDocumentIdentifier": "s1"}], + ) + flaky = _FlakySummary([epm_group]) + svc = TabularQueryService(flaky) # type: ignore[arg-type] + result = await svc.cross_table_pairs( + "ds", + "OpenArm", + "name", + join_on="treatment", + group_by=None, + group_order=None, + session=None, + ) + # treatment class raised; treatment_drug recovered. + by_sid = {p["subjectId"]: p for p in result["pairs"]} + assert by_sid["s1"]["y"] == "Recovered" + + +# --------------------------------------------------------------------------- +# Shared/cross constants — keep test pinning consistent if the constant +# value ever changes. +# --------------------------------------------------------------------------- + + +def test_subject_key_constant(): + """The subject-id key on projected rows is locked to + ``subjectDocumentIdentifier`` — this is the contract every + cross-table helper assumes. Pinning to surface schema drift early.""" + from backend.services.tabular_query_service import _SUBJECT_KEY + assert _SUBJECT_KEY == "subjectDocumentIdentifier" + + +def test_treatment_class_chain_constant(): + from backend.services.tabular_query_service import _TREATMENT_CLASS_CHAIN + assert _TREATMENT_CLASS_CHAIN == ( + "treatment", + "treatment_drug", + "treatment_transfer", + ) diff --git a/backend/tests/unit/test_treatment_timeline_service.py b/backend/tests/unit/test_treatment_timeline_service.py new file mode 100644 index 0000000..3a4d34f --- /dev/null +++ b/backend/tests/unit/test_treatment_timeline_service.py @@ -0,0 +1,767 @@ +"""Unit tests for :class:`TreatmentTimelineService`. + +Both backing services are stubbed: the orchestrator under test +composes two existing services that have their own coverage. The +focus here is the projection math — ordinal slot timing, explicit +timing detection, mixed mode classification, subject cap, fallback +fan-out, and the ``empty_hint`` envelope. +""" +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from backend.services.treatment_timeline_service import ( + DEFAULT_MAX_SUBJECTS, + TreatmentTimelineService, + _classify_temporal_source, + _extract_explicit_timing, + _parse_iso_datetime, + _pick_subject_label, + _pick_treatment_label, +) + +# --------------------------------------------------------------------------- +# Pure-helper tests — no IO +# --------------------------------------------------------------------------- + + +class TestPickSubjectLabel: + def test_prefers_subject_document_identifier(self): + row = {"subjectDocumentIdentifier": "subj_001", "subject": "ignored"} + assert _pick_subject_label(row) == "subj_001" + + def test_falls_back_to_subject_field(self): + row = {"subject": "subj_alt"} + assert _pick_subject_label(row) == "subj_alt" + + def test_returns_none_when_empty(self): + assert _pick_subject_label({}) is None + assert _pick_subject_label({"subjectDocumentIdentifier": ""}) is None + + +class TestPickTreatmentLabel: + def test_prefers_treatment_name(self): + row = {"treatmentName": "Saline", "stringValue": "ignored"} + assert _pick_treatment_label(row) == "Saline" + + def test_falls_back_to_string_value(self): + row = {"stringValue": "CNO"} + assert _pick_treatment_label(row) == "CNO" + + def test_returns_none_when_empty(self): + assert _pick_treatment_label({}) is None + + def test_picks_treatment_drug_field(self): + # Stream 5.2 (2026-05-15): treatment_drug class emits its own + # label fields. Recognize them after the canonical treatmentName. + assert _pick_treatment_label({"treatment_drug": "isoamylol"}) == "isoamylol" + + def test_picks_drug_name_field(self): + assert _pick_treatment_label({"drugName": "CNO"}) == "CNO" + + def test_picks_compound_field(self): + assert _pick_treatment_label({"compound": "Saline"}) == "Saline" + + def test_treatment_name_wins_over_drug_field(self): + # Priority: treatmentName > treatment_drug > drugName > compound > stringValue. + assert ( + _pick_treatment_label( + {"treatmentName": "Saline", "treatment_drug": "ignored"}, + ) + == "Saline" + ) + + +class TestExtractExplicitTiming: + def test_numeric_value_pair(self): + row = {"numericValue": [10.0, 20.0]} + assert _extract_explicit_timing(row) == (10.0, 20.0) + + def test_numeric_value_scalar(self): + row = {"numericValue": 5.0} + assert _extract_explicit_timing(row) == (5.0, 6.0) + + def test_numeric_value_singleton_array(self): + row = {"numericValue": [7.5]} + assert _extract_explicit_timing(row) == (7.5, 8.5) + + def test_start_date_pair(self): + row = {"startDate": "2026-01-01", "endDate": "2026-01-05"} + assert _extract_explicit_timing(row) == ("2026-01-01", "2026-01-05") + + def test_start_time_end_time_pair(self): + row = {"startTime": 100.0, "endTime": 200.0} + assert _extract_explicit_timing(row) == (100.0, 200.0) + + def test_administration_start_end_time_pair(self): + # Stream 5.2 (2026-05-15): treatment_drug class emits + # administration_start_time / administration_end_time as the + # canonical timing pair. Without this, Bhar's TreatmentTimeline + # panel rendered an empty chart despite 24466 treatment_drug + # docs with explicit timing. + row = { + "administration_start_time": "2026-01-01T08:00:00Z", + "administration_end_time": "2026-01-01T08:02:00Z", + } + out = _extract_explicit_timing(row) + assert out == ( + "2026-01-01T08:00:00Z", + "2026-01-01T08:02:00Z", + ) + + def test_startDate_wins_over_administration_pair(self): + # Priority: startDate > startTime > administration_*_time. + row = { + "startDate": "2026-01-01", + "endDate": "2026-01-05", + "administration_start_time": "2026-02-01", + "administration_end_time": "2026-02-05", + } + assert _extract_explicit_timing(row) == ("2026-01-01", "2026-01-05") + + def test_numeric_value_empty_array_returns_none(self): + assert _extract_explicit_timing({"numericValue": []}) is None + + def test_nan_inf_rejected(self): + assert _extract_explicit_timing({"numericValue": float("nan")}) is None + assert _extract_explicit_timing({"numericValue": float("inf")}) is None + + def test_no_timing_returns_none(self): + assert _extract_explicit_timing({"treatmentName": "Saline"}) is None + + def test_iso_date_string_value_emits_day_window(self): + out = _extract_explicit_timing({"stringValue": "2026-05-14"}) + assert out is not None + start, end = out + assert start == "2026-05-14" + # End is the +1 day ISO string — bare date interpreted as UTC. + assert isinstance(end, str) and end.startswith("2026-05-15") + + +class TestParseIsoDatetime: + def test_bare_date(self): + out = _parse_iso_datetime("2026-05-14") + assert out is not None + assert out.year == 2026 and out.month == 5 and out.day == 14 + + def test_z_suffix(self): + out = _parse_iso_datetime("2026-05-14T12:00:00Z") + assert out is not None + assert out.hour == 12 + + def test_garbage_returns_none(self): + assert _parse_iso_datetime("not a date") is None + + +class TestClassifyTemporalSource: + def test_all_explicit(self): + assert _classify_temporal_source(5, 0) == "explicit" + + def test_all_ordinal(self): + assert _classify_temporal_source(0, 5) == "ordinal" + + def test_mixed(self): + assert _classify_temporal_source(3, 2) == "mixed" + + def test_neither_defaults_ordinal(self): + assert _classify_temporal_source(0, 0) == "ordinal" + + +# --------------------------------------------------------------------------- +# Service-level: stub both backing services +# --------------------------------------------------------------------------- + + +def _make_service( + *, + primary_response: dict[str, Any] | None = None, + primary_raises: Exception | None = None, + fallback_response: dict[str, Any] | None = None, + fallback_raises: Exception | None = None, +) -> TreatmentTimelineService: + """Compose a service whose backing dependencies return canned + payloads. Either response or raises wins — use raises to simulate + cloud failures. + """ + summary = AsyncMock() + if primary_raises is not None: + summary.single_class.side_effect = primary_raises + else: + # 2026-05-19 (F-1e merge semantics) — the service now queries + # ALL classes in `_TREATMENT_CLASS_CHAIN` and merges their + # rows. For these tests the legacy intent is "primary returns + # this response under the `treatment` class; the others + # contribute nothing." Honor that by returning the canned + # response ONLY when the class argument is `treatment`, + # otherwise empty. Tests that exercise multi-class merging + # explicitly opt-in via `summary.single_class.side_effect = …` + # after the fixture returns. + canned = primary_response or {"columns": [], "rows": []} + empty = {"columns": [], "rows": []} + + async def _single_class_dispatch( + _dataset_id: str, class_name: str, **_kwargs: Any, + ) -> dict[str, Any]: + return canned if class_name == "treatment" else empty + + summary.single_class.side_effect = _single_class_dispatch + + tabular = AsyncMock() + if fallback_raises is not None: + tabular.violin_groups.side_effect = fallback_raises + else: + tabular.violin_groups.return_value = fallback_response or { + "groups": [], + "yLabel": "", + "xLabel": "", + } + + return TreatmentTimelineService(summary=summary, tabular=tabular) + + +@pytest.mark.asyncio +async def test_primary_happy_path_explicit_timing(): + """5 treatments across 3 subjects with explicit numericValue — + items returned in first-seen subject order, all timing + explicit so temporal_source='explicit'. + """ + rows = [ + { + "subjectDocumentIdentifier": "subj_A", + "treatmentName": "Saline", + "numericValue": [0.0, 10.0], + }, + { + "subjectDocumentIdentifier": "subj_A", + "treatmentName": "CNO", + "numericValue": [10.0, 20.0], + }, + { + "subjectDocumentIdentifier": "subj_B", + "treatmentName": "Saline", + "numericValue": [0.0, 15.0], + }, + { + "subjectDocumentIdentifier": "subj_C", + "treatmentName": "Saline", + "numericValue": [0.0, 12.0], + }, + { + "subjectDocumentIdentifier": "subj_C", + "treatmentName": "CNO", + "numericValue": [12.0, 24.0], + }, + ] + svc = _make_service( + primary_response={ + "columns": [{"key": "treatmentName"}, {"key": "subjectDocumentIdentifier"}], + "rows": rows, + }, + ) + result = await svc.compute_timeline( + "ds_xyz", title="My Timeline", max_subjects=30, session=None, + ) + assert result["total_subjects"] == 3 + assert result["total_treatments"] == 5 + assert result["temporal_source"] == "explicit" + assert result["title"] == "My Timeline" + assert result["datasetId"] == "ds_xyz" + # First-seen ordering of subjects: A, B, C. + subjects = [item["subject"] for item in result["items"]] + assert subjects == ["subj_A", "subj_A", "subj_B", "subj_C", "subj_C"] + # Timing is the literal numericValue pair. + assert result["items"][0]["start"] == 0.0 + assert result["items"][0]["end"] == 10.0 + # No empty_hint when items are produced. + assert "empty_hint" not in result + + +@pytest.mark.asyncio +async def test_ordinal_timing_when_numeric_value_missing(): + """Rows without any explicit timing get per-subject ordinal slots + [0,1], [1,2], etc. temporal_source='ordinal'. + """ + rows = [ + {"subjectDocumentIdentifier": "S1", "treatmentName": "T1"}, + {"subjectDocumentIdentifier": "S1", "treatmentName": "T2"}, + {"subjectDocumentIdentifier": "S1", "treatmentName": "T3"}, + {"subjectDocumentIdentifier": "S2", "treatmentName": "T1"}, + {"subjectDocumentIdentifier": "S2", "treatmentName": "T2"}, + ] + svc = _make_service(primary_response={"columns": [], "rows": rows}) + result = await svc.compute_timeline( + "ds_a", title=None, max_subjects=30, session=None, + ) + assert result["temporal_source"] == "ordinal" + assert result["total_subjects"] == 2 + assert result["total_treatments"] == 5 + items = result["items"] + # S1's three treatments: [0,1], [1,2], [2,3]. + s1 = [it for it in items if it["subject"] == "S1"] + assert [it["start"] for it in s1] == [0, 1, 2] + assert [it["end"] for it in s1] == [1, 2, 3] + # S2's two treatments: [0,1], [1,2]. Per-subject counter resets. + s2 = [it for it in items if it["subject"] == "S2"] + assert [it["start"] for it in s2] == [0, 1] + assert [it["end"] for it in s2] == [1, 2] + + +@pytest.mark.asyncio +async def test_mixed_timing_classification(): + """Some rows explicit, some ordinal → temporal_source='mixed'.""" + rows = [ + # Explicit timing. + { + "subjectDocumentIdentifier": "S1", + "treatmentName": "T1", + "numericValue": [0.0, 5.0], + }, + # Same subject, ordinal — counter is independent of the + # explicit row's range. + {"subjectDocumentIdentifier": "S1", "treatmentName": "T2"}, + # Different subject, also explicit. + { + "subjectDocumentIdentifier": "S2", + "treatmentName": "T1", + "numericValue": [0.0, 10.0], + }, + ] + svc = _make_service(primary_response={"columns": [], "rows": rows}) + result = await svc.compute_timeline( + "ds_a", title=None, max_subjects=30, session=None, + ) + assert result["temporal_source"] == "mixed" + # 2 explicit + 1 ordinal = 3 total items. + assert result["total_treatments"] == 3 + + +@pytest.mark.asyncio +async def test_max_subjects_cap_drops_excess(): + """50 distinct subjects with maxSubjects=30 → only 30 surface in + ``items``. total_subjects reflects the in-chart count (30), not the + underlying count — that's the TS handler's contract: the chart + truncates and the caller surfaces the truncation count via the + ``cited`` vs ``total_subjects`` ratio at the chat-prompt layer. + """ + rows = [ + {"subjectDocumentIdentifier": f"subj_{i}", "treatmentName": "Saline"} + for i in range(50) + ] + svc = _make_service(primary_response={"columns": [], "rows": rows}) + result = await svc.compute_timeline( + "ds_a", title=None, max_subjects=30, session=None, + ) + assert result["total_subjects"] == 30 + assert result["total_treatments"] == 30 + # Distinct subjects in items. + distinct = {it["subject"] for it in result["items"]} + assert len(distinct) == 30 + + +@pytest.mark.asyncio +async def test_primary_empty_fallback_hits_synthesizes_group_rows(): + """Zero treatment rows; tabular_query has 2 groups. Synthesize one + bar per group with subject='group:' and ordinal timing. + """ + fallback = { + "groups": [ + {"name": "Saline", "count": 12}, + {"name": "CNO", "count": 9}, + ], + "yLabel": "Treatment: CNO or Saline Administration", + "xLabel": "", + } + svc = _make_service( + primary_response={"columns": [], "rows": []}, + fallback_response=fallback, + ) + result = await svc.compute_timeline( + "ds_a", title=None, max_subjects=30, session=None, + ) + assert result["total_subjects"] == 2 + assert result["total_treatments"] == 2 + # Subject labels prefixed with group: so callers can disambiguate + # synthesized vs real subject rows. + subjects = sorted(it["subject"] for it in result["items"]) + assert subjects == ["group:CNO", "group:Saline"] + treatments = sorted(it["treatment"] for it in result["items"]) + assert treatments == ["CNO", "Saline"] + # All synthesized → temporal_source='ordinal'. + assert result["temporal_source"] == "ordinal" + + +@pytest.mark.asyncio +async def test_primary_empty_fallback_empty_surfaces_empty_hint(): + """Both backends return nothing — empty_hint surfaced with the + 'no temporal info' reason and available_columns from whatever + column list the primary did expose. + """ + svc = _make_service( + primary_response={ + "columns": [ + {"key": "treatmentName"}, + {"key": "subjectDocumentIdentifier"}, + ], + "rows": [], + }, + fallback_response={"groups": [], "yLabel": "", "xLabel": ""}, + ) + result = await svc.compute_timeline( + "ds_a", title=None, max_subjects=30, session=None, + ) + assert result["items"] == [] + assert result["total_subjects"] == 0 + assert result["total_treatments"] == 0 + assert "empty_hint" in result + hint = result["empty_hint"] + assert "no temporal info" in hint["reason"] + # available_columns echoes the column keys the primary table + # exposed even though the row list was empty — gives the caller + # something to mention. + assert "treatmentName" in hint["available_columns"] + assert "subjectDocumentIdentifier" in hint["available_columns"] + + +# --------------------------------------------------------------------------- +# Defensive edge cases — behavior not strictly required by the brief but +# locked here to prevent regressions when the orchestrator evolves. +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_rows_without_subject_or_treatment_dropped(): + """Rows missing subject OR treatment are silently skipped — they + can't be plotted. No empty_hint when at least one row plots. + """ + rows = [ + # Plottable. + {"subjectDocumentIdentifier": "S1", "treatmentName": "T1"}, + # Missing subject. + {"treatmentName": "T2"}, + # Missing treatment. + {"subjectDocumentIdentifier": "S2"}, + ] + svc = _make_service(primary_response={"columns": [], "rows": rows}) + result = await svc.compute_timeline( + "ds_a", title=None, max_subjects=30, session=None, + ) + assert result["total_treatments"] == 1 + assert "empty_hint" not in result + + +@pytest.mark.asyncio +async def test_rows_returned_but_unplottable_surfaces_hint(): + """When rows come back but NONE have a usable subject+treatment + pair, the hint reason distinguishes that from the empty-rows case. + """ + rows = [ + {"treatmentName": "T1"}, # No subject. + {"subjectDocumentIdentifier": "S1"}, # No treatment. + ] + svc = _make_service(primary_response={"columns": [], "rows": rows}) + result = await svc.compute_timeline( + "ds_a", title=None, max_subjects=30, session=None, + ) + assert result["items"] == [] + assert "empty_hint" in result + assert "none had a usable" in result["empty_hint"]["reason"] + + +@pytest.mark.asyncio +async def test_primary_failure_falls_through_to_fallback(): + """If the primary call raises (cloud unreachable, etc.), the + service catches and tries the fallback — does NOT propagate the + error out of compute_timeline. + """ + svc = _make_service( + primary_raises=RuntimeError("cloud unreachable"), + fallback_response={ + "groups": [{"name": "Saline"}], + "yLabel": "Treatment", + "xLabel": "", + }, + ) + result = await svc.compute_timeline( + "ds_a", title=None, max_subjects=30, session=None, + ) + # Fallback produced a row; no error surfaced. + assert result["total_treatments"] == 1 + assert result["items"][0]["subject"] == "group:Saline" + + +@pytest.mark.asyncio +async def test_both_failures_surface_empty_hint_not_exception(): + """Catastrophic — both backends raise. The endpoint still returns + a well-typed response with empty_hint set. + """ + svc = _make_service( + primary_raises=RuntimeError("primary down"), + fallback_raises=RuntimeError("fallback down"), + ) + result = await svc.compute_timeline( + "ds_a", title=None, max_subjects=30, session=None, + ) + assert result["items"] == [] + assert "empty_hint" in result + + +@pytest.mark.asyncio +async def test_default_max_subjects_constant_used_by_router(): + """The router uses DEFAULT_MAX_SUBJECTS as the model default; lock + the constant value so a silent bump in the service doesn't change + the public contract. + """ + assert DEFAULT_MAX_SUBJECTS == 30 + + +# --------------------------------------------------------------------------- +# B3 (2026-05-18) — MATLAB ``datestr`` parsing for Haley-style literal +# treatment docs whose ``string_value`` is "DD-MMM-YYYY HH:MM:SS". +# +# Before this fix every Haley row fell through to ordinal timing +# (parse failure → None → ordinal slot). The chart still rendered, +# but the x-axis was synthetic ordinals 0..N instead of the real +# food-restriction onset/offset wall times. +# --------------------------------------------------------------------------- + + +class TestParseMatlabDatestr: + def test_dd_mmm_yyyy_hh_mm_ss(self): + # Haley's literal treatment shape — onset/offset times serialized + # by MATLAB's default ``datestr``. + out = _parse_iso_datetime("03-Nov-2023 07:53:00") + assert out is not None + assert out.year == 2023 + assert out.month == 11 + assert out.day == 3 + assert out.hour == 7 + assert out.minute == 53 + + def test_dd_mmm_yyyy_date_only(self): + # Date-only courtesy variant. + out = _parse_iso_datetime("14-Nov-2023") + assert out is not None + assert out.year == 2023 + assert out.month == 11 + assert out.day == 14 + + def test_iso_still_parses_after_matlab_branch_added(self): + # Regression: adding the MATLAB branch must not break the ISO + # path. `fromisoformat` still wins for ISO-shaped inputs. + out = _parse_iso_datetime("2026-05-14T12:00:00Z") + assert out is not None + assert out.hour == 12 + assert out.tzinfo is not None # Z normalized to +00:00 + + def test_garbage_still_returns_none(self): + # Neither ISO nor MATLAB matches → None, caller falls back to ordinal. + assert _parse_iso_datetime("not a date") is None + assert _parse_iso_datetime("32-Jan-2023") is None # impossible day + + def test_matlab_string_via_string_value_path(self): + # End-to-end: a row with a MATLAB datestr `stringValue` now + # produces explicit timing rather than None (which previously + # forced the caller to fall back to ordinal). + out = _extract_explicit_timing( + {"stringValue": "03-Nov-2023 07:53:00"}, + ) + assert out is not None + start, end = out + # Contract: start returned verbatim (matches the ISO behaviour). + assert start == "03-Nov-2023 07:53:00" + # End is the +1 day ISO string. + assert isinstance(end, str) and end.startswith("2023-11-04T07:53:00") + + +# --------------------------------------------------------------------------- +# B3 — `_fetch_primary_rows` covers literal-only datasets. +# +# Repro: Haley (`682e7772…`) publishes 56 LITERAL `treatment` docs +# and ZERO `treatment_drug` / `treatment_transfer` docs. The chain +# walker MUST surface the 56 literal rows in the merged result rather +# than silently masking them when the two subclass probes come back +# empty. This complements the existing fixture which exercises the +# inverse case (subclasses-only on Bhar). +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_primary_rows_literal_only_haley_shape(): + """Haley shape: 56 literal `treatment` rows + 0 subclass rows in + the chain → merged result is the 56 literal rows, contributing + classes is exactly ['treatment'].""" + haley_treatment_rows = [ + { + "treatmentName": "treatment: food restriction onset time", + "treatmentOntology": "EMPTY:0000202", + "numericValue": [], + "stringValue": "03-Nov-2023 07:53:00", + "subjectDocumentIdentifier": f"subj-{i:03d}", + } + for i in range(56) + ] + summary = AsyncMock() + + async def _dispatch(_dataset_id: str, class_name: str, **_kwargs: Any) -> dict[str, Any]: + if class_name == "treatment": + return { + "columns": [ + {"key": "treatmentName"}, + {"key": "stringValue"}, + {"key": "subjectDocumentIdentifier"}, + ], + "rows": haley_treatment_rows, + } + # treatment_drug + treatment_transfer come back EMPTY for Haley. + return {"columns": [], "rows": []} + + summary.single_class.side_effect = _dispatch + tabular = AsyncMock() + svc = TreatmentTimelineService(summary=summary, tabular=tabular) + + result = await svc.compute_timeline( + "682e7772cdf3f24938176fac", + title=None, + max_subjects=30, + session=None, + ) + + # All 56 rows surface; capped at 30 distinct subjects (each row has + # a unique subject in this fixture). + assert result["total_subjects"] == 30 + assert result["total_treatments"] == 30 # one item per subject up to cap + # Now-explicit timing thanks to MATLAB datestr parsing — temporal_source + # is `"explicit"` rather than `"ordinal"` for Haley going forward. + assert result["temporal_source"] == "explicit" + assert "empty_hint" not in result + + +@pytest.mark.asyncio +async def test_primary_rows_merges_literal_and_subclass_rows(): + """If a dataset emits BOTH literal `treatment` AND `treatment_drug` + rows, the merged result must include both — F-1e's merge-all + semantics (no first-non-empty-wins short-circuit).""" + summary = AsyncMock() + + async def _dispatch(_dataset_id: str, class_name: str, **_kwargs: Any) -> dict[str, Any]: + if class_name == "treatment": + return { + "columns": [{"key": "treatmentName"}], + "rows": [ + { + "treatmentName": "Saline", + "subjectDocumentIdentifier": "subj-001", + }, + ], + } + if class_name == "treatment_drug": + return { + "columns": [ + {"key": "treatmentName"}, + {"key": "numericValue"}, + ], + "rows": [ + { + "treatmentName": "CNO", + "numericValue": [10.0, 20.0], + "subjectDocumentIdentifier": "subj-002", + }, + ], + } + return {"columns": [], "rows": []} + + summary.single_class.side_effect = _dispatch + tabular = AsyncMock() + svc = TreatmentTimelineService(summary=summary, tabular=tabular) + + result = await svc.compute_timeline( + "merged-dataset", + title=None, + max_subjects=30, + session=None, + ) + + assert result["total_subjects"] == 2 + assert result["total_treatments"] == 2 + # Mixed: ordinal subj-001 (literal, no timing) + explicit subj-002 (numeric pair). + assert result["temporal_source"] == "mixed" + + +# --------------------------------------------------------------------------- +# B3 — `_row_treatment` with literal-class doc shape (sanity-check the +# canonical legacy branch isn't accidentally regressed by future +# subclass-handling edits in summary_table_service). +# --------------------------------------------------------------------------- + + +def test_row_treatment_literal_haley_shape(): + """Direct projection check against the exact Haley doc shape (curl'd + from the experimental backend 2026-05-18). Critically, when + ``treatment.numeric_value`` is an empty list the projection MUST + surface that as ``numericValue: []`` (preserved, not dropped) so the + downstream Gantt projection's empty-list branch handles it; and + when ``treatment.string_value`` is a MATLAB datestr the row MUST + preserve the verbatim string so `_extract_explicit_timing` can + parse it.""" + from backend.services.summary_table_service import _row_treatment + + haley_doc: dict[str, Any] = { + "id": "68c0558ef81ed200dc9a1c14", + "ndiId": "41269430c7715eb4_c0dc0fdbb2efc9a6", + "name": "", + "className": "treatment", + "datasetId": "682e7772cdf3f24938176fac", + "data": { + "base": { + "id": "41269430c7715eb4_c0dc0fdbb2efc9a6", + "name": "", + }, + "depends_on": [ + { + "name": "subject_id", + "value": "41269430c748e734_c0de47463bc5a7a4", + }, + ], + "treatment": { + "ontologyName": "EMPTY:0000202", + "name": "treatment: food restriction onset time", + "numeric_value": [], + "string_value": "14-Nov-2023 12:50:00", + }, + }, + } + + out = _row_treatment(haley_doc) + + assert out["treatmentName"] == "treatment: food restriction onset time" + assert out["treatmentOntology"] == "EMPTY:0000202" + assert out["numericValue"] == [] + assert out["stringValue"] == "14-Nov-2023 12:50:00" + assert out["subjectDocumentIdentifier"] == "41269430c748e734_c0de47463bc5a7a4" + + +def test_row_treatment_literal_with_iso_string_value(): + """Cover the ISO-string variant of the same literal `treatment` + branch — ensures the projection isn't accidentally MATLAB-only.""" + from backend.services.summary_table_service import _row_treatment + + doc: dict[str, Any] = { + "data": { + "depends_on": [{"name": "subject_id", "value": "subj-iso"}], + "treatment": { + "name": "ISO-flavored treatment", + "ontologyName": "FAKE:001", + "numeric_value": 42.0, + "string_value": "2024-03-15T12:00:00Z", + }, + }, + } + + out = _row_treatment(doc) + assert out["treatmentName"] == "ISO-flavored treatment" + assert out["numericValue"] == 42.0 + assert out["stringValue"] == "2024-03-15T12:00:00Z" + assert out["subjectDocumentIdentifier"] == "subj-iso" diff --git a/docs/plans/2026-05-13-ndi-python-integration.md b/docs/plans/2026-05-13-ndi-python-integration.md new file mode 100644 index 0000000..cc2fe48 --- /dev/null +++ b/docs/plans/2026-05-13-ndi-python-integration.md @@ -0,0 +1,340 @@ +# NDI-python integration plan — backend signal/edit layer + +**Status:** Draft for user review. **No backend code has been written yet.** +**Audience:** Audri. Companion to `ndi-cloud-app/apps/web/docs/specs/2026-05-13-ask-checkpoint-pre-compact.md`. +**Author:** Claude Sonnet 4.5 (1M context), 2026-05-13. + +## TL;DR + +Three phases, escalating risk. All work happens on `ndi-data-browser-v2` (the Railway FastAPI). The `feat/experimental-ask-chat` branch in `ndi-cloud-app` is **NOT** touched — it stays draft + DO NOT MERGE. The benefit chain is bottom-up: every phase that lands on ndb-v2 main flows automatically to (a) the live Document Explorer / QuickPlot, (b) the experimental Ask chat preview, and (c) the upcoming Data Browser product. + +1. **Phase A — vlt-only install (~1 day, low risk).** Add `vhlab-toolbox-python` to the Railway image (which pulls `vlt`), `apt-get install -y git` so pip can fetch the git-sourced source, and ~10 LOC in `binary_service.py` to call `vlt.file` instead of returning the `"vlt library not available"` soft error. **Unlocks Haley VHSB position-trace plotting — immediately benefits the live Document Explorer's QuickPlot for every VHSB dataset, and unblocks the Ask chat's chart prompt for Haley.** No architectural change; existing routes and tests untouched. +2. **Phase B — replace inline parsers with `database_openbinarydoc` (~1 week, medium-high risk).** Install full NDI-python (which pulls `did`, `ndr`, `vhlab-toolbox-python`, `ndi-compress`). Add a startup/cron job to call `ndi.cloud.orchestration.downloadDataset(dataset_id, /data/ndi/{id})` against a Railway persistent volume. Refactor `BinaryService.get_timeseries` to call `dataset.database_openbinarydoc(doc, filename)`. Feature-flag the swap for an A/B week; rollback = flag flip. +3. **Phase C — new rich endpoints (~1-2 weeks, low risk because additive).** New routes that *only* exist because we have NDI-python: `POST /api/datasets/:id/ndiquery` (Mongo-style structured queries via `ndi.query.Query`), `POST /api/datasets/:id/documents/:docId/edit` (auth-gated, foundation for the Data Browser product), and `GET /api/datasets/:id/elements/:elementId/native` (`ndi.element`-backed). Existing routes unchanged. + +**Phase A is the only phase that should ship before Audri reviews this spec.** Phases B & C need design buy-in on the cache + volume strategy (and Phase C scope). + +## Pre-flight state (2026-05-13, ~21:20 UTC) + +- PR #111 merged to ndb-v2 main (commit `c5b02884`) — Railway auto-deploying the `?file=` param fix +- Ask RAG index re-baked with `binarySignalExample` sidecar (staging v3 → production, atomic) +- ndi-cloud-app `feat/experimental-ask-chat` branch (PR #160, draft) is "demo-ready" for the NBF chart path; Haley VHSB still soft-errors +- 8 published NDI Commons datasets: 3 of them are tutorial-having (Bhar, Haley, Dabrowska); Haley is VHSB-formatted + +## Day 0 spike findings (research-only, no code) + +### F1. The chatbots are general-purpose by design — NDI-python integration is exactly the new capability the Ask chat (and the data browser) needs + +`vh-lab-chatbot` and `shrek-lab-chatbot` **don't import NDI-python.** They're general-purpose lab-document RAG systems (PDF / HTML / xlsx → pgvector + Voyage + Claude). That's deliberate. The Ask chat in `ndi-cloud-app` is the first product that *will* use NDI-python — both for richer chatbot answers (plotting, provenance walks, structured queries) AND to power richer data-browser interactions (public QuickPlot expansion, private dataset editing). The canonical "how to use NDI-python with cloud datasets" reference therefore comes from NDI-python's own surface: **`src/ndi/cloud/` package + `tests/test_cloud_*.py`** suite + the published tutorials. + +### F2. Cloud connectivity = `ndi.cloud.orchestration.downloadDataset` + +There is **no `ndi.cloud.Dataset(dataset_id)` lazy constructor**. The entry point is: + +```python +from ndi.cloud.orchestration import downloadDataset +from ndi.cloud.client import CloudClient + +dataset = downloadDataset( + cloud_dataset_id="682e7772cdf3f24938176fac", # Haley + target_folder="/data/ndi/682e7772cdf3f24938176fac", + sync_files=False, # binaries lazy + client=CloudClient.from_env(), +) +``` + +It performs (per `NDI-python/src/ndi/cloud/orchestration.py:23-186`): + +1. Eagerly downloads ALL JSON documents from Mongo (chunked bulk ZIPs) — **multi-minute for real datasets** (Haley = 78K docs, ~16 GB; Carbon-fiber test = 743 docs, ~9.7 GB) +2. Rewrites each binary-file `location` in document properties to an `ndic://{dataset_id}/{file_uid}` URI +3. Materializes a local `ndi_dataset_dir` under `target_folder/{cloud_dataset_id}` +4. Stashes the authenticated `CloudClient` on the returned object as `dataset.cloud_client` + +After that, binary files materialize **lazily** on the first `database_openbinarydoc(doc, filename)` call via presigned S3 URLs (`session/session_base.py:553-628` → `cloud/filehandler.py:121-177`). No `boto3` — direct `requests.get(url, stream=True)`. + +**Implication:** `downloadDataset` cannot be a per-request operation. It has to run at startup or on a cron, against a persistent volume. + +### F3. `vlt` is provided by `vhlab-toolbox-python` + +Verified at `NDI-python/src/ndi/check.py:72-74`: + +```python +# vhlab-toolbox-python +ok, detail = _try_import("vlt") +check("vhlab-toolbox-python (vlt)", ok, detail) +``` + +The git-pin is `vhlab-toolbox-python @ git+https://github.com/VH-Lab/vhlab-toolbox-python.git@main` (from `NDI-python/pyproject.toml:39`). Installing it alone gives us `vlt` without pulling all of NDI-python. + +### F4. NDI-python git-sourced deps need `git` in the Docker image + +NDI-python's four git-sourced pip deps (`did`, `ndr`, `vhlab-toolbox-python`, `ndi-compress`) need `git` available at install time. Current `infra/Dockerfile` does NOT install git (only `libjpeg62-turbo libtiff6 curl`). One `apt-get install` line fixes it. + +### F5. Current `binary_service.py` has an early-return for text-VHSB + +`backend/services/binary_service.py:164-184` (post-PR-#111): + +```python +head = payload[:5] if len(payload) >= 5 else b"" +if head.startswith(b"This "): + return _timeseries_error( + "vlt_library", + "vlt library is not available on this server — full VHSB " + "decoding requires the DID-python `vlt` extension. ...", + ) +try: + if name.endswith(".vhsb") or (payload[:4] == b"VHSB"): + return _parse_vhsb(payload) + return _parse_nbf(payload) +except Exception as e: + ... +``` + +So the current code: +- **Handles binary-magic VHSB** (`b"VHSB"` prefix, 24-byte header, float32 body) via the inline `_parse_vhsb` +- **Bails on text-header VHSB** (`This is a VHSB file, http://github.com/VH-Lab\n...`) with a soft error — this is the variant vlt handles + +**Critical:** Audri's "Phase A is just `pip install vlt` with zero code changes" assumption is **off by one short edit**. The current code never tries to import vlt — the soft error is an early return on payload prefix. We'd need ~10 LOC to actually call `vlt.file` after installing vhlab-toolbox-python. + +### F6. ndb-v2 already has numpy + scipy + +`backend/requirements.txt` already pins `numpy>=2.0.0` and `scipy>=1.14.0`. These are the heavy NDI-python deps. Image growth from adding vhlab-toolbox-python alone is modest (~10-20 MB). Full NDI-python adds ~80-150 MB (did, ndr, ndi-compress + their numpy/networkx/jsonschema/openMINDS overlap with what's already there). + +### F7. ndb-v2's ADR-009 bans `httpx`/`requests`/`aiohttp` in `services/` + +NDI-python uses `requests` internally. The ADR-009 ban (per `backend/pyproject.toml:90-94`) is **path-scoped** — it forbids importing these libs *inside* `backend/services/`. NDI-python's own use of `requests` is fine (it's a sub-package import, not a direct service import). But if we wrap NDI-python in a `backend/services/ndi_python_service.py`, that wrapper can't directly `import requests` — only NDI-python can. Per-file carve-outs are possible if needed; this is a containable lint problem. + +--- + +## Phase A — vlt-only install (the "free win" with one small caveat) + +**Goal:** Unblock text-header VHSB decoding so Haley's position traces become plottable in the Document Explorer and the Ask chat. Everything else stays exactly as it is today. + +**Scope:** the smallest possible change. + +### A.1 Files to modify + +| File | Change | LOC | Why | +|---|---|---|---| +| `infra/Dockerfile` | `apt-get install -y git` in the Stage 2 system-deps line | +1 | Required for pip to fetch the git-sourced `vhlab-toolbox-python` | +| `backend/requirements.txt` | Add `vhlab-toolbox-python @ git+https://github.com/VH-Lab/vhlab-toolbox-python.git@main` | +1 | Brings in `vlt` | +| `backend/pyproject.toml` | Same addition to `dependencies` | +1 | Mirror — pyproject is the source of truth for dev installs | +| `backend/services/binary_service.py` | Replace lines 164-171 (the soft-error early return) with a vlt call | ~10-15 | Actually use vlt to decode text-VHSB | +| `backend/tests/unit/test_binary_shape.py` | Add a text-VHSB fixture + decode test | +30-50 | Regression coverage | + +### A.2 Concrete `binary_service.py` change + +Current: +```python +head = payload[:5] if len(payload) >= 5 else b"" +if head.startswith(b"This "): + return _timeseries_error( + "vlt_library", + "vlt library is not available on this server — ..." + ) +try: + if name.endswith(".vhsb") or (payload[:4] == b"VHSB"): + return _parse_vhsb(payload) + return _parse_nbf(payload) +``` + +Proposed: +```python +head = payload[:5] if len(payload) >= 5 else b"" +if head.startswith(b"This "): + # Text-header VHSB ("This is a VHSB file, http://github.com/VH-Lab\n…") + # — DID-python's vlt extension parses the typed binary slots that + # follow the text header. Lazy import so a missing vlt downgrades + # cleanly rather than blowing up the worker. + try: + from vlt.file import vhsb_read # type: ignore + except ImportError: + return _timeseries_error( + "vlt_library", + "vlt library is not available — install vhlab-toolbox-python.", + ) + try: + return _from_vlt_vhsb(vhsb_read(io.BytesIO(payload))) + except Exception as e: + log.warning("binary.vlt_decode_failed", error=str(e)) + return _timeseries_error("decode", f"vlt VHSB decode failed: {e}") +try: + if name.endswith(".vhsb") or (payload[:4] == b"VHSB"): + return _parse_vhsb(payload) + return _parse_nbf(payload) +``` + +Plus a small private helper `_from_vlt_vhsb()` that converts vlt's output (likely a numpy array + sample-rate + channel-name list) into the existing `{channels, timestamps, sample_count, format, error}` envelope. Exact shape needs verification against vlt's actual API — Phase A *first action* is to read `vlt/file.py` upstream. + +### A.3 Test plan + +- **Unit**: synthesize a minimal text-header VHSB payload (or pull one from the Haley dataset by hand) and feed it through `BinaryService.get_timeseries` against a mocked cloud-download. Assert non-empty `channels`, correct `format == "vhsb"`, sane `sample_count`. +- **Integration**: extend `backend/tests/integration/test_routes.py` with a route test for `/api/datasets/.../documents/.../signal` against a Haley doc — requires either a recorded fixture or live cloud creds in CI. Recorded fixture is preferred (faster + no creds in CI). +- **Smoke (manual, post-deploy)**: hit `GET /api/datasets/682e7772cdf3f24938176fac/documents//signal` against the deployed Railway URL and confirm a JSON response with non-empty channels. +- **Backward-compat**: the NBF + binary-VHSB paths are untouched. The existing 56 binary-service tests must still pass. + +### A.4 Risk + rollback + +| Concern | Mitigation | +|---|---| +| Image grows by ~10-20 MB | Acceptable. Heavy deps (numpy, scipy) already in. | +| `git` in image adds ~30 MB | Acceptable. One-time cost. | +| vlt's API may not match our envelope | Phase A's *first* concrete action is to read `vlt/file.py` upstream and write the helper. If the API doesn't fit, we adapt or abort Phase A; no production impact. | +| New Dockerfile layer cache miss | First Railway build will be slow (~3-5 min). Subsequent builds re-use the apt layer. | +| Text-header VHSB variant has multiple sub-formats | The vlt library handles them all (that's its job). If we discover a sub-format vlt doesn't handle, we fall through to the existing `_parse_vhsb` (binary magic path). | + +**Rollback**: `git revert` the merge commit. The change is isolated to one branch / one merge; nothing else depends on it. + +### A.5 Pre-flight verification needed + +Before writing Phase A code: + +1. **Read `vhlab-toolbox-python/src/vlt/file.py` upstream** (GitHub) and confirm the public API surface — exact function names, return shapes. +2. **Confirm Railway's `pip install` step has internet access to github.com**. (It almost certainly does, since redis/etc come from PyPI which is GitHub-backed by some mirrors, but `git+https://github.com/...` is a different code path.) +3. **Pick one Haley VHSB doc** as the smoke-test target and note its docId + filename. + +### A.6 Estimated wall-clock + +- 2-3 hours: read vlt's API, write the binary_service change, write the unit test +- 1 hour: smoke-test locally against a saved Haley payload (or via `httpx` against the live cloud) +- 30 min: open PR, wait for CI +- 30 min: merge, wait for Railway deploy, smoke against live URL + +**Total: ~half a day.** Lower bound assumes vlt's API is well-documented and matches the shape we need. + +--- + +## Phase B — full NDI-python (`database_openbinarydoc` swap) + +**Goal:** Replace the two inline binary parsers (`_parse_nbf`, `_parse_vhsb` + the new vlt path) with a single canonical call: `dataset.database_openbinarydoc(doc, filename) → file_handle`. One source of truth for binary parsing, native multi-file selection (eliminates the `?file=` workaround entirely), and forward compatibility with any new binary formats NDI adds upstream. + +### B.1 The cache + volume design question (THE main thing to decide) + +`downloadDataset` is **not per-request**. It eagerly fetches Mongo metadata for a whole dataset (minutes for big ones). Three workable patterns: + +**Option B-1: Persistent volume + warm cache on startup** +- Mount a Railway persistent volume at `/data/ndi/` +- On worker startup, for each of the 8 published datasets, run `downloadDataset(id, /data/ndi/{id})` +- Cache survives across deploys (volume is persistent) +- First boot is slow (potentially 30-60 min for big datasets); subsequent boots are fast (already-cached metadata) +- **Multi-replica caveat**: if Railway scales to N workers, each one re-downloads to its own volume slice. Shared-volume solutions need RWX (NFS-class). Otherwise: download once via a separate one-shot job / cron, share via S3-backed `mount`. + +**Option B-2: Lazy + LRU** +- No startup work. First request for dataset X triggers `downloadDataset(X)` and the response waits. +- Sub-Pattern: a separate background job warms the top-K most-queried datasets while the worker is otherwise idle. +- Eviction: LRU on disk usage; when over budget, delete oldest dataset's `/data/ndi/{id}` dir. +- **Failure mode**: cold first-request latency is intolerable for chat UX (10-30 min). Mitigated by warming. + +**Option B-3: Hybrid — startup-warm the demo datasets only** +- Audri has 8 published datasets. Pre-warm the 3 tutorial-having ones (Bhar, Haley, Dabrowska) on startup. +- For the other 5, fall through to Option B-2 (lazy + LRU). +- **Best risk/reward** for the demo era — known-good warm path for the demo prompts, fallback for everything else. + +**My recommendation: Option B-3 with a `NDI_PREWARM_IDS` env var listing the dataset IDs to fetch on startup.** Cheap to implement; doesn't paint us into a corner. + +### B.2 Files to modify (rough) + +- `infra/Dockerfile`: install full NDI-python + add `/data` volume directive (the volume itself is configured in Railway, the Dockerfile just creates the mount-point dir) +- `infra/railway.toml`: declare the persistent volume +- `backend/requirements.txt` + `pyproject.toml`: add `ndi @ git+...` +- `backend/services/ndi_python_service.py` (NEW): wraps `ndi.cloud.orchestration.downloadDataset` + manages the in-memory `{dataset_id: ndi_dataset_dir}` cache +- `backend/services/binary_service.py`: refactor `get_timeseries` to call `ndi_python_service.open_binary(dataset_id, doc, filename)` behind a feature flag +- `backend/app.py`: startup hook that pre-warms `NDI_PREWARM_IDS` datasets in a background task +- `backend/auth/ndi_cloud.py` (NEW or extension of existing): manage the NDI Cloud JWT lifecycle (currently the FastAPI is using its own session auth; NDI-python needs `NDI_CLOUD_USERNAME` + `NDI_CLOUD_PASSWORD` env vars) +- Tests: characterization test that compares old-vs-new outputs for a known set of NBF + VHSB docs + +### B.3 Feature flag + rollback plan + +- Add `NDI_PYTHON_BINARY=on|off` env var (default `off` initially) +- Branch `get_timeseries`: + - `off`: keep the existing inline parser path (today's code) + - `on`: route to `ndi_python_service.open_binary` +- A/B for one week. Track: + - Latency P50/P95 for `/data/timeseries` + - Response-shape diff rate (should be 0) + - Error rate +- Rollback: flip flag back to `off`. Worst case `git revert` the merge. + +### B.4 Open questions + +1. **Multi-replica strategy**: Railway's persistent volume model — is RWX supported? How does it interact with `WEB_CONCURRENCY=4`? Currently each uvicorn worker is in the same container, so they share the volume trivially. If Railway autoscales to N containers, that breaks. +2. **NDI Cloud auth lifetime**: JWT exp is ~1h per `NDI-python/src/ndi/cloud/auth.py`. We need a refresh strategy (probably refresh-on-401 via the username/password fallback path). +3. **Image build time**: full NDI-python install with 4 git-sourced deps will lengthen CI build time. Cacheable via Docker layer ordering but worth measuring. +4. **Test creds in CI**: `NDI-python/tests/test_cloud_*.py` skip when `NDI_CLOUD_USERNAME` / `PASSWORD` aren't set. Should our own integration tests require live creds, or use a recorded fixture? +5. **AWS Lambda gateway flakiness**: `test_cloud_live.py:42-68` notes the cloud API returns frequent 504s. Need retry + backoff in `ndi_python_service`. + +### B.5 What this unlocks + +- Eliminates the `?file=` param workaround entirely (`database_openbinarydoc` takes the filename natively) +- Supports any future binary format NDI adds (we inherit decoders for free) +- The QuickPlot in the Document Explorer now reads the same upgraded outputs (same `{channels, timestamps, ...}` envelope) → public data-browser users see VHSB decoded too, without any frontend change +- The same `ndi.dataset.Dataset` handle becomes available to the upcoming **private data browser** — same Python API researchers use locally is the cloud read/edit surface +- Lays the groundwork for Phase C + +--- + +## Phase C — new rich endpoints + +**Goal:** Add capabilities the existing REST passthrough can't provide. Purely additive; existing routes stay byte-identical. + +### C.1 Proposed endpoints + +- **`POST /api/datasets/:id/ndiquery`** — accepts an `ndi.query.Query`-style JSON filter. Powers the killer cross-dataset chatbot question in Ask ("compare patch-clamp in V1 across mouse and rat datasets") AND surfaces in the public data browser as a richer query builder than today's class-table filter. Backed by `dataset.database_search(q)`. +- **`POST /api/datasets/:id/documents/:docId/edit`** (auth-gated) — uses `Dataset.database_add` / `_remove` for validated document edits. Foundation for the upcoming **private Data Browser** product where logged-in users can edit their own datasets through a UI. Reuses NDI's schema validation + provenance machinery — we don't reimplement either. +- **`GET /api/datasets/:id/elements/:elementId/native`** — wraps `ndi.element` for richer single-element responses (epoch lists with native typing, probe definitions, etc.). Used by Ask chat + public data browser's element detail view. + +### C.2 Risk + +Low. Each is a new route. If buggy, only Ask chat (which is opt-in feature-flagged on the frontend) and the upcoming Data Browser product (which isn't shipped yet) are affected. Public Document Explorer + catalog APIs untouched. + +### C.3 Out of scope for this spec + +The detailed contracts (request/response shapes, error mapping, rate limits, auth gating) deserve their own spec when we get to them. Phase A + B groundwork has to land first. + +--- + +## Concerns + mitigations (matrix) + +| Concern | Phase | Mitigation | +|---|---|---| +| Docker image size grows ~150-200 MB | B | Worth it. Phase A is just ~10-20 MB. | +| Cold-start adds ~500ms for the ndi import | B | Lazy import (existing pattern in `binary_service.py`). | +| NDI-python version drift | B/C | Pin `ndi==X.Y.Z` once stable; track upstream PRs. | +| Cloud-dataset volume strategy unknown | B | THIS spec's main open design decision. My recommendation: Option B-3 (warm 3 demo IDs at startup, lazy for the rest). Audri to confirm. | +| Multi-replica scaling on Railway | B | Need to research Railway's RWX volume support. If unavailable, use a separate one-shot warmer + S3-mounted shared dir. | +| Performance regression on public Document Explorer | B | Feature flag for week-long A/B; rollback is one flag flip. | +| AWS Lambda gateway 504s | B | Retry-with-backoff wrapper around every cloud call. NDI-python's tests already document this pattern. | +| Existing ADR-009 service-layer HTTP ban | B/C | Per-file carve-out for `services/ndi_python_service.py` (precedent: `services/ontology_service.py` already has one). | +| Test creds in CI | B/C | Use recorded fixtures; reserve live-cred tests for nightly or manual. | + +## Recommended sequence + +1. **Now**: Audri reads this spec, signs off on the Phase A code change + the Option B-3 cache strategy (or proposes an alternative). +2. **Phase A (today / tomorrow)**: ~half-day work. New branch on ndb-v2, ~10-15 LOC change in `binary_service.py`, +1 dep, +1 apt line, +1 test. PR → CI → merge → Railway deploys → smoke against Haley. **Done.** +3. **Demo**: re-run the chart prompt against the Vercel preview. With Phase A landed, *both* Dabrowska NBF and Haley VHSB voltage traces render in the chat. **This is the moment the demo gets the second "wow" datapoint.** +4. **Phase B research week**: write a Phase B detailed spec (separate doc) with the volume + auth + multi-replica answers nailed. Audri reviews before any Phase B code. +5. **Phase B implementation week**: feature-flagged refactor, week-long A/B, then flip. +6. **Phase C**: scope each endpoint individually as a separate PR. No rush. + +## Critical file pointers (so the next session can continue) + +- **This spec**: `ndi-data-browser-v2/docs/plans/2026-05-13-ndi-python-integration.md` (you're reading it) +- **Companion checkpoint**: `ndi-cloud-app/apps/web/docs/specs/2026-05-13-ask-checkpoint-pre-compact.md` +- **NDI-python cloud module**: `/Users/audribhowmick/Documents/ndi-projects/NDI-python/src/ndi/cloud/` +- **NDI-python cloud tests** (the canonical "how to use it" examples): `/Users/audribhowmick/Documents/ndi-projects/NDI-python/tests/test_cloud_*.py` +- **vhlab-toolbox-python (vlt)**: `https://github.com/VH-Lab/vhlab-toolbox-python` (not cloned locally yet — need to fetch for Phase A code) +- **Current binary parser**: `ndi-data-browser-v2/backend/services/binary_service.py` (lines 164-184 are the edit target for Phase A) +- **NDI-python tutorials** (real usage patterns): `/Users/audribhowmick/Documents/ndi-projects/NDI-python/tutorials/tutorial_*.py` + +## Open questions for Audri + +1. **Phase A approval?** ~10 LOC + 1 dep + 1 apt line + 1 test. Risk is low. Land it before any further architectural moves? +2. **Volume strategy for Phase B?** Option B-3 (warm 3 demo IDs on startup, lazy for the rest) — agreed, or different preference? +3. **Phase B feature flag** — fine to default `off` for a week of A/B, then flip? +4. **Phase C scope** — same set of three endpoints (`ndiquery`, `edit`, `element/native`), or different priorities? +5. **NDI Cloud test creds** — should our integration tests require live creds, or recorded fixtures only? +6. **Timing** — Phase A this week, Phase B research next week, Phase B implementation week after? Or different cadence? + +--- + +*No production code has been written for any of these phases. This document is a planning artifact only. The Phase A change is small and well-scoped; the Phase B refactor needs more design work; Phase C waits on Phase B.* diff --git a/docs/plans/2026-05-13-railway-experimental-env-runbook.md b/docs/plans/2026-05-13-railway-experimental-env-runbook.md new file mode 100644 index 0000000..91be501 --- /dev/null +++ b/docs/plans/2026-05-13-railway-experimental-env-runbook.md @@ -0,0 +1,141 @@ +# Railway experimental environment — setup runbook + +Companion to `2026-05-13-ndi-python-integration.md`. The Phase A backend changes live on `feat/ndi-python-phase-a`; this doc walks through the **dashboard-only steps** that are required to spin up an "experimental" Railway environment pointing at that branch, so the audit can compare it against production byte-for-byte. + +**Why manual?** Railway's MCP / API doesn't expose environment-creation. Environments are a project-level construct that has to be set up via the dashboard. Once it exists, all subsequent deploys + redeploys CAN be triggered programmatically. + +## Pre-flight + +- [ ] You're logged into Railway on the audrib's-Projects workspace +- [ ] The `feat/ndi-python-phase-a` branch is pushed to GitHub (Claude will commit + push as the last step of the implementation pass) +- [ ] You have ~10 minutes for the dashboard walk-through plus ~5 min for the first build to run + +## Step-by-step + +### 1. Open the project + +Navigate to: + +``` +https://railway.com/project/81a57456-ae9a-47d0-98ef-2b5463f4815b +``` + +You should see the `ndi-data-browser-v2` project with three services: **ndb-v2**, **Postgres**, **Redis**, and the environment dropdown showing **"production"** in the top-left. + +### 2. Create the new environment + +1. Click the environment dropdown (top-left, currently "production") +2. Click **"+ New Environment"** +3. Name it **`experimental`** (lowercase, no spaces) +4. Choose **"Fork from production"** when prompted — this copies the existing services and env vars as starting points (saves us from re-entering NDI_CLOUD_USERNAME etc.). DO NOT pick "Create empty" — that's much more work. +5. Click **Create** + +You should now be inside the new `experimental` environment, with copies of all three services. + +### 3. Point `ndb-v2` at the feature branch + +1. Click into the **`ndb-v2`** service (still in the `experimental` environment) +2. Go to **Settings** → **Service** → **Source** +3. Change the **Branch** from `main` → **`feat/ndi-python-phase-a`** +4. Save / confirm + +Railway will trigger a deploy. **Wait ~3-5 minutes** for the new image (with NDI-python + git deps) to build. The Dockerfile's added `RUN python -c "from vlt.file..."` sanity check will fail the build if anything is missing, so a successful deploy = the import chain works end-to-end. + +### 4. Verify Postgres + Redis are shared (or not) + +The forked environment SHOULD have its own logical instances of Postgres + Redis under the same project umbrella. **Open each service in the experimental env and confirm**: + +- The Postgres service inside `experimental` is a separate instance from production's. If it's NOT (i.e., it's the same `DATABASE_URL`), you have two options: + - **(a) Share — accept the risk**: experimental writes to production's Postgres. Acceptable IF the experimental backend is read-only on Postgres (which the NDI-python paths are — they don't write). + - **(b) Isolate — recommended**: in experimental's Postgres service settings, click **"Create new database"**. This adds a fresh empty Postgres instance for experimental only. +- Same checkbox for Redis. Redis is the cache layer; sharing it is mostly fine (cache poisoning is the only risk; experimental writes the same shape of data as production). + +**My recommendation:** isolate Postgres, share Redis. Cheapest cost, lowest risk. + +### 5. Get the public URL + +1. Inside the `experimental` env's **ndb-v2** service, go to **Settings** → **Networking** +2. Under **Public Networking**, click **"Generate Domain"** (or similar — Railway sometimes auto-assigns) +3. Copy the resulting URL — should look like `ndb-v2-experimental-production.up.railway.app` or `ndb-v2-experimental.up.railway.app` +4. Verify it responds: `curl https:///api/health` should return `{"status":"ok"}` (or similar) + +### 6. Set the cloud-app preview to point at this URL + +This step is on the Vercel side, NOT Railway. Two ways to do it: + +**Option A — Branch-scoped env vars (recommended):** + +1. Go to https://vercel.com/your-team/ndi-cloud-app/settings/environment-variables +2. For each of these vars, **add a new entry** scoped to the **Preview** environment for the **`feat/experimental-ask-chat`** branch: + +``` +UPSTREAM_API_URL=https:// +INTERNAL_API_URL=https:// +``` + +3. Hit **Save** for each +4. Trigger a fresh build of `feat/experimental-ask-chat` (push any commit, or click "Redeploy" in Vercel's Deployments tab) + +**Option B — Just override at deploy time:** + +If you don't want persistent env-var entries, you can pass them inline when triggering a redeploy from Vercel CLI: + +``` +vercel --prod=false env add UPSTREAM_API_URL preview feat/experimental-ask-chat +``` + +Either way, the Vercel preview that comes out the other side should now serve the experimental backend's responses to anonymous public page requests. + +### 7. Smoke-check before running the audit + +Open the Vercel preview URL in an incognito browser: + +- `/datasets` should load with the catalog (8 datasets) +- `/datasets/682e7772cdf3f24938176fac/documents` (Haley) should load +- Pick a Haley binary doc → expand QuickPlot → **should now render the position trace** (previously soft-errored with the vlt_library message — this is the Phase A unblock) + +If any of those fail, check the Railway logs for the `ndb-v2` service in the `experimental` env via: + +``` +gh api /repos/Waltham-Data-Science/ndi-data-browser-v2/actions # (or the railway-agent MCP) +``` + +Or pull logs directly from the dashboard. + +### 8. Tell Claude the audit is ready + +Reply with the two URLs and Claude will run the audit: + +``` +LIVE_URL=https://ndi-cloud.com +EXPERIMENTAL_URL= +``` + +Claude will also need the experimental Railway URL (for the Layer 1 backend-API diff): + +``` +EXPERIMENTAL_API_URL=https:// +``` + +## Expected cost + +For the `experimental` environment with 2 replicas of ndb-v2 + isolated Postgres + shared Redis, while the audit is running: + +- **ndb-v2**: ~$1-3/mo while actively serving traffic, much less idle +- **Postgres (new instance)**: ~$3-5/mo for the smallest tier +- **Redis (shared)**: $0 — already in production + +**Total marginal: ~$5-10/mo while the env exists.** Pro plan's $20 monthly credit absorbs this if you're not already near the ceiling. + +**Tear down after the audit:** once Phase A is decided (either merged to main or rejected), you can delete the `experimental` environment to stop the meter: + +1. Project page → environment dropdown → "experimental" → "Delete Environment" + +The Postgres data + Redis content go with it (the production env is untouched). + +## Rollback / abort + +If at any step you decide not to proceed: + +- **Easiest**: delete the `experimental` environment per above. Zero impact on production. +- **More cautious**: pause the deploy on the experimental ndb-v2 service via Settings → Service → Pause. This stops the meter but preserves the setup for resumption. diff --git a/infra/Dockerfile b/infra/Dockerfile index 9c0d3ad..716482c 100644 --- a/infra/Dockerfile +++ b/infra/Dockerfile @@ -16,6 +16,7 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ RUN apt-get update && apt-get install -y --no-install-recommends \ libjpeg62-turbo libtiff6 curl \ + git \ && rm -rf /var/lib/apt/lists/* WORKDIR /app @@ -25,8 +26,40 @@ RUN useradd -u 10001 -m ndb RUN mkdir -p /tmp/ndb && chown -R ndb:ndb /tmp/ndb /app COPY backend/requirements.txt ./backend/requirements.txt +# Upgrade pip to >=26.1 before resolving deps — the python:3.12-slim base +# image's bundled pip (26.0.x as of 2026-05-15) carries CVE-2026-6357. The +# upgrade closes the CVE during image build without changing the runtime +# Python interpreter. Surfaced 2026-05-15 by pip-audit against the dev venv. +RUN pip install --upgrade 'pip>=26.1' RUN pip install -r backend/requirements.txt +# --- NDI-python integration (Phase A) --------------------------------------- +# NDI-python and its git-sourced kin (did, ndr, vhlab-toolbox-python, +# ndi-compress) are installed with `--no-deps` to skip matplotlib (~50-70 MB) +# and opencv-python-headless (~80 MB) which they declare but our backend never +# imports. The runtime deps we DO need are listed in backend/requirements.txt +# above (pandas, networkx, jsonschema, openMINDS, portalocker, h5py, requests). +# +# CRITICAL: pinned to specific git SHAs (NOT @main) to prevent silent upstream +# drift. If upstream pushes a breaking change to e.g. NDI-python's +# `ontology.lookup` return shape, our experimental Railway redeploys silently +# with the new code — and chats start returning wrong data with no signal. +# Pinning forces every drift to be an explicit Dockerfile edit + redeploy. +# Bump these SHAs by re-running `git ls-remote HEAD` and pasting the +# new hashes. See docs/plans/2026-05-13-ndi-python-integration.md for the +# overall integration plan. +# +# Pins captured 2026-05-13: +RUN pip install --no-deps "vhlab-toolbox-python @ git+https://github.com/VH-Lab/vhlab-toolbox-python.git@b073185565ea5b47bb0307cddeae923fa9b86268" +RUN pip install --no-deps "did @ git+https://github.com/VH-Lab/DID-python.git@1b1491fb98f37a61a74b86cccdfabae8b6bbce9e" +RUN pip install --no-deps "ndr @ git+https://github.com/VH-lab/NDR-python.git@896ed637c35cd8ba118e1512a8c65bdd634a7622" +RUN pip install --no-deps "ndi-compress @ git+https://github.com/Waltham-Data-Science/NDI-compress-python.git@0c05d9dbd63ed5d15866eb1bf0a096568ef0c192" +RUN pip install --no-deps "ndi @ git+https://github.com/Waltham-Data-Science/NDI-python.git@9c64acb13bfbe0baf7f6b40bee18925fd7d9117c" +# Sanity: import every entry point we use in production so a missing +# transitive dep fails the build, not the first request. +RUN python -c "from vlt.file.custom_file_formats import vhsb_read, vhsb_readheader; \ +import ndicompress; from ndi.ontology import lookup; print('ndi-python stack importable')" + COPY backend/ ./backend/ COPY --from=frontend-build /app/frontend/dist ./frontend_dist @@ -35,7 +68,8 @@ USER ndb ENV ONTOLOGY_CACHE_DB_PATH=/tmp/ndb/ontology.db \ LOG_FORMAT=json \ WEB_CONCURRENCY=4 \ - PORT=8000 + PORT=8000 \ + NDI_PYTHON_REQUIRED=1 HEALTHCHECK --interval=30s --timeout=5s --start-period=60s --retries=3 \ CMD curl -fsS http://127.0.0.1:${PORT}/api/health || exit 1