From 2033e0d73ee89f523e663ceb1814e9857b9dd52d Mon Sep 17 00:00:00 2001 From: frapercan Date: Wed, 18 Mar 2026 13:34:49 +0100 Subject: [PATCH 1/7] feat(infra): connection pool, DLQ, structured logging, health probes, stale reaper - DB connection pool (pool_size=20, max_overflow=40, pool_recycle=3600) - Publisher: thread-local connection reuse, exponential backoff (5 attempts) - Consumer: dead letter queue (protea.dlx -> protea.dead-letter) on all queues - Consumer: OperationConsumer emit writes JobEvent to parent job - Health endpoints: /health (liveness) and /health/ready (DB + RabbitMQ) - Structured JSON logging with --log-format flag in worker - StaleJobReaper: marks RUNNING jobs as FAILED after 1h timeout - BaseWorker: adaptive backoff for RetryLaterError (capped at 600s) - Cancel endpoint: cancels both QUEUED and RUNNING children - Composite indexes on (annotation_set_id, accession) and (prediction_set_id, accession) - Taxonomy DB warmup at worker startup for prediction queues - Multi-stage Dockerfile with healthcheck - docker-compose: all 11 services with memory limits --- Dockerfile | 26 ++++- ...2eb0f986d_add_composite_indexes_for_knn.py | 37 ++++++ docker-compose.prod.yml | 18 ++- docker-compose.yml | 92 ++++++++++++++- protea/api/app.py | 28 ++++- protea/api/routers/jobs.py | 17 ++- protea/core/feature_engineering.py | 16 +++ protea/infrastructure/database/engine.py | 9 +- protea/infrastructure/logging.py | 104 +++++++++++++++++ .../annotation/protein_go_annotation.py | 3 +- .../orm/models/embedding/go_prediction.py | 3 +- protea/infrastructure/queue/consumer.py | 83 +++++++++++++- protea/infrastructure/queue/publisher.py | 70 ++++++++---- protea/workers/base_worker.py | 14 ++- protea/workers/stale_job_reaper.py | 106 ++++++++++++++++++ scripts/manage.sh | 6 +- scripts/worker.py | 33 +++++- 17 files changed, 619 insertions(+), 46 deletions(-) create mode 100644 alembic/versions/5fc2eb0f986d_add_composite_indexes_for_knn.py create mode 100644 protea/infrastructure/logging.py create mode 100644 protea/workers/stale_job_reaper.py diff --git a/Dockerfile b/Dockerfile index 0df0128..17c1ecc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,5 @@ -FROM python:3.12-slim +# ── Stage 1: build dependencies ────────────────────────────────────────────── +FROM python:3.12-slim AS builder RUN apt-get update && apt-get install -y \ build-essential \ @@ -8,16 +9,30 @@ RUN apt-get update && apt-get install -y \ WORKDIR /app -# Install Poetry and dependencies first (layer cache) RUN pip install --no-cache-dir poetry==2.1.0 COPY pyproject.toml poetry.lock ./ RUN poetry config virtualenvs.create false \ && poetry install --without dev --no-root --no-interaction --no-ansi -# Copy source COPY protea/ ./protea/ RUN poetry install --without dev --no-interaction --no-ansi + +# ── Stage 2: runtime ──────────────────────────────────────────────────────── +FROM python:3.12-slim + +RUN apt-get update && apt-get install -y \ + libpq5 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy installed packages from builder +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin + +# Copy application code +COPY protea/ ./protea/ COPY scripts/ ./scripts/ COPY alembic/ ./alembic/ COPY alembic.ini ./ @@ -25,7 +40,10 @@ COPY alembic.ini ./ ENV PYTHONUNBUFFERED=1 EXPOSE 8000 +HEALTHCHECK --interval=30s --timeout=5s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1 + # Default: API server # Override CMD to run a worker: # docker run protea python scripts/worker.py --queue protea.jobs -CMD ["uvicorn", "protea.api.app:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["uvicorn", "protea.api.app:create_app", "--factory", "--host", "0.0.0.0", "--port", "8000"] diff --git a/alembic/versions/5fc2eb0f986d_add_composite_indexes_for_knn.py b/alembic/versions/5fc2eb0f986d_add_composite_indexes_for_knn.py new file mode 100644 index 0000000..1fcab83 --- /dev/null +++ b/alembic/versions/5fc2eb0f986d_add_composite_indexes_for_knn.py @@ -0,0 +1,37 @@ +"""add composite indexes for KNN performance + +Revision ID: 5fc2eb0f986d +Revises: 54e758c210c8 +Create Date: 2026-03-18 12:00:00.000000 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "5fc2eb0f986d" +down_revision: str = "54e758c210c8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Composite index for KNN GO transfer: queries are always scoped to + # a single annotation_set_id and filtered by protein_accession. + op.create_index( + "ix_pga_set_accession", + "protein_go_annotation", + ["annotation_set_id", "protein_accession"], + ) + + # Composite index for prediction export and evaluation: queries filter + # by prediction_set_id then protein_accession. + op.create_index( + "ix_go_prediction_set_accession", + "go_prediction", + ["prediction_set_id", "protein_accession"], + ) + + +def downgrade() -> None: + op.drop_index("ix_go_prediction_set_accession", table_name="go_prediction") + op.drop_index("ix_pga_set_accession", table_name="protein_go_annotation") diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index e5f9db4..49d97c0 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -1,7 +1,7 @@ # Production overrides: pull pre-built images from ghcr.io instead of building locally. # Use with: docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d # -# The worker-embeddings service gets GPU access via the NVIDIA container runtime. +# The worker-embeddings-batch service gets GPU access via the NVIDIA container runtime. services: migrate: @@ -15,6 +15,9 @@ services: worker-embeddings: image: ghcr.io/frapercan/protea:latest + + worker-embeddings-batch: + image: ghcr.io/frapercan/protea:latest deploy: resources: reservations: @@ -22,8 +25,19 @@ services: - driver: nvidia count: all capabilities: [gpu] + limits: + memory: 8G + + worker-embeddings-write: + image: ghcr.io/frapercan/protea:latest + + worker-predictions-batch: + image: ghcr.io/frapercan/protea:latest + + worker-predictions-write: + image: ghcr.io/frapercan/protea:latest - worker-predictions: + worker-reaper: image: ghcr.io/frapercan/protea:latest frontend: diff --git a/docker-compose.yml b/docker-compose.yml index cb925b8..1e74b6c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,6 +16,10 @@ services: interval: 10s timeout: 5s retries: 5 + deploy: + resources: + limits: + memory: 2G rabbitmq: image: rabbitmq:3-management @@ -30,6 +34,10 @@ services: interval: 10s timeout: 5s retries: 5 + deploy: + resources: + limits: + memory: 512M migrate: build: . @@ -55,6 +63,10 @@ services: condition: service_healthy migrate: condition: service_completed_successfully + deploy: + resources: + limits: + memory: 1G worker-jobs: build: . @@ -67,6 +79,10 @@ services: condition: service_completed_successfully rabbitmq: condition: service_healthy + deploy: + resources: + limits: + memory: 2G worker-embeddings: build: . @@ -79,8 +95,44 @@ services: condition: service_completed_successfully rabbitmq: condition: service_healthy + deploy: + resources: + limits: + memory: 2G - worker-predictions: + worker-embeddings-batch: + build: . + environment: + PROTEA_DB_URL: postgresql+psycopg://protea:protea@postgres/protea + PROTEA_AMQP_URL: amqp://guest:guest@rabbitmq/ + command: python scripts/worker.py --queue protea.embeddings.batch + depends_on: + migrate: + condition: service_completed_successfully + rabbitmq: + condition: service_healthy + deploy: + resources: + limits: + memory: 4G + + worker-embeddings-write: + build: . + environment: + PROTEA_DB_URL: postgresql+psycopg://protea:protea@postgres/protea + PROTEA_AMQP_URL: amqp://guest:guest@rabbitmq/ + command: python scripts/worker.py --queue protea.embeddings.write + depends_on: + migrate: + condition: service_completed_successfully + rabbitmq: + condition: service_healthy + deploy: + resources: + limits: + memory: 1G + + worker-predictions-batch: build: . environment: PROTEA_DB_URL: postgresql+psycopg://protea:protea@postgres/protea @@ -91,6 +143,40 @@ services: condition: service_completed_successfully rabbitmq: condition: service_healthy + deploy: + resources: + limits: + memory: 4G + + worker-predictions-write: + build: . + environment: + PROTEA_DB_URL: postgresql+psycopg://protea:protea@postgres/protea + PROTEA_AMQP_URL: amqp://guest:guest@rabbitmq/ + command: python scripts/worker.py --queue protea.predictions.write + depends_on: + migrate: + condition: service_completed_successfully + rabbitmq: + condition: service_healthy + deploy: + resources: + limits: + memory: 1G + + worker-reaper: + build: . + environment: + PROTEA_DB_URL: postgresql+psycopg://protea:protea@postgres/protea + PROTEA_AMQP_URL: amqp://guest:guest@rabbitmq/ + command: python scripts/worker.py --queue reaper + depends_on: + migrate: + condition: service_completed_successfully + deploy: + resources: + limits: + memory: 256M frontend: build: @@ -102,6 +188,10 @@ services: - "3000:3000" depends_on: - api + deploy: + resources: + limits: + memory: 512M volumes: postgres_data: diff --git a/protea/api/app.py b/protea/api/app.py index 2d075bb..4d94602 100644 --- a/protea/api/app.py +++ b/protea/api/app.py @@ -3,7 +3,7 @@ from pathlib import Path -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles @@ -83,6 +83,32 @@ def create_app(project_root: Path | None = None) -> FastAPI: allow_headers=["*"], ) + @app.get("/health", tags=["health"]) + def health_check() -> dict[str, str]: + """Liveness probe — returns 200 if the API process is up.""" + return {"status": "ok"} + + @app.get("/health/ready", tags=["health"]) + def readiness_check() -> dict[str, str]: + """Readiness probe — verifies database and RabbitMQ connections.""" + from sqlalchemy import text + + from protea.infrastructure.session import session_scope + + with session_scope(factory) as session: + session.execute(text("SELECT 1")) + + # Check RabbitMQ connectivity + import pika + + try: + conn = pika.BlockingConnection(pika.URLParameters(settings.amqp_url)) + conn.close() + except Exception as exc: + raise HTTPException(status_code=503, detail=f"RabbitMQ unreachable: {exc}") + + return {"status": "ready"} + app.include_router(jobs_router.router) app.include_router(proteins_router.router) app.include_router(annotations_router.router) diff --git a/protea/api/routers/jobs.py b/protea/api/routers/jobs.py index c3d554c..3645460 100644 --- a/protea/api/routers/jobs.py +++ b/protea/api/routers/jobs.py @@ -7,6 +7,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session, sessionmaker + +from protea.core.utils import utcnow from starlette.requests import Request from protea.infrastructure.orm.models.job import Job, JobEvent, JobStatus @@ -228,10 +230,12 @@ def cancel_job( job_id: UUID, factory: sessionmaker[Session] = Depends(get_session_factory), ) -> dict[str, Any]: - """Mark a job (and any queued child jobs) as CANCELLED. + """Mark a job (and any non-terminal child jobs) as CANCELLED. Already-finished jobs (SUCCEEDED/FAILED) are returned as-is with no state change. - Note: workers processing a batch mid-flight will complete their current message before stopping. + Children in QUEUED are cancelled immediately. Children in RUNNING are also + marked CANCELLED — the worker's parent-check in BaseWorker.handle_job() will + detect the cancelled parent on the next iteration and stop gracefully. """ with session_scope(factory) as session: j = session.get(Job, job_id) @@ -242,16 +246,21 @@ def cancel_job( return {"id": str(j.id), "status": j.status.value} j.status = JobStatus.CANCELLED + j.finished_at = utcnow() session.add(JobEvent(job_id=job_id, event="job.cancelled", fields={})) - # Cancel any queued children so they are not picked up by a worker. + # Cancel all non-terminal children (QUEUED and RUNNING). children = ( session.query(Job) - .filter(Job.parent_job_id == job_id, Job.status == JobStatus.QUEUED) + .filter( + Job.parent_job_id == job_id, + Job.status.in_((JobStatus.QUEUED, JobStatus.RUNNING)), + ) .all() ) for child in children: child.status = JobStatus.CANCELLED + child.finished_at = utcnow() session.add( JobEvent( job_id=child.id, diff --git a/protea/core/feature_engineering.py b/protea/core/feature_engineering.py index 056a563..12760ba 100644 --- a/protea/core/feature_engineering.py +++ b/protea/core/feature_engineering.py @@ -129,6 +129,22 @@ def _get_ncbi() -> NCBITaxa: return _ncbi +def warmup_taxonomy_db() -> None: + """Pre-initialize the NCBITaxa database. + + Call at worker startup so the download (~100 MB on first run) + happens before any batch is processed, not mid-flight. + """ + if not _ETE3_AVAILABLE: + return + import logging + + log = logging.getLogger(__name__) + log.info("Warming up NCBI taxonomy database...") + _get_ncbi() + log.info("NCBI taxonomy database ready.") + + @lru_cache(maxsize=100_000) def _cached_lineage(tid: int) -> list[int]: return _get_ncbi().get_lineage(tid) # type: ignore[return-value] diff --git a/protea/infrastructure/database/engine.py b/protea/infrastructure/database/engine.py index dd02e91..7df6692 100644 --- a/protea/infrastructure/database/engine.py +++ b/protea/infrastructure/database/engine.py @@ -5,4 +5,11 @@ def build_engine(db_url: str) -> Engine: - return create_engine(db_url, future=True, pool_pre_ping=True) + return create_engine( + db_url, + future=True, + pool_pre_ping=True, + pool_size=20, + max_overflow=40, + pool_recycle=3600, + ) diff --git a/protea/infrastructure/logging.py b/protea/infrastructure/logging.py new file mode 100644 index 0000000..3d27c60 --- /dev/null +++ b/protea/infrastructure/logging.py @@ -0,0 +1,104 @@ +# protea/infrastructure/logging.py +"""Structured logging configuration for PROTEA. + +Provides a JSON formatter using only the Python standard library and a +``configure_logging()`` helper that workers and the API can call at startup. +""" +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone +from typing import Any + + +class JSONFormatter(logging.Formatter): + """Formats log records as single-line JSON objects. + + Each line contains at least ``timestamp``, ``level``, ``logger``, and + ``message``. Any *extra* fields attached to the record are merged into + the top-level JSON object, making it easy to add structured context + (e.g. ``logger.info("started", extra={"queue": "protea.jobs"})``). + """ + + # Keys that belong to the standard LogRecord and should not be forwarded + # as extra fields. + _BUILTIN_ATTRS: frozenset[str] = frozenset( + { + "args", + "created", + "exc_info", + "exc_text", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "message", + "module", + "msecs", + "msg", + "name", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "taskName", + "thread", + "threadName", + } + ) + + def format(self, record: logging.LogRecord) -> str: # noqa: D401 + log_entry: dict[str, Any] = { + "timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + + # Merge any extra fields the caller attached to the record. + for key, value in record.__dict__.items(): + if key not in self._BUILTIN_ATTRS: + log_entry[key] = value + + # Append exception info when present. + if record.exc_info and record.exc_info[1] is not None: + log_entry["exception"] = self.formatException(record.exc_info) + + if record.stack_info: + log_entry["stack_info"] = record.stack_info + + return json.dumps(log_entry, default=str) + + +_TEXT_FORMAT = "%(asctime)s %(levelname)s %(message)s" + + +def configure_logging(*, json: bool = True, level: str = "INFO") -> None: + """Configure the root logger for the process. + + Parameters + ---------- + json: + When *True* (the default), use :class:`JSONFormatter` so that every + log line is a valid JSON object. When *False*, fall back to the + plain-text format used during local development. + level: + Root log level name (e.g. ``"INFO"``, ``"DEBUG"``). + """ + root = logging.getLogger() + root.setLevel(getattr(logging, level.upper(), logging.INFO)) + + # Remove any handlers that may have been added by earlier basicConfig + # calls or library imports so we start fresh. + root.handlers.clear() + + handler = logging.StreamHandler() + if json: + handler.setFormatter(JSONFormatter()) + else: + handler.setFormatter(logging.Formatter(_TEXT_FORMAT)) + + root.addHandler(handler) diff --git a/protea/infrastructure/orm/models/annotation/protein_go_annotation.py b/protea/infrastructure/orm/models/annotation/protein_go_annotation.py index 6804860..934917d 100644 --- a/protea/infrastructure/orm/models/annotation/protein_go_annotation.py +++ b/protea/infrastructure/orm/models/annotation/protein_go_annotation.py @@ -3,7 +3,7 @@ import uuid from typing import TYPE_CHECKING -from sqlalchemy import BigInteger, ForeignKey, String, UniqueConstraint +from sqlalchemy import BigInteger, ForeignKey, Index, String, UniqueConstraint from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -37,6 +37,7 @@ class ProteinGOAnnotation(Base): "evidence_code", name="uq_pga_set_protein_term_evidence", ), + Index("ix_pga_set_accession", "annotation_set_id", "protein_accession"), ) id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) diff --git a/protea/infrastructure/orm/models/embedding/go_prediction.py b/protea/infrastructure/orm/models/embedding/go_prediction.py index ef58a93..4243c33 100644 --- a/protea/infrastructure/orm/models/embedding/go_prediction.py +++ b/protea/infrastructure/orm/models/embedding/go_prediction.py @@ -3,7 +3,7 @@ import uuid from typing import TYPE_CHECKING -from sqlalchemy import BigInteger, Float, ForeignKey, Integer, String, UniqueConstraint +from sqlalchemy import BigInteger, Float, ForeignKey, Index, Integer, String, UniqueConstraint from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -31,6 +31,7 @@ class GOPrediction(Base): "go_term_id", name="uq_go_prediction_set_protein_term", ), + Index("ix_go_prediction_set_accession", "prediction_set_id", "protein_accession"), ) id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) diff --git a/protea/infrastructure/queue/consumer.py b/protea/infrastructure/queue/consumer.py index 8c3310a..876b1cf 100644 --- a/protea/infrastructure/queue/consumer.py +++ b/protea/infrastructure/queue/consumer.py @@ -13,11 +13,22 @@ from protea.core.contracts.operation import RetryLaterError from protea.core.contracts.registry import OperationRegistry +from protea.infrastructure.orm.models.job import JobEvent from protea.infrastructure.queue.publisher import publish_operation from protea.workers.base_worker import BaseWorker logger = logging.getLogger(__name__) +_DLX_NAME = "protea.dlx" +_DLQ_NAME = "protea.dead-letter" + + +def _setup_dead_letter(channel: BlockingChannel) -> None: + """Declare the dead-letter exchange and queue (idempotent).""" + channel.exchange_declare(exchange=_DLX_NAME, exchange_type="fanout", durable=True) + channel.queue_declare(queue=_DLQ_NAME, durable=True) + channel.queue_bind(queue=_DLQ_NAME, exchange=_DLX_NAME) + class QueueConsumer: """ @@ -63,7 +74,12 @@ def run(self) -> None: connection = pika.BlockingConnection(params) channel = connection.channel() - channel.queue_declare(queue=self._queue_name, durable=True) + _setup_dead_letter(channel) + channel.queue_declare( + queue=self._queue_name, + durable=True, + arguments={"x-dead-letter-exchange": _DLX_NAME}, + ) channel.basic_qos(prefetch_count=self._prefetch_count) channel.basic_consume( queue=self._queue_name, @@ -182,7 +198,12 @@ def run(self) -> None: connection = pika.BlockingConnection(params) channel = connection.channel() - channel.queue_declare(queue=self._queue_name, durable=True) + _setup_dead_letter(channel) + channel.queue_declare( + queue=self._queue_name, + durable=True, + arguments={"x-dead-letter-exchange": _DLX_NAME}, + ) channel.basic_qos(prefetch_count=self._prefetch_count) channel.basic_consume( queue=self._queue_name, @@ -234,6 +255,14 @@ def _on_message( "Dispatching operation. operation=%s queue=%s", operation_name, self._queue_name ) + parent_job_id: UUID | None = None + raw_job_id = data.get("job_id") + if raw_job_id: + try: + parent_job_id = UUID(raw_job_id) + except (ValueError, TypeError): + pass + op = self._registry.get(operation_name) session = self._factory() try: @@ -245,6 +274,32 @@ def emit( level: str = "info", ) -> None: logger.info("operation.%s fields=%s", event, fields or {}) + if parent_job_id is not None: + event_session = self._factory() + try: + event_session.add( + JobEvent( + job_id=parent_job_id, + event=f"child.{event}", + message=message, + fields=fields or {}, + level=level, + ) + ) + event_session.commit() + except Exception as emit_exc: + logger.warning( + "Failed to write child event to parent job. " + "parent_job_id=%s error=%s", + parent_job_id, + emit_exc, + ) + try: + event_session.rollback() + except Exception: + pass + finally: + event_session.close() result = op.execute(session, payload, emit=emit) session.commit() @@ -270,6 +325,30 @@ def emit( ) else: logger.error("Operation failed. operation=%s error=%s", operation_name, exc) + # Record failure event on parent job so it's visible in the UI. + if parent_job_id is not None: + err_session = self._factory() + try: + err_session.add( + JobEvent( + job_id=parent_job_id, + event="child.failed", + message=str(exc)[:2000], + fields={ + "operation": operation_name, + "error_code": exc.__class__.__name__, + }, + level="error", + ) + ) + err_session.commit() + except Exception: + try: + err_session.rollback() + except Exception: + pass + finally: + err_session.close() try: session.rollback() except Exception: diff --git a/protea/infrastructure/queue/publisher.py b/protea/infrastructure/queue/publisher.py index 29a3985..f5115a4 100644 --- a/protea/infrastructure/queue/publisher.py +++ b/protea/infrastructure/queue/publisher.py @@ -2,6 +2,7 @@ import json import logging +import threading import time from typing import Any from uuid import UUID @@ -10,37 +11,60 @@ logger = logging.getLogger(__name__) -_RETRY_DELAYS = (1, 3, 10) # seconds between attempts (3 attempts total) +_MAX_ATTEMPTS = 5 +_BASE_DELAY = 1 # seconds; exponential backoff: 1, 2, 4, 8, 16 (capped at 30) + +# Thread-local persistent connection to avoid opening/closing per publish. +_local = threading.local() + + +def _get_connection(amqp_url: str) -> pika.BlockingConnection: + """Return a reusable connection, creating one if needed.""" + conn: pika.BlockingConnection | None = getattr(_local, "connection", None) + if conn is not None and conn.is_open: + return conn + _local.connection = pika.BlockingConnection(pika.URLParameters(amqp_url)) + return _local.connection + + +def _close_cached_connection() -> None: + conn: pika.BlockingConnection | None = getattr(_local, "connection", None) + if conn is not None and conn.is_open: + try: + conn.close() + except Exception: + pass + _local.connection = None def _publish(amqp_url: str, queue_name: str, body: bytes) -> None: - """Core publish logic with retries. Used by both publish_job and publish_operation.""" + """Core publish logic with retries and connection reuse.""" last_exc: Exception | None = None - for attempt, delay in enumerate((*_RETRY_DELAYS, None), start=1): + for attempt in range(1, _MAX_ATTEMPTS + 1): try: - connection = pika.BlockingConnection(pika.URLParameters(amqp_url)) - try: - channel = connection.channel() - channel.queue_declare(queue=queue_name, durable=True) - channel.basic_publish( - exchange="", - routing_key=queue_name, - body=body, - properties=pika.BasicProperties( - delivery_mode=pika.DeliveryMode.Persistent, - ), - ) - return - finally: - if connection.is_open: - connection.close() + connection = _get_connection(amqp_url) + channel = connection.channel() + channel.queue_declare(queue=queue_name, durable=True) + channel.basic_publish( + exchange="", + routing_key=queue_name, + body=body, + properties=pika.BasicProperties( + delivery_mode=pika.DeliveryMode.Persistent, + ), + ) + return except Exception as exc: last_exc = exc - if delay is not None: + # Connection is stale — discard it so next attempt creates a fresh one. + _close_cached_connection() + if attempt < _MAX_ATTEMPTS: + delay = min(_BASE_DELAY * (2 ** (attempt - 1)), 30) logger.warning( - "publish failed (attempt %d), retrying in %ds. queue=%s error=%s", + "publish failed (attempt %d/%d), retrying in %ds. queue=%s error=%s", attempt, + _MAX_ATTEMPTS, delay, queue_name, exc, @@ -49,13 +73,13 @@ def _publish(amqp_url: str, queue_name: str, body: bytes) -> None: else: logger.error( "publish failed after %d attempts. queue=%s error=%s", - attempt, + _MAX_ATTEMPTS, queue_name, exc, ) raise RuntimeError( - f"Failed to publish to queue {queue_name!r} after {len(_RETRY_DELAYS) + 1} attempts" + f"Failed to publish to queue {queue_name!r} after {_MAX_ATTEMPTS} attempts" ) from last_exc diff --git a/protea/workers/base_worker.py b/protea/workers/base_worker.py index b5d7323..7381b00 100644 --- a/protea/workers/base_worker.py +++ b/protea/workers/base_worker.py @@ -177,6 +177,16 @@ def emit( except RetryLaterError as e: # Resource busy — reset to QUEUED so the consumer can re-publish. + # Adaptive backoff: count previous retries and increase delay. + retry_count = ( + session.query(func.count(JobEvent.id)) + .filter(JobEvent.job_id == job_id, JobEvent.event == "job.retry_later") + .scalar() + or 0 + ) + base_delay = e.delay_seconds + delay = min(base_delay * (2 ** retry_count), 600) # cap at 10 min + job.status = JobStatus.QUEUED job.started_at = None self._emit( @@ -184,10 +194,12 @@ def emit( job_id, "job.retry_later", str(e), - {"delay_seconds": e.delay_seconds}, + {"delay_seconds": delay, "retry_count": retry_count + 1}, level="info", ) session.commit() + # Propagate adaptive delay to the consumer. + e.delay_seconds = delay raise # consumer handles re-publish except Exception as e: diff --git a/protea/workers/stale_job_reaper.py b/protea/workers/stale_job_reaper.py new file mode 100644 index 0000000..8aa1d3f --- /dev/null +++ b/protea/workers/stale_job_reaper.py @@ -0,0 +1,106 @@ +# protea/workers/stale_job_reaper.py +"""Periodic reaper that marks long-running jobs as FAILED. + +Workers are single-threaded and cannot be interrupted mid-operation without +risking data corruption. Instead, this lightweight reaper runs on a timer +and transitions any job that has been in RUNNING status for longer than +``timeout_seconds`` to FAILED with error_code ``JobTimeout``. + +Usage:: + + reaper = StaleJobReaper(session_factory, timeout_seconds=3600) + reaper.run(interval_seconds=60) # checks every minute +""" +from __future__ import annotations + +import logging +import signal +import time +from datetime import timedelta + +from sqlalchemy import update as sa_update +from sqlalchemy.orm import Session, sessionmaker + +from protea.core.utils import utcnow +from protea.infrastructure.orm.models.job import Job, JobEvent, JobStatus + +logger = logging.getLogger(__name__) + + +class StaleJobReaper: + def __init__( + self, + session_factory: sessionmaker[Session], + timeout_seconds: int = 3600, + ) -> None: + self._factory = session_factory + self._timeout = timedelta(seconds=timeout_seconds) + self._stop = False + + def run(self, interval_seconds: int = 60) -> None: + signal.signal(signal.SIGINT, self._handle_stop) + signal.signal(signal.SIGTERM, self._handle_stop) + + logger.info( + "StaleJobReaper started. timeout=%ss interval=%ss", + self._timeout.total_seconds(), + interval_seconds, + ) + while not self._stop: + try: + reaped = self._reap() + if reaped: + logger.info("Reaped %d stale job(s).", reaped) + except Exception as exc: + logger.error("Reaper cycle failed: %s", exc) + time.sleep(interval_seconds) + + logger.info("StaleJobReaper stopped.") + + def _handle_stop(self, *_: object) -> None: + self._stop = True + + def _reap(self) -> int: + cutoff = utcnow() - self._timeout + session = self._factory() + try: + stale_jobs = ( + session.query(Job) + .filter( + Job.status == JobStatus.RUNNING, + Job.started_at < cutoff, + ) + .all() + ) + for job in stale_jobs: + job.status = JobStatus.FAILED + job.finished_at = utcnow() + job.error_code = "JobTimeout" + job.error_message = ( + f"Job exceeded timeout of {self._timeout.total_seconds():.0f}s" + ) + session.add( + JobEvent( + job_id=job.id, + event="job.timeout", + message=job.error_message, + fields={"timeout_seconds": self._timeout.total_seconds()}, + level="error", + ) + ) + logger.warning( + "Marking stale job FAILED. job_id=%s operation=%s started_at=%s", + job.id, + job.operation, + job.started_at, + ) + session.commit() + return len(stale_jobs) + except Exception: + try: + session.rollback() + except Exception: + pass + raise + finally: + session.close() diff --git a/scripts/manage.sh b/scripts/manage.sh index 020a8e3..fe49d1d 100755 --- a/scripts/manage.sh +++ b/scripts/manage.sh @@ -103,8 +103,12 @@ cmd_start() { done _start_bg worker-predictions-write poetry run python scripts/worker.py --queue protea.predictions.write + # Stale job reaper + printf "\n${BOLD}[6] Stale job reaper${RESET}\n" + _start_bg worker-reaper poetry run python scripts/worker.py --queue reaper + # Frontend - printf "\n${BOLD}[6] Frontend${RESET}\n" + printf "\n${BOLD}[7] Frontend${RESET}\n" cd "$ROOT/apps/web" _start_bg frontend npm run dev sleep 6 diff --git a/scripts/worker.py b/scripts/worker.py index de129da..dbbd9b5 100644 --- a/scripts/worker.py +++ b/scripts/worker.py @@ -36,18 +36,27 @@ from protea.infrastructure.session import build_session_factory from protea.infrastructure.settings import load_settings from protea.workers.base_worker import BaseWorker, WorkerConfig - -logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") -# Suppress pika's verbose connection lifecycle messages -logging.getLogger("pika").setLevel(logging.WARNING) +from protea.workers.stale_job_reaper import StaleJobReaper def main() -> None: parser = argparse.ArgumentParser(description="PROTEA queue worker") parser.add_argument("--queue", default="protea.jobs", help="Queue name to consume") parser.add_argument("--requeue-on-failure", action="store_true") + parser.add_argument( + "--log-format", + choices=["json", "text"], + default="json", + help="Log output format (default: json)", + ) args = parser.parse_args() + from protea.infrastructure.logging import configure_logging + + configure_logging(json=(args.log_format == "json")) + # Suppress pika's verbose connection lifecycle messages + logging.getLogger("pika").setLevel(logging.WARNING) + project_root = Path(__file__).resolve().parents[1] settings = load_settings(project_root) @@ -78,6 +87,13 @@ def main() -> None: "protea.predictions.write", } + # Special mode: stale job reaper (no queue, just periodic DB check). + if args.queue == "reaper": + reaper = StaleJobReaper(factory, timeout_seconds=3600) + logging.info("Stale job reaper started. timeout=3600s interval=60s") + reaper.run(interval_seconds=60) + return + if args.queue in _OPERATION_QUEUES: consumer: QueueConsumer | OperationConsumer = OperationConsumer( amqp_url=settings.amqp_url, @@ -95,6 +111,15 @@ def main() -> None: requeue_on_failure=args.requeue_on_failure, ) + # Pre-warm taxonomy DB for prediction workers that may need it. + if args.queue in ("protea.predictions.batch", "protea.jobs"): + try: + from protea.core.feature_engineering import warmup_taxonomy_db + + warmup_taxonomy_db() + except Exception as exc: + logging.warning("Taxonomy DB warmup skipped: %s", exc) + logging.info("Worker started. queue=%s", args.queue) while True: try: From 7ee749fc9e4b6b465e230300b050e2adde704ca0 Mon Sep 17 00:00:00 2001 From: frapercan Date: Wed, 18 Mar 2026 13:36:39 +0100 Subject: [PATCH 2/7] test: expand coverage from 65% to 88% (283 -> 831 tests) New test files: - test_logging.py (15 tests): JSONFormatter, configure_logging - test_evaluation.py (+35 tests): load_children_map, build_negative_keys, compute_evaluation_data - test_run_cafa_evaluation.py (50 tests): full operation coverage - test_load_goa_annotations.py (+45 tests): GAF parsing, store buffer, execute - test_load_quickgo_annotations.py (+33 tests): TSV parsing, pagination, ECO mapping - test_annotations_router.py (70 tests): all 23 endpoints - test_embeddings_router.py (+44 tests): configs, predict, prediction sets, CAFA TSV - test_proteins_router.py (17 tests): stats, list, detail, annotations - test_admin_router.py (4 tests): reset-db - test_scoring_router.py (+7 tests): scored TSV, metrics Extended test files: - test_queue.py (+20 tests): OperationConsumer on_message, QueueConsumer retry, DLQ - test_base_worker.py (+24 tests): parent cancel, two-session, publish, reaper, warmup - test_core.py (+11 tests): fetch_uniprot_metadata paths - test_infrastructure.py (+7 tests): health endpoints, app factory, pool config - test_insert_proteins.py (+22 tests): FASTA parsing, store_records, pagination - test_load_ontology_snapshot.py (+17 tests): OBO parsing, relationships, backfill --- tests/test_admin_router.py | 122 +++ tests/test_annotations_router.py | 1249 ++++++++++++++++++++++++ tests/test_base_worker.py | 662 ++++++++++++- tests/test_core.py | 338 +++++++ tests/test_embeddings_router.py | 674 +++++++++++++ tests/test_evaluation.py | 428 +++++++- tests/test_infrastructure.py | 156 ++- tests/test_insert_proteins.py | 453 +++++++++ tests/test_load_goa_annotations.py | 750 +++++++++++++- tests/test_load_ontology_snapshot.py | 375 +++++++ tests/test_load_quickgo_annotations.py | 670 ++++++++++++- tests/test_logging.py | 157 +++ tests/test_proteins_router.py | 353 +++++++ tests/test_queue.py | 573 ++++++++++- tests/test_run_cafa_evaluation.py | 1167 ++++++++++++++++++++++ tests/test_scoring_router.py | 263 ++++- 16 files changed, 8318 insertions(+), 72 deletions(-) create mode 100644 tests/test_admin_router.py create mode 100644 tests/test_annotations_router.py create mode 100644 tests/test_logging.py create mode 100644 tests/test_proteins_router.py create mode 100644 tests/test_run_cafa_evaluation.py diff --git a/tests/test_admin_router.py b/tests/test_admin_router.py new file mode 100644 index 0000000..2da4bc0 --- /dev/null +++ b/tests/test_admin_router.py @@ -0,0 +1,122 @@ +"""Unit tests for the /admin router. + +Database and subprocess calls are fully mocked -- no real infrastructure required. +""" +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from protea.api.routers.admin import router + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_app(): + app = FastAPI() + app.state.session_factory = MagicMock() + app.include_router(router) + return app + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def mock_psycopg(): + """Ensure psycopg is available as a mock in sys.modules for the local import.""" + mock_mod = MagicMock() + conn_ctx = MagicMock() + mock_mod.connect.return_value.__enter__ = MagicMock(return_value=conn_ctx) + mock_mod.connect.return_value.__exit__ = MagicMock(return_value=False) + with patch.dict(sys.modules, {"psycopg": mock_mod}): + yield mock_mod, conn_ctx + + +@pytest.fixture() +def client(mock_psycopg): + app = _make_app() + with TestClient(app) as c: + yield c, app, mock_psycopg + + +# --------------------------------------------------------------------------- +# POST /admin/reset-db +# --------------------------------------------------------------------------- + +class TestResetDB: + @patch("protea.api.routers.admin.build_session_factory") + @patch("protea.api.routers.admin.subprocess.run") + @patch("protea.api.routers.admin.load_settings") + def test_reset_db_success(self, mock_settings, mock_run, mock_build, client): + c, app, (mock_psycopg_mod, conn_ctx) = client + settings = MagicMock() + settings.db_url = "postgresql+psycopg://u:p@localhost/db" + mock_settings.return_value = settings + + mock_run.return_value = MagicMock(returncode=0) + mock_build.return_value = MagicMock() + + resp = c.post("/admin/reset-db") + assert resp.status_code == 200 + assert resp.json()["ok"] is True + mock_build.assert_called_once() + + @patch("protea.api.routers.admin.build_session_factory") + @patch("protea.api.routers.admin.subprocess.run") + @patch("protea.api.routers.admin.load_settings") + def test_reset_db_migration_failure(self, mock_settings, mock_run, mock_build, client): + c, app, (mock_psycopg_mod, conn_ctx) = client + settings = MagicMock() + settings.db_url = "postgresql+psycopg://u:p@localhost/db" + mock_settings.return_value = settings + + mock_run.return_value = MagicMock(returncode=1, stderr="migration error") + + resp = c.post("/admin/reset-db") + assert resp.status_code == 200 + data = resp.json() + assert data["ok"] is False + assert "migration error" in data["error"] + mock_build.assert_not_called() + + @patch("protea.api.routers.admin.build_session_factory") + @patch("protea.api.routers.admin.subprocess.run") + @patch("protea.api.routers.admin.load_settings") + def test_reset_db_drops_and_recreates_schema(self, mock_settings, mock_run, mock_build, client): + c, app, (mock_psycopg_mod, conn_ctx) = client + settings = MagicMock() + settings.db_url = "postgresql+psycopg://u:p@localhost/db" + mock_settings.return_value = settings + + mock_run.return_value = MagicMock(returncode=0) + + resp = c.post("/admin/reset-db") + assert resp.status_code == 200 + conn_ctx.execute.assert_any_call("DROP SCHEMA public CASCADE") + conn_ctx.execute.assert_any_call("CREATE SCHEMA public") + + @patch("protea.api.routers.admin.build_session_factory") + @patch("protea.api.routers.admin.subprocess.run") + @patch("protea.api.routers.admin.load_settings") + def test_reset_db_replaces_psycopg_in_url(self, mock_settings, mock_run, mock_build, client): + c, app, (mock_psycopg_mod, conn_ctx) = client + settings = MagicMock() + settings.db_url = "postgresql+psycopg://u:p@localhost/db" + mock_settings.return_value = settings + + mock_run.return_value = MagicMock(returncode=0) + + resp = c.post("/admin/reset-db") + assert resp.status_code == 200 + # Verify psycopg.connect was called with the URL without +psycopg + mock_psycopg_mod.connect.assert_called_once_with( + "postgresql://u:p@localhost/db", autocommit=True + ) diff --git a/tests/test_annotations_router.py b/tests/test_annotations_router.py new file mode 100644 index 0000000..a5da5a7 --- /dev/null +++ b/tests/test_annotations_router.py @@ -0,0 +1,1249 @@ +"""Unit tests for the /annotations router. + +Database and queue are fully mocked -- no real infrastructure required. +""" +from __future__ import annotations + +from contextlib import contextmanager +from pathlib import Path +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlalchemy.exc import IntegrityError + +from protea.api.routers.annotations import router + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(session_factory, amqp_url="amqp://guest:guest@localhost", artifacts_dir=None): + app = FastAPI() + app.state.session_factory = session_factory + app.state.amqp_url = amqp_url + app.state.artifacts_dir = artifacts_dir or Path("/tmp/protea-test-artifacts") + app.include_router(router) + return app + + +@contextmanager +def _mock_scope(session): + yield session + + +def _make_snapshot(snap_id=None, obo_url="http://obo", obo_version="2024-01-01", ia_url=None): + s = MagicMock() + s.id = snap_id or uuid4() + s.obo_url = obo_url + s.obo_version = obo_version + s.ia_url = ia_url + s.loaded_at = MagicMock() + s.loaded_at.isoformat.return_value = "2024-01-01T00:00:00" + return s + + +def _make_annotation_set(set_id=None, source="goa", source_version="2024-01", snap_id=None, job_id=None): + a = MagicMock() + a.id = set_id or uuid4() + a.source = source + a.source_version = source_version + a.ontology_snapshot_id = snap_id or uuid4() + a.job_id = job_id + a.created_at = MagicMock() + a.created_at.isoformat.return_value = "2024-01-01T00:00:00" + a.meta = {"key": "value"} + return a + + +def _make_evaluation_set(eval_id=None, old_id=None, new_id=None, job_id=None, stats=None): + e = MagicMock() + e.id = eval_id or uuid4() + e.old_annotation_set_id = old_id or uuid4() + e.new_annotation_set_id = new_id or uuid4() + e.job_id = job_id + e.created_at = MagicMock() + e.created_at.isoformat.return_value = "2024-06-01T00:00:00" + e.stats = stats or {"nk": 10, "lk": 5} + return e + + +def _make_evaluation_result(result_id=None, eval_set_id=None, pred_set_id=None, scoring_id=None, job_id=None, results=None): + r = MagicMock() + r.id = result_id or uuid4() + r.evaluation_set_id = eval_set_id or uuid4() + r.prediction_set_id = pred_set_id or uuid4() + r.scoring_config_id = scoring_id + r.job_id = job_id + r.created_at = MagicMock() + r.created_at.isoformat.return_value = "2024-07-01T00:00:00" + r.results = results or {} + return r + + +@pytest.fixture() +def session(): + return MagicMock() + + +@pytest.fixture() +def factory(session): + return MagicMock() + + +@pytest.fixture() +def client(session, factory): + app = _make_app(factory) + with patch("protea.api.routers.annotations.session_scope", side_effect=lambda _: _mock_scope(session)): + with TestClient(app) as c: + yield c, session + + +@pytest.fixture() +def client_with_artifacts(session, factory, tmp_path): + app = _make_app(factory, artifacts_dir=tmp_path) + with patch("protea.api.routers.annotations.session_scope", side_effect=lambda _: _mock_scope(session)): + with TestClient(app) as c: + yield c, session, tmp_path + + +# --------------------------------------------------------------------------- +# GET /annotations/snapshots (lines 71-86) +# --------------------------------------------------------------------------- + + +class TestListSnapshots: + def test_returns_list(self, client): + c, session = client + snap = _make_snapshot() + # Simulate the subquery join: session.query(...).outerjoin(...).order_by(...).all() + session.query.return_value.group_by.return_value.subquery.return_value = MagicMock() + session.query.return_value.outerjoin.return_value.order_by.return_value.all.return_value = [ + (snap, 42) + ] + + resp = c.get("/annotations/snapshots") + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["obo_version"] == "2024-01-01" + assert data[0]["go_term_count"] == 42 + + def test_empty_list(self, client): + c, session = client + session.query.return_value.group_by.return_value.subquery.return_value = MagicMock() + session.query.return_value.outerjoin.return_value.order_by.return_value.all.return_value = [] + + resp = c.get("/annotations/snapshots") + assert resp.status_code == 200 + assert resp.json() == [] + + def test_null_count_defaults_to_zero(self, client): + c, session = client + snap = _make_snapshot() + session.query.return_value.group_by.return_value.subquery.return_value = MagicMock() + session.query.return_value.outerjoin.return_value.order_by.return_value.all.return_value = [ + (snap, None) + ] + + resp = c.get("/annotations/snapshots") + assert resp.status_code == 200 + assert resp.json()[0]["go_term_count"] == 0 + + +# --------------------------------------------------------------------------- +# GET /annotations/snapshots/{snapshot_id} (lines 105-116) +# --------------------------------------------------------------------------- + + +class TestGetSnapshot: + def test_returns_snapshot(self, client): + c, session = client + snap = _make_snapshot() + session.get.return_value = snap + session.query.return_value.filter.return_value.scalar.return_value = 99 + + resp = c.get(f"/annotations/snapshots/{snap.id}") + assert resp.status_code == 200 + data = resp.json() + assert data["obo_version"] == "2024-01-01" + assert data["go_term_count"] == 99 + + def test_not_found(self, client): + c, session = client + session.get.return_value = None + + resp = c.get(f"/annotations/snapshots/{uuid4()}") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# POST /annotations/snapshots/load (lines 176-195) +# --------------------------------------------------------------------------- + + +class TestLoadOntologySnapshot: + def test_success(self, client): + c, session = client + + def add_side(obj): + from protea.infrastructure.orm.models.job import Job + if isinstance(obj, Job): + obj.id = uuid4() + session.add.side_effect = add_side + + with patch("protea.api.routers.annotations.publish_job"): + resp = c.post( + "/annotations/snapshots/load", + json={"obo_url": "http://example.com/go.obo"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "queued" + + def test_invalid_payload(self, client): + c, session = client + resp = c.post("/annotations/snapshots/load", json={}) + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /annotations/sets (lines 207-222) +# --------------------------------------------------------------------------- + + +class TestListAnnotationSets: + def test_returns_list(self, client): + c, session = client + aset = _make_annotation_set() + session.query.return_value.group_by.return_value.subquery.return_value = MagicMock() + q_mock = session.query.return_value.outerjoin.return_value + q_mock.filter.return_value.order_by.return_value.all.return_value = [(aset, 10)] + q_mock.order_by.return_value.all.return_value = [(aset, 10)] + + resp = c.get("/annotations/sets") + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["source"] == "goa" + + def test_filter_by_source(self, client): + c, session = client + aset = _make_annotation_set(source="quickgo") + session.query.return_value.group_by.return_value.subquery.return_value = MagicMock() + q_mock = session.query.return_value.outerjoin.return_value + q_mock.filter.return_value.order_by.return_value.all.return_value = [(aset, 5)] + + resp = c.get("/annotations/sets?source=quickgo") + assert resp.status_code == 200 + + def test_empty(self, client): + c, session = client + session.query.return_value.group_by.return_value.subquery.return_value = MagicMock() + q_mock = session.query.return_value.outerjoin.return_value + q_mock.order_by.return_value.all.return_value = [] + + resp = c.get("/annotations/sets") + assert resp.status_code == 200 + assert resp.json() == [] + + +# --------------------------------------------------------------------------- +# GET /annotations/sets/{set_id} (lines 243-254) +# --------------------------------------------------------------------------- + + +class TestGetAnnotationSet: + def test_returns_set(self, client): + c, session = client + aset = _make_annotation_set(job_id=uuid4()) + session.get.return_value = aset + session.query.return_value.filter.return_value.scalar.return_value = 100 + + resp = c.get(f"/annotations/sets/{aset.id}") + assert resp.status_code == 200 + data = resp.json() + assert data["annotation_count"] == 100 + assert data["job_id"] is not None + + def test_not_found(self, client): + c, session = client + session.get.return_value = None + + resp = c.get(f"/annotations/sets/{uuid4()}") + assert resp.status_code == 404 + + def test_no_job_id(self, client): + c, session = client + aset = _make_annotation_set(job_id=None) + session.get.return_value = aset + session.query.return_value.filter.return_value.scalar.return_value = 0 + + resp = c.get(f"/annotations/sets/{aset.id}") + assert resp.status_code == 200 + assert resp.json()["job_id"] is None + + +# --------------------------------------------------------------------------- +# POST /annotations/sets/load-goa (lines 300-319) +# --------------------------------------------------------------------------- + + +class TestLoadGOAAnnotations: + def test_success(self, client): + c, session = client + + def add_side(obj): + from protea.infrastructure.orm.models.job import Job + if isinstance(obj, Job): + obj.id = uuid4() + session.add.side_effect = add_side + + with patch("protea.api.routers.annotations.publish_job"): + resp = c.post( + "/annotations/sets/load-goa", + json={ + "ontology_snapshot_id": str(uuid4()), + "gaf_url": "http://example.com/goa.gaf.gz", + "source_version": "2024-01", + }, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "queued" + + def test_invalid_payload(self, client): + c, session = client + resp = c.post("/annotations/sets/load-goa", json={}) + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# POST /annotations/sets/load-quickgo (lines 330-349) +# --------------------------------------------------------------------------- + + +class TestLoadQuickGOAnnotations: + def test_success(self, client): + c, session = client + + def add_side(obj): + from protea.infrastructure.orm.models.job import Job + if isinstance(obj, Job): + obj.id = uuid4() + session.add.side_effect = add_side + + with patch("protea.api.routers.annotations.publish_job"): + resp = c.post( + "/annotations/sets/load-quickgo", + json={ + "ontology_snapshot_id": str(uuid4()), + "source_version": "2024-01", + }, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "queued" + + def test_invalid_payload(self, client): + c, session = client + resp = c.post("/annotations/sets/load-quickgo", json={}) + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# Dependency edge cases (lines 45, 52, 57-60) +# --------------------------------------------------------------------------- + + +class TestDependencyGuards: + def test_missing_session_factory_raises(self): + app = FastAPI() + app.include_router(router) + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.get("/annotations/snapshots") + assert resp.status_code == 500 + + def test_missing_amqp_url_raises(self, session): + app = FastAPI() + app.state.session_factory = MagicMock() + # no amqp_url set + app.include_router(router) + with patch("protea.api.routers.annotations.session_scope", side_effect=lambda _: _mock_scope(session)): + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.post("/annotations/snapshots/load", json={"obo_url": "http://example.com/go.obo"}) + assert resp.status_code == 500 + + def test_missing_artifacts_dir_raises(self, session): + app = FastAPI() + app.state.session_factory = MagicMock() + # no artifacts_dir set + app.include_router(router) + eval_id = uuid4() + with patch("protea.api.routers.annotations.session_scope", side_effect=lambda _: _mock_scope(session)): + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.delete(f"/annotations/evaluation-sets/{eval_id}") + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# PATCH /annotations/snapshots/{snapshot_id}/ia-url (lines 146-158) +# --------------------------------------------------------------------------- + + +class TestSetSnapshotIaUrl: + def test_set_ia_url_success(self, client): + c, session = client + snap = _make_snapshot() + session.get.return_value = snap + + resp = c.patch( + f"/annotations/snapshots/{snap.id}/ia-url", + json={"ia_url": "http://example.com/ia.tsv"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == str(snap.id) + assert data["obo_version"] == snap.obo_version + + def test_set_ia_url_null_clears(self, client): + c, session = client + snap = _make_snapshot(ia_url="http://old.com/ia.tsv") + session.get.return_value = snap + + resp = c.patch( + f"/annotations/snapshots/{snap.id}/ia-url", + json={"ia_url": None}, + ) + assert resp.status_code == 200 + + def test_missing_ia_url_key_returns_422(self, client): + c, session = client + snap = _make_snapshot() + + resp = c.patch( + f"/annotations/snapshots/{snap.id}/ia-url", + json={"wrong_key": "value"}, + ) + assert resp.status_code == 422 + + def test_snapshot_not_found_returns_404(self, client): + c, session = client + session.get.return_value = None + + resp = c.patch( + f"/annotations/snapshots/{uuid4()}/ia-url", + json={"ia_url": "http://example.com/ia.tsv"}, + ) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE /annotations/sets/{set_id} (lines 272-289) +# --------------------------------------------------------------------------- + + +class TestDeleteAnnotationSet: + def test_delete_success(self, client): + c, session = client + aset = _make_annotation_set() + session.get.return_value = aset + session.query.return_value.filter.return_value.scalar.return_value = 42 + + resp = c.delete(f"/annotations/sets/{aset.id}") + assert resp.status_code == 200 + data = resp.json() + assert data["deleted"] == str(aset.id) + assert data["annotations_deleted"] == 42 + session.delete.assert_called_once_with(aset) + + def test_delete_not_found(self, client): + c, session = client + session.get.return_value = None + + resp = c.delete(f"/annotations/sets/{uuid4()}") + assert resp.status_code == 404 + + def test_delete_integrity_error_returns_409(self, client): + c, session = client + aset = _make_annotation_set() + session.get.return_value = aset + session.query.return_value.filter.return_value.scalar.return_value = 10 + session.flush.side_effect = IntegrityError("stmt", "params", Exception("fk")) + + resp = c.delete(f"/annotations/sets/{aset.id}") + assert resp.status_code == 409 + assert "referenced" in resp.json()["detail"].lower() + + +# --------------------------------------------------------------------------- +# POST /annotations/evaluation-sets/generate (lines 367-386) +# --------------------------------------------------------------------------- + + +class TestGenerateEvaluationSet: + def test_success(self, client): + c, session = client + old_id, new_id = str(uuid4()), str(uuid4()) + + # Mock Job creation + def add_side(obj): + from protea.infrastructure.orm.models.job import Job + if isinstance(obj, Job): + obj.id = uuid4() + session.add.side_effect = add_side + + with patch("protea.api.routers.annotations.publish_job"): + resp = c.post( + "/annotations/evaluation-sets/generate", + json={"old_annotation_set_id": old_id, "new_annotation_set_id": new_id}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "queued" + + def test_invalid_payload_returns_422(self, client): + c, session = client + resp = c.post("/annotations/evaluation-sets/generate", json={}) + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /annotations/evaluation-sets (lines 394-396) +# --------------------------------------------------------------------------- + + +class TestListEvaluationSets: + def test_returns_list(self, client): + c, session = client + ev = _make_evaluation_set() + session.query.return_value.order_by.return_value.all.return_value = [ev] + + resp = c.get("/annotations/evaluation-sets") + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["id"] == str(ev.id) + assert data[0]["stats"] == ev.stats + + def test_empty_list(self, client): + c, session = client + session.query.return_value.order_by.return_value.all.return_value = [] + + resp = c.get("/annotations/evaluation-sets") + assert resp.status_code == 200 + assert resp.json() == [] + + +# --------------------------------------------------------------------------- +# DELETE /annotations/evaluation-sets/{eval_id} (lines 416-434) +# --------------------------------------------------------------------------- + + +class TestDeleteEvaluationSet: + def test_delete_success(self, client_with_artifacts): + c, session, tmp_path = client_with_artifacts + ev = _make_evaluation_set() + session.get.side_effect = lambda model, id_: ev if id_ == ev.id else None + + # Create a fake result with an artifact directory + result_mock = MagicMock() + result_mock.id = uuid4() + result_dir = tmp_path / str(result_mock.id) + result_dir.mkdir() + (result_dir / "output.tsv").write_text("test") + + session.query.return_value.filter.return_value.all.return_value = [result_mock] + + resp = c.delete(f"/annotations/evaluation-sets/{ev.id}") + assert resp.status_code == 204 + session.delete.assert_called_once_with(ev) + # Artifact directory should be removed + assert not result_dir.exists() + + def test_delete_not_found(self, client_with_artifacts): + c, session, _ = client_with_artifacts + session.get.return_value = None + + resp = c.delete(f"/annotations/evaluation-sets/{uuid4()}") + assert resp.status_code == 404 + + def test_delete_no_artifact_dir(self, client_with_artifacts): + c, session, tmp_path = client_with_artifacts + ev = _make_evaluation_set() + session.get.side_effect = lambda model, id_: ev if id_ == ev.id else None + session.query.return_value.filter.return_value.all.return_value = [] + + resp = c.delete(f"/annotations/evaluation-sets/{ev.id}") + assert resp.status_code == 204 + + +# --------------------------------------------------------------------------- +# GET /annotations/evaluation-sets/{eval_id} (lines 442-446) +# --------------------------------------------------------------------------- + + +class TestGetEvaluationSet: + def test_success(self, client): + c, session = client + ev = _make_evaluation_set(job_id=uuid4()) + session.get.return_value = ev + + resp = c.get(f"/annotations/evaluation-sets/{ev.id}") + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == str(ev.id) + assert data["job_id"] == str(ev.job_id) + + def test_not_found(self, client): + c, session = client + session.get.return_value = None + + resp = c.get(f"/annotations/evaluation-sets/{uuid4()}") + assert resp.status_code == 404 + + def test_no_job_id(self, client): + c, session = client + ev = _make_evaluation_set(job_id=None) + session.get.return_value = ev + + resp = c.get(f"/annotations/evaluation-sets/{ev.id}") + assert resp.status_code == 200 + assert resp.json()["job_id"] is None + + +# --------------------------------------------------------------------------- +# _eval_set_or_404 helper (lines 457-460) -- tested indirectly via GT endpoints +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Ground-truth TSV downloads (lines 475-591) +# --------------------------------------------------------------------------- + + +class _EvalData: + """Fake result of compute_evaluation_data.""" + def __init__(self, nk=None, lk=None, pk=None, known=None): + self.nk = nk or {} + self.lk = lk or {} + self.pk = pk or {} + self.known = known or {} + + +class TestDownloadGroundTruthNK: + def test_success(self, client): + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + def get_side(model, id_): + from protea.infrastructure.orm.models.annotation.evaluation_set import EvaluationSet + from protea.infrastructure.orm.models.annotation.annotation_set import AnnotationSet + if model is EvaluationSet: + return ev + if model is AnnotationSet: + return ann_old + return None + session.get.side_effect = get_side + + fake_data = _EvalData(nk={"P12345": {"GO:0003674", "GO:0008150"}}) + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/ground-truth-NK.tsv") + assert resp.status_code == 200 + assert "text/tab-separated-values" in resp.headers["content-type"] + lines = resp.text.strip().split("\n") + assert len(lines) == 2 + assert "P12345" in lines[0] + + def test_not_found(self, client): + c, session = client + session.get.return_value = None + + resp = c.get(f"/annotations/evaluation-sets/{uuid4()}/ground-truth-NK.tsv") + assert resp.status_code == 404 + + +class TestDownloadGroundTruthLK: + def test_success(self, client): + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + def get_side(model, id_): + from protea.infrastructure.orm.models.annotation.evaluation_set import EvaluationSet + from protea.infrastructure.orm.models.annotation.annotation_set import AnnotationSet + if model is EvaluationSet: + return ev + if model is AnnotationSet: + return ann_old + return None + session.get.side_effect = get_side + + fake_data = _EvalData(lk={"Q99999": {"GO:0005575"}}) + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/ground-truth-LK.tsv") + assert resp.status_code == 200 + lines = resp.text.strip().split("\n") + assert len(lines) == 1 + assert "Q99999\tGO:0005575" in lines[0] + + +class TestDownloadGroundTruthPK: + def test_success(self, client): + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + def get_side(model, id_): + from protea.infrastructure.orm.models.annotation.evaluation_set import EvaluationSet + from protea.infrastructure.orm.models.annotation.annotation_set import AnnotationSet + if model is EvaluationSet: + return ev + if model is AnnotationSet: + return ann_old + return None + session.get.side_effect = get_side + + fake_data = _EvalData(pk={"A00001": {"GO:0003674"}}) + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/ground-truth-PK.tsv") + assert resp.status_code == 200 + assert "A00001\tGO:0003674" in resp.text + + +class TestDownloadKnownTerms: + def test_success(self, client): + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + def get_side(model, id_): + from protea.infrastructure.orm.models.annotation.evaluation_set import EvaluationSet + from protea.infrastructure.orm.models.annotation.annotation_set import AnnotationSet + if model is EvaluationSet: + return ev + if model is AnnotationSet: + return ann_old + return None + session.get.side_effect = get_side + + fake_data = _EvalData(known={"P12345": {"GO:0003674"}, "Q99999": {"GO:0005575"}}) + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/known-terms.tsv") + assert resp.status_code == 200 + lines = resp.text.strip().split("\n") + assert len(lines) == 2 + + +# --------------------------------------------------------------------------- +# GET /annotations/evaluation-sets/{eval_id}/delta-proteins.fasta (lines 615-672) +# --------------------------------------------------------------------------- + + +class TestDownloadDeltaFasta: + def _setup_session(self, session, ev, ann_old, fake_data, protein_rows=None): + def get_side(model, id_): + from protea.infrastructure.orm.models.annotation.evaluation_set import EvaluationSet + from protea.infrastructure.orm.models.annotation.annotation_set import AnnotationSet + if model is EvaluationSet: + return ev + if model is AnnotationSet: + return ann_old + return None + session.get.side_effect = get_side + + if protein_rows is not None: + session.query.return_value.join.return_value.filter.return_value.order_by.return_value.all.return_value = protein_rows + + def test_all_category(self, client): + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + protein = MagicMock() + protein.accession = "P12345" + protein.entry_name = "P12345_HUMAN" + protein.organism = "Homo sapiens" + protein.taxonomy_id = 9606 + seq = MagicMock() + seq.sequence = "ACDEFGHIKLMNPQRST" + + fake_data = _EvalData(nk={"P12345": {"GO:0003674"}}, lk={}) + self._setup_session(session, ev, ann_old, fake_data, protein_rows=[(protein, seq)]) + + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/delta-proteins.fasta") + assert resp.status_code == 200 + assert ">P12345" in resp.text + assert "ACDEFGHIKLMNPQRST" in resp.text + assert "(NK)" in resp.text + + def test_nk_category_filter(self, client): + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + protein = MagicMock() + protein.accession = "P12345" + protein.entry_name = None + protein.organism = None + protein.taxonomy_id = None + seq = MagicMock() + seq.sequence = "ACDEF" + + fake_data = _EvalData(nk={"P12345": {"GO:0003674"}}, lk={"Q99999": {"GO:0005575"}}) + self._setup_session(session, ev, ann_old, fake_data, protein_rows=[(protein, seq)]) + + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/delta-proteins.fasta?category=nk") + assert resp.status_code == 200 + assert ">P12345" in resp.text + + def test_empty_delta_returns_empty_fasta(self, client): + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + fake_data = _EvalData() + self._setup_session(session, ev, ann_old, fake_data, protein_rows=[]) + + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/delta-proteins.fasta") + assert resp.status_code == 200 + assert resp.text == "" + + def test_long_sequence_wraps_at_60(self, client): + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + protein = MagicMock() + protein.accession = "P12345" + protein.entry_name = None + protein.organism = None + protein.taxonomy_id = None + seq = MagicMock() + seq.sequence = "A" * 120 # should wrap to two lines of 60 + + fake_data = _EvalData(nk={"P12345": {"GO:0003674"}}) + self._setup_session(session, ev, ann_old, fake_data, protein_rows=[(protein, seq)]) + + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/delta-proteins.fasta") + lines = resp.text.strip().split("\n") + # header + 2 sequence lines + assert len(lines) == 3 + assert len(lines[1]) == 60 + assert len(lines[2]) == 60 + + def test_pk_category(self, client): + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + protein = MagicMock() + protein.accession = "X00001" + protein.entry_name = "X_MOUSE" + protein.organism = "Mus musculus" + protein.taxonomy_id = 10090 + seq = MagicMock() + seq.sequence = "MMLLL" + + fake_data = _EvalData(pk={"X00001": {"GO:0005575"}}) + self._setup_session(session, ev, ann_old, fake_data, protein_rows=[(protein, seq)]) + + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/delta-proteins.fasta?category=pk") + assert resp.status_code == 200 + assert "(PK)" in resp.text + + def test_all_category_includes_lk(self, client): + """Ensure LK proteins are included when category=all (covers line 632).""" + c, session = client + ev = _make_evaluation_set() + ann_old = _make_annotation_set(snap_id=uuid4()) + + protein = MagicMock() + protein.accession = "Q99999" + protein.entry_name = None + protein.organism = None + protein.taxonomy_id = None + seq = MagicMock() + seq.sequence = "MMMM" + + fake_data = _EvalData(nk={}, lk={"Q99999": {"GO:0005575"}}) + self._setup_session(session, ev, ann_old, fake_data, protein_rows=[(protein, seq)]) + + with patch("protea.api.routers.annotations.compute_evaluation_data", return_value=fake_data): + resp = c.get(f"/annotations/evaluation-sets/{ev.id}/delta-proteins.fasta?category=all") + assert resp.status_code == 200 + assert "(LK)" in resp.text + + +# --------------------------------------------------------------------------- +# POST /annotations/evaluation-sets/{eval_id}/run (lines 698-720) +# --------------------------------------------------------------------------- + + +class TestRunCafaEvaluation: + def test_success(self, client): + c, session = client + eval_id = uuid4() + pred_set_id = str(uuid4()) + ev = _make_evaluation_set(eval_id=eval_id) + session.get.return_value = ev + + def add_side(obj): + from protea.infrastructure.orm.models.job import Job + if isinstance(obj, Job): + obj.id = uuid4() + session.add.side_effect = add_side + + with patch("protea.api.routers.annotations.publish_job"): + resp = c.post( + f"/annotations/evaluation-sets/{eval_id}/run", + json={"prediction_set_id": pred_set_id}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "queued" + + def test_invalid_payload_returns_422(self, client): + c, session = client + eval_id = uuid4() + + resp = c.post(f"/annotations/evaluation-sets/{eval_id}/run", json={}) + assert resp.status_code == 422 + + def test_evaluation_set_not_found(self, client): + c, session = client + eval_id = uuid4() + pred_set_id = str(uuid4()) + session.get.return_value = None + + with patch("protea.api.routers.annotations.publish_job"): + resp = c.post( + f"/annotations/evaluation-sets/{eval_id}/run", + json={"prediction_set_id": pred_set_id}, + ) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# GET .../results/{result_id}/metrics.tsv (lines 732-751) +# --------------------------------------------------------------------------- + + +class TestDownloadEvaluationMetrics: + def test_success_with_results(self, client): + c, session = client + eval_id = uuid4() + result = _make_evaluation_result( + eval_set_id=eval_id, + results={ + "NK": { + "BPO": {"fmax": 0.42, "precision": 0.5, "recall": 0.35, "tau": 0.3, "coverage": 0.8, "n_proteins": 100}, + "MFO": {"fmax": 0.55, "precision": 0.6, "recall": 0.5, "tau": 0.4, "coverage": 0.9, "n_proteins": 80}, + }, + "LK": {}, + }, + ) + session.get.return_value = result + + resp = c.get(f"/annotations/evaluation-sets/{eval_id}/results/{result.id}/metrics.tsv") + assert resp.status_code == 200 + assert "text/tab-separated-values" in resp.headers["content-type"] + lines = resp.text.strip().split("\n") + # header + 2 data lines (NK/BPO and NK/MFO) + assert len(lines) == 3 + assert lines[0].startswith("setting") + assert "NK\tBPO" in lines[1] + + def test_result_not_found(self, client): + c, session = client + eval_id = uuid4() + session.get.return_value = None + + resp = c.get(f"/annotations/evaluation-sets/{eval_id}/results/{uuid4()}/metrics.tsv") + assert resp.status_code == 404 + + def test_result_wrong_eval_set(self, client): + c, session = client + eval_id = uuid4() + result = _make_evaluation_result(eval_set_id=uuid4()) # different eval set + session.get.return_value = result + + resp = c.get(f"/annotations/evaluation-sets/{eval_id}/results/{result.id}/metrics.tsv") + assert resp.status_code == 404 + + def test_empty_results(self, client): + c, session = client + eval_id = uuid4() + result = _make_evaluation_result(eval_set_id=eval_id, results={}) + session.get.return_value = result + + resp = c.get(f"/annotations/evaluation-sets/{eval_id}/results/{result.id}/metrics.tsv") + assert resp.status_code == 200 + lines = resp.text.strip().split("\n") + assert len(lines) == 1 # header only + + +# --------------------------------------------------------------------------- +# GET .../results/{result_id}/artifacts.zip (lines 768-785) +# --------------------------------------------------------------------------- + + +class TestDownloadEvaluationArtifacts: + def test_success(self, client_with_artifacts): + c, session, tmp_path = client_with_artifacts + eval_id = uuid4() + result = _make_evaluation_result(eval_set_id=eval_id) + session.get.return_value = result + + # Create artifact directory with files + result_dir = tmp_path / str(result.id) + result_dir.mkdir() + (result_dir / "pr_curve.tsv").write_text("threshold\tprecision\trecall\n0.5\t0.8\t0.6") + (result_dir / "metrics.json").write_text('{"fmax": 0.42}') + + resp = c.get(f"/annotations/evaluation-sets/{eval_id}/results/{result.id}/artifacts.zip") + assert resp.status_code == 200 + assert "application/zip" in resp.headers["content-type"] + assert len(resp.content) > 0 + + # Verify it's a valid zip + import io, zipfile + with zipfile.ZipFile(io.BytesIO(resp.content)) as zf: + names = zf.namelist() + assert "pr_curve.tsv" in names + assert "metrics.json" in names + + def test_result_not_found(self, client_with_artifacts): + c, session, _ = client_with_artifacts + eval_id = uuid4() + session.get.return_value = None + + resp = c.get(f"/annotations/evaluation-sets/{eval_id}/results/{uuid4()}/artifacts.zip") + assert resp.status_code == 404 + + def test_no_artifacts_directory(self, client_with_artifacts): + c, session, tmp_path = client_with_artifacts + eval_id = uuid4() + result = _make_evaluation_result(eval_set_id=eval_id) + session.get.return_value = result + # No directory created for this result + + resp = c.get(f"/annotations/evaluation-sets/{eval_id}/results/{result.id}/artifacts.zip") + assert resp.status_code == 404 + assert "No artifacts found" in resp.json()["detail"] + + +# --------------------------------------------------------------------------- +# GET .../results (lines 800-809) +# --------------------------------------------------------------------------- + + +class TestListEvaluationResults: + def test_success(self, client): + c, session = client + eval_id = uuid4() + ev = _make_evaluation_set(eval_id=eval_id) + result = _make_evaluation_result(eval_set_id=eval_id, scoring_id=uuid4(), job_id=uuid4()) + + # First call: session.get(EvaluationSet, eval_id) returns ev + session.get.return_value = ev + session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [result] + + resp = c.get(f"/annotations/evaluation-sets/{eval_id}/results") + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["id"] == str(result.id) + + def test_eval_set_not_found(self, client): + c, session = client + session.get.return_value = None + + resp = c.get(f"/annotations/evaluation-sets/{uuid4()}/results") + assert resp.status_code == 404 + + def test_empty_results(self, client): + c, session = client + eval_id = uuid4() + ev = _make_evaluation_set(eval_id=eval_id) + session.get.return_value = ev + session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + + resp = c.get(f"/annotations/evaluation-sets/{eval_id}/results") + assert resp.status_code == 200 + assert resp.json() == [] + + +# --------------------------------------------------------------------------- +# DELETE .../results/{result_id} (lines 834-845) +# --------------------------------------------------------------------------- + + +class TestDeleteEvaluationResult: + def test_success(self, client_with_artifacts): + c, session, tmp_path = client_with_artifacts + eval_id = uuid4() + result = _make_evaluation_result(eval_set_id=eval_id) + session.get.return_value = result + + # Create artifact dir + result_dir = tmp_path / str(result.id) + result_dir.mkdir() + (result_dir / "output.tsv").write_text("data") + + resp = c.delete(f"/annotations/evaluation-sets/{eval_id}/results/{result.id}") + assert resp.status_code == 204 + session.delete.assert_called_once_with(result) + assert not result_dir.exists() + + def test_not_found(self, client_with_artifacts): + c, session, _ = client_with_artifacts + eval_id = uuid4() + session.get.return_value = None + + resp = c.delete(f"/annotations/evaluation-sets/{eval_id}/results/{uuid4()}") + assert resp.status_code == 404 + + def test_wrong_eval_set(self, client_with_artifacts): + c, session, _ = client_with_artifacts + eval_id = uuid4() + result = _make_evaluation_result(eval_set_id=uuid4()) + session.get.return_value = result + + resp = c.delete(f"/annotations/evaluation-sets/{eval_id}/results/{result.id}") + assert resp.status_code == 404 + + def test_no_artifact_dir(self, client_with_artifacts): + c, session, tmp_path = client_with_artifacts + eval_id = uuid4() + result = _make_evaluation_result(eval_set_id=eval_id) + session.get.return_value = result + + resp = c.delete(f"/annotations/evaluation-sets/{eval_id}/results/{result.id}") + assert resp.status_code == 204 + + +# --------------------------------------------------------------------------- +# GET /annotations/snapshots/{snapshot_id}/subgraph (lines 859-927) +# --------------------------------------------------------------------------- + + +class TestGetGoSubgraph: + def _make_go_term(self, db_id, go_id, name="term", aspect="F"): + t = MagicMock() + t.id = db_id + t.go_id = go_id + t.name = name + t.aspect = aspect + t.ontology_snapshot_id = None + return t + + def _make_rel(self, child_id, parent_id, relation_type="is_a"): + r = MagicMock() + r.child_go_term_id = child_id + r.parent_go_term_id = parent_id + r.relation_type = relation_type + r.ontology_snapshot_id = None + return r + + def test_basic_subgraph(self, client): + c, session = client + snap_id = uuid4() + snap = _make_snapshot(snap_id=snap_id) + + seed = self._make_go_term(1, "GO:0003674", "molecular_function") + parent = self._make_go_term(2, "GO:0005488", "binding") + rel = self._make_rel(1, 2, "is_a") + + # session.get for snapshot + session.get.return_value = snap + # session.query(GOTerm).filter(...).all() for seed terms + # session.query(GOTermRelationship).filter(...).all() for rels + # session.query(GOTerm).filter(...).all() for parents + query_mock = session.query.return_value + filter_mock = query_mock.filter.return_value + filter_mock.all.side_effect = [ + [seed], # seed terms query + [rel], # first BFS level relationships + [parent], # parent terms fetch + [], # second BFS level relationships (no more) + ] + + resp = c.get(f"/annotations/snapshots/{snap_id}/subgraph?go_ids=GO:0003674") + assert resp.status_code == 200 + data = resp.json() + assert len(data["nodes"]) == 2 + assert len(data["edges"]) == 1 + # Check that the seed term is marked as is_query + seed_node = [n for n in data["nodes"] if n["go_id"] == "GO:0003674"][0] + assert seed_node["is_query"] is True + parent_node = [n for n in data["nodes"] if n["go_id"] == "GO:0005488"][0] + assert parent_node["is_query"] is False + + def test_snapshot_not_found(self, client): + c, session = client + session.get.return_value = None + + resp = c.get(f"/annotations/snapshots/{uuid4()}/subgraph?go_ids=GO:0003674") + assert resp.status_code == 404 + + def test_no_matching_terms_returns_empty(self, client): + c, session = client + snap = _make_snapshot() + session.get.return_value = snap + session.query.return_value.filter.return_value.all.return_value = [] + + resp = c.get(f"/annotations/snapshots/{snap.id}/subgraph?go_ids=GO:9999999") + assert resp.status_code == 200 + data = resp.json() + assert data == {"nodes": [], "edges": []} + + def test_multiple_go_ids(self, client): + c, session = client + snap = _make_snapshot() + session.get.return_value = snap + + t1 = self._make_go_term(1, "GO:0003674") + t2 = self._make_go_term(2, "GO:0008150") + + query_mock = session.query.return_value + filter_mock = query_mock.filter.return_value + filter_mock.all.side_effect = [ + [t1, t2], # seed terms + [], # no relationships + ] + + resp = c.get(f"/annotations/snapshots/{snap.id}/subgraph?go_ids=GO:0003674,GO:0008150") + assert resp.status_code == 200 + data = resp.json() + assert len(data["nodes"]) == 2 + assert data["edges"] == [] + + def test_bfs_stops_when_frontier_empty(self, client): + """After one BFS level with parents, next level has rels but no new parents -> frontier empty -> break (line 887).""" + c, session = client + snap = _make_snapshot() + session.get.return_value = snap + + seed = self._make_go_term(1, "GO:0003674") + parent = self._make_go_term(2, "GO:0005488") + rel1 = self._make_rel(1, 2, "is_a") + + query_mock = session.query.return_value + filter_mock = query_mock.filter.return_value + filter_mock.all.side_effect = [ + [seed], # seed terms + [rel1], # first BFS: rel from 1->2 + [parent], # fetch parent 2 + [], # second BFS: no rels from frontier {2} + ] + + resp = c.get(f"/annotations/snapshots/{snap.id}/subgraph?go_ids=GO:0003674&depth=5") + assert resp.status_code == 200 + data = resp.json() + assert len(data["nodes"]) == 2 diff --git a/tests/test_base_worker.py b/tests/test_base_worker.py index 15925aa..b72d56d 100644 --- a/tests/test_base_worker.py +++ b/tests/test_base_worker.py @@ -1,18 +1,21 @@ """ -Unit tests for BaseWorker. +Unit tests for BaseWorker and StaleJobReaper. Uses a mocked session factory and a fake Operation — no real DB needed. """ from __future__ import annotations -from unittest.mock import MagicMock +import signal +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from protea.core.contracts.operation import OperationResult +from protea.core.contracts.operation import OperationResult, RetryLaterError from protea.core.contracts.registry import OperationRegistry -from protea.infrastructure.orm.models.job import Job, JobStatus +from protea.infrastructure.orm.models.job import Job, JobEvent, JobStatus from protea.workers.base_worker import BaseWorker, WorkerConfig +from protea.workers.stale_job_reaper import StaleJobReaper # --------------------------------------------------------------------------- # Helpers @@ -145,3 +148,654 @@ def test_progress_fields_are_set(self): assert job.progress_current == 5 assert job.progress_total == 10 + + def test_retry_later_uses_adaptive_backoff(self): + """RetryLaterError delay should increase based on previous retry count.""" + job = _make_job() + session = MagicMock() + session.get.return_value = job + # Simulate 2 previous retries + session.query.return_value.filter.return_value.scalar.return_value = 2 + factory = MagicMock(return_value=session) + + registry, _ = _make_registry(raises=RetryLaterError("GPU busy", delay_seconds=30)) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + with pytest.raises(RetryLaterError) as exc_info: + worker.handle_job(job.id) + + # 30 * 2^2 = 120 seconds + assert exc_info.value.delay_seconds == 120 + assert job.status == JobStatus.QUEUED + + def test_retry_backoff_capped_at_600(self): + """Adaptive backoff should be capped at 600 seconds.""" + job = _make_job() + session = MagicMock() + session.get.return_value = job + # Simulate 10 previous retries → 60 * 2^10 = 61440, capped to 600 + session.query.return_value.filter.return_value.scalar.return_value = 10 + factory = MagicMock(return_value=session) + + registry, _ = _make_registry(raises=RetryLaterError("GPU busy", delay_seconds=60)) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + with pytest.raises(RetryLaterError) as exc_info: + worker.handle_job(job.id) + + assert exc_info.value.delay_seconds == 600 + + +# --------------------------------------------------------------------------- +# StaleJobReaper +# --------------------------------------------------------------------------- + +class TestStaleJobReaper: + def test_reaps_stale_running_jobs(self): + """Jobs in RUNNING for longer than timeout should be marked FAILED.""" + stale_job = MagicMock(spec=Job) + stale_job.id = uuid4() + stale_job.status = JobStatus.RUNNING + stale_job.operation = "compute_embeddings" + stale_job.started_at = datetime.now(timezone.utc) - timedelta(hours=2) + + session = MagicMock() + session.query.return_value.filter.return_value.all.return_value = [stale_job] + factory = MagicMock(return_value=session) + + reaper = StaleJobReaper(factory, timeout_seconds=3600) + count = reaper._reap() + + assert count == 1 + assert stale_job.status == JobStatus.FAILED + assert stale_job.error_code == "JobTimeout" + session.add.assert_called_once() # JobEvent + session.commit.assert_called_once() + + def test_no_stale_jobs_returns_zero(self): + """When no jobs are stale, reaper does nothing.""" + session = MagicMock() + session.query.return_value.filter.return_value.all.return_value = [] + factory = MagicMock(return_value=session) + + reaper = StaleJobReaper(factory, timeout_seconds=3600) + count = reaper._reap() + + assert count == 0 + session.commit.assert_called_once() + + def test_reaper_handles_db_error_gracefully(self): + """If the DB query fails, reaper raises but does not crash permanently.""" + session = MagicMock() + session.query.side_effect = RuntimeError("DB connection lost") + factory = MagicMock(return_value=session) + + reaper = StaleJobReaper(factory, timeout_seconds=3600) + with pytest.raises(RuntimeError, match="DB connection lost"): + reaper._reap() + session.rollback.assert_called_once() + + def test_reaper_rollback_also_fails(self): + """If rollback itself raises, the exception from _reap still propagates.""" + session = MagicMock() + session.query.side_effect = RuntimeError("DB gone") + session.rollback.side_effect = RuntimeError("rollback failed too") + factory = MagicMock(return_value=session) + + reaper = StaleJobReaper(factory, timeout_seconds=3600) + with pytest.raises(RuntimeError, match="DB gone"): + reaper._reap() + session.rollback.assert_called_once() + session.close.assert_called_once() + + def test_run_registers_signal_handlers(self): + """run() should register SIGINT and SIGTERM handlers.""" + factory = MagicMock() + reaper = StaleJobReaper(factory, timeout_seconds=3600) + # Make _reap set _stop=True so the loop exits after one iteration + reaper._stop = False + call_count = [0] + def fake_reap(): + call_count[0] += 1 + reaper._stop = True + return 0 + reaper._reap = fake_reap + + with patch("protea.workers.stale_job_reaper.signal.signal") as mock_signal, \ + patch("protea.workers.stale_job_reaper.time.sleep"): + reaper.run(interval_seconds=1) + + # Should register both SIGINT and SIGTERM + calls = [c[0] for c in mock_signal.call_args_list] + assert (signal.SIGINT, reaper._handle_stop) in calls + assert (signal.SIGTERM, reaper._handle_stop) in calls + + def test_run_loops_and_stops_on_flag(self): + """run() calls _reap repeatedly until _stop is set.""" + factory = MagicMock() + reaper = StaleJobReaper(factory, timeout_seconds=3600) + reap_count = [0] + + def fake_reap(): + reap_count[0] += 1 + if reap_count[0] >= 3: + reaper._stop = True + return 0 + + reaper._reap = fake_reap + + with patch("protea.workers.stale_job_reaper.signal.signal"), \ + patch("protea.workers.stale_job_reaper.time.sleep"): + reaper.run(interval_seconds=1) + + assert reap_count[0] == 3 + + def test_run_logs_reaped_count(self): + """When _reap returns non-zero, run() logs it.""" + factory = MagicMock() + reaper = StaleJobReaper(factory, timeout_seconds=3600) + + def fake_reap(): + reaper._stop = True + return 5 + + reaper._reap = fake_reap + + with patch("protea.workers.stale_job_reaper.signal.signal"), \ + patch("protea.workers.stale_job_reaper.time.sleep"), \ + patch("protea.workers.stale_job_reaper.logger") as mock_logger: + reaper.run(interval_seconds=1) + + # Should have logged the reaped count + info_calls = [str(c) for c in mock_logger.info.call_args_list] + assert any("5" in c for c in info_calls) + + def test_run_catches_reap_exception(self): + """If _reap raises, run() logs the error and continues.""" + factory = MagicMock() + reaper = StaleJobReaper(factory, timeout_seconds=3600) + call_count = [0] + + def failing_reap(): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("transient DB error") + reaper._stop = True + return 0 + + reaper._reap = failing_reap + + with patch("protea.workers.stale_job_reaper.signal.signal"), \ + patch("protea.workers.stale_job_reaper.time.sleep"), \ + patch("protea.workers.stale_job_reaper.logger") as mock_logger: + reaper.run(interval_seconds=1) + + # Should have logged the error but continued + mock_logger.error.assert_called_once() + assert call_count[0] == 2 + + def test_handle_stop_sets_flag(self): + """_handle_stop sets the _stop flag.""" + factory = MagicMock() + reaper = StaleJobReaper(factory, timeout_seconds=3600) + assert reaper._stop is False + reaper._handle_stop(signal.SIGINT, None) + assert reaper._stop is True + + +# --------------------------------------------------------------------------- +# Feature engineering warmup +# --------------------------------------------------------------------------- + +class TestTaxonomyWarmup: + def test_warmup_calls_get_ncbi(self): + from protea.core.feature_engineering import warmup_taxonomy_db + + with patch("protea.core.feature_engineering._get_ncbi") as mock_get, \ + patch("protea.core.feature_engineering._ETE3_AVAILABLE", True): + warmup_taxonomy_db() + mock_get.assert_called_once() + + def test_warmup_skips_when_ete3_unavailable(self): + from protea.core.feature_engineering import warmup_taxonomy_db + + with patch("protea.core.feature_engineering._ETE3_AVAILABLE", False), \ + patch("protea.core.feature_engineering._get_ncbi") as mock_get: + warmup_taxonomy_db() # should not raise + mock_get.assert_not_called() + + +# --------------------------------------------------------------------------- +# BaseWorker — extended coverage +# --------------------------------------------------------------------------- + +class TestBaseWorkerParentCancelled: + """Cover parent_job_id cancellation detection (lines 93-106).""" + + def test_cancelled_parent_cancels_child(self): + """If parent is CANCELLED during claim, child should be CANCELLED too.""" + parent_id = uuid4() + child_job = _make_job(parent_job_id=parent_id) + parent_job = MagicMock(spec=Job) + parent_job.id = parent_id + parent_job.status = JobStatus.CANCELLED + + session = MagicMock() + # session.get returns child_job by default, parent_job when queried by parent_id + def get_side_effect(model, id_val): + if id_val == parent_id: + return parent_job + return child_job + session.get.side_effect = get_side_effect + + factory = MagicMock(return_value=session) + registry, op = _make_registry() + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker.handle_job(child_job.id) + + assert child_job.status == JobStatus.CANCELLED + assert child_job.finished_at is not None + op.execute.assert_not_called() + + def test_active_parent_does_not_cancel_child(self): + """If parent is still RUNNING, child should execute normally.""" + parent_id = uuid4() + child_job = _make_job(parent_job_id=parent_id) + parent_job = MagicMock(spec=Job) + parent_job.id = parent_id + parent_job.status = JobStatus.RUNNING + + session = MagicMock() + def get_side_effect(model, id_val): + if id_val == parent_id: + return parent_job + return child_job + session.get.side_effect = get_side_effect + + factory = MagicMock(return_value=session) + registry, op = _make_registry() + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker.handle_job(child_job.id) + + assert child_job.status == JobStatus.SUCCEEDED + op.execute.assert_called_once() + + +class TestBaseWorkerUnknownOperation: + """Cover unknown operation name — registry.get raises KeyError.""" + + def test_unknown_operation_raises_key_error(self): + """KeyError from registry.get propagates without being caught by inner handler.""" + job = _make_job(operation="nonexistent_op") + session = MagicMock() + session.get.return_value = job + + factory = MagicMock(return_value=session) + # Real registry with no operations registered + registry = OperationRegistry() + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + with pytest.raises(KeyError, match="nonexistent_op"): + worker.handle_job(job.id) + + # Session should still be closed (finally block) + session.close.assert_called() + + +class TestBaseWorkerTwoSessionPattern: + """Verify the two-session pattern: claim session commits before execute session.""" + + def test_two_sessions_are_created(self): + job = _make_job() + sessions = [] + + def make_session(): + s = MagicMock() + s.get.return_value = job + sessions.append(s) + return s + + factory = MagicMock(side_effect=make_session) + registry, op = _make_registry() + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker.handle_job(job.id) + + # Two sessions: claim + execute + assert len(sessions) >= 2 + # Both should have been committed + sessions[0].commit.assert_called() + sessions[1].commit.assert_called() + # Both should have been closed + sessions[0].close.assert_called_once() + sessions[1].close.assert_called_once() + + def test_claim_session_sets_running_before_execute(self): + """First session transitions to RUNNING; second session runs the operation.""" + job = _make_job() + status_log = [] + original_status = job.status + + call_count = [0] + + def make_session(): + s = MagicMock() + s.get.return_value = job + call_count[0] += 1 + current_call = call_count[0] + + original_commit = s.commit + def commit_side_effect(): + status_log.append((current_call, job.status)) + s.commit.side_effect = commit_side_effect + return s + + factory = MagicMock(side_effect=make_session) + registry, _ = _make_registry() + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker.handle_job(job.id) + + # Session 1 should commit with RUNNING, session 2 with SUCCEEDED + assert status_log[0] == (1, JobStatus.RUNNING) + assert status_log[1] == (2, JobStatus.SUCCEEDED) + + +class TestBaseWorkerJobNotFoundOnExecute: + """Cover line 88: job is None on the execute session.""" + + def test_job_disappears_between_sessions(self): + job = _make_job() + call_count = [0] + + def make_session(): + s = MagicMock() + call_count[0] += 1 + if call_count[0] == 1: + # Claim session finds the job + s.get.return_value = job + else: + # Execute session: job is gone + s.get.return_value = None + return s + + factory = MagicMock(side_effect=make_session) + registry, op = _make_registry() + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker.handle_job(job.id) + + op.execute.assert_not_called() + + +class TestBaseWorkerProgressFromResult: + """Cover progress update from OperationResult (lines 139-142).""" + + def test_progress_current_only(self): + job = _make_job() + session = MagicMock() + session.get.return_value = job + factory = MagicMock(return_value=session) + registry, _ = _make_registry(result=OperationResult( + result={}, progress_current=42 + )) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker.handle_job(job.id) + + assert job.progress_current == 42 + + def test_no_progress_fields_leaves_job_unchanged(self): + job = _make_job() + job.progress_current = None + job.progress_total = None + session = MagicMock() + session.get.return_value = job + factory = MagicMock(return_value=session) + registry, _ = _make_registry(result=OperationResult(result={})) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker.handle_job(job.id) + + # progress_current/total should not be set if result has None + # (succeeded is set, but progress fields are untouched) + assert job.status == JobStatus.SUCCEEDED + + +class TestBaseWorkerDeferredResult: + """Cover deferred result handling (lines 144-153).""" + + def test_deferred_result_does_not_set_succeeded(self): + job = _make_job() + session = MagicMock() + session.get.return_value = job + factory = MagicMock(return_value=session) + registry, _ = _make_registry(result=OperationResult( + result={"dispatched": True}, deferred=True + )) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker.handle_job(job.id) + + # Deferred: should NOT transition to SUCCEEDED + assert job.status != JobStatus.SUCCEEDED + # Should remain RUNNING (set in claim phase) + assert job.status == JobStatus.RUNNING + + +class TestBaseWorkerPublishAfterCommit: + """Cover publish_after_commit and publish_operations (lines 169-176).""" + + def test_publish_after_commit_publishes_child_jobs(self): + child_id = uuid4() + job = _make_job() + session = MagicMock() + session.get.return_value = job + factory = MagicMock(return_value=session) + registry, _ = _make_registry(result=OperationResult( + result={}, + publish_after_commit=[("protea.jobs", child_id)], + )) + + worker = BaseWorker( + factory, registry, WorkerConfig(worker_name="test"), + amqp_url="amqp://localhost/", + ) + + with patch("protea.workers.base_worker.publish_job") as mock_pub: + worker.handle_job(job.id) + + mock_pub.assert_called_once_with("amqp://localhost/", "protea.jobs", child_id) + + def test_publish_operations_publishes_ephemeral_messages(self): + job = _make_job() + session = MagicMock() + session.get.return_value = job + factory = MagicMock(return_value=session) + registry, _ = _make_registry(result=OperationResult( + result={}, + publish_operations=[ + ("protea.embeddings.batch", {"batch_data": [1, 2]}), + ], + )) + + worker = BaseWorker( + factory, registry, WorkerConfig(worker_name="test"), + amqp_url="amqp://localhost/", + ) + + with patch("protea.workers.base_worker.publish_operation") as mock_pub: + worker.handle_job(job.id) + + mock_pub.assert_called_once_with( + "amqp://localhost/", "protea.embeddings.batch", {"batch_data": [1, 2]} + ) + + def test_no_amqp_url_skips_publish(self): + """Without amqp_url, publish_after_commit is silently skipped.""" + child_id = uuid4() + job = _make_job() + session = MagicMock() + session.get.return_value = job + factory = MagicMock(return_value=session) + registry, _ = _make_registry(result=OperationResult( + result={}, + publish_after_commit=[("protea.jobs", child_id)], + )) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + + with patch("protea.workers.base_worker.publish_job") as mock_pub: + worker.handle_job(job.id) + + mock_pub.assert_not_called() + + +class TestBaseWorkerEmitProgress: + """Cover emit callback writing _progress_current/_progress_total (lines 124-129).""" + + def test_emit_with_progress_fields_updates_job(self): + job = _make_job() + + sessions = [] + def make_session(): + s = MagicMock() + s.get.return_value = job + sessions.append(s) + return s + + factory = MagicMock(side_effect=make_session) + + def _execute(sess, payload, *, emit): + emit("progress", "step done", {"_progress_current": 5, "_progress_total": 20}, "info") + return OperationResult() + + op = MagicMock() + op.name = "ping" + op.execute.side_effect = _execute + registry = OperationRegistry() + registry.register(op) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker.handle_job(job.id) + + # The emit session (3rd session: claim, execute, emit) should have updated progress + # Find the session where progress was set + assert job.progress_current == 5 + assert job.progress_total == 20 + + +class TestBaseWorkerForceFailJob: + """Cover _force_fail_job (lines 242-263).""" + + def test_force_fail_on_commit_failure(self): + """When execute session commit fails, _force_fail_job is called.""" + job = _make_job() + call_count = [0] + + def make_session(): + s = MagicMock() + s.get.return_value = job + call_count[0] += 1 + current = call_count[0] + if current == 2: + # Execute session: commit raises on second call (after failure recording) + commit_count = [0] + def commit_side(): + commit_count[0] += 1 + if commit_count[0] == 1: + raise RuntimeError("DB connection dropped") + s.commit.side_effect = commit_side + return s + + factory = MagicMock(side_effect=make_session) + registry, _ = _make_registry(raises=ValueError("op failed")) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + + with pytest.raises(ValueError, match="op failed"): + worker.handle_job(job.id) + + # The fallback session (3rd) should have been created + assert call_count[0] >= 3 + + def test_force_fail_direct_call(self): + """Direct test of _force_fail_job method.""" + job_id = uuid4() + session = MagicMock() + factory = MagicMock(return_value=session) + registry, _ = _make_registry() + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + worker._force_fail_job(job_id, ValueError("original")) + + session.execute.assert_called_once() + session.commit.assert_called_once() + session.close.assert_called_once() + + def test_force_fail_handles_fallback_failure(self): + """If the fallback session also fails, it logs but doesn't crash.""" + job_id = uuid4() + session = MagicMock() + session.commit.side_effect = RuntimeError("still broken") + factory = MagicMock(return_value=session) + registry, _ = _make_registry() + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + # Should not raise + worker._force_fail_job(job_id, ValueError("original")) + + session.close.assert_called_once() + + +class TestBaseWorkerMaybeFailParent: + """Cover _maybe_fail_parent (lines 267-302).""" + + def test_all_children_failed_marks_parent_failed(self): + """When all children are terminal and none succeeded, parent fails.""" + parent_id = uuid4() + job = _make_job(parent_job_id=parent_id) + + session = MagicMock() + session.get.return_value = job + # First query: non_terminal count = 0 + # Second query: succeeded count = 0 + query_results = [0, 0] + call_count = [0] + + def scalar_side(): + idx = call_count[0] + call_count[0] += 1 + return query_results[idx] if idx < len(query_results) else 0 + + session.query.return_value.filter.return_value.scalar.side_effect = scalar_side + factory = MagicMock(return_value=session) + registry, _ = _make_registry(raises=RuntimeError("child failed")) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + with pytest.raises(RuntimeError, match="child failed"): + worker.handle_job(job.id) + + # session.execute should have been called for the sa_update on parent + session.execute.assert_called() + + def test_children_still_running_does_not_fail_parent(self): + """If some children are still running, parent is not failed.""" + parent_id = uuid4() + job = _make_job(parent_job_id=parent_id) + + session = MagicMock() + session.get.return_value = job + # non_terminal count = 3 (children still running) + session.query.return_value.filter.return_value.scalar.return_value = 3 + factory = MagicMock(return_value=session) + registry, _ = _make_registry(raises=RuntimeError("child failed")) + + worker = BaseWorker(factory, registry, WorkerConfig(worker_name="test")) + with pytest.raises(RuntimeError, match="child failed"): + worker.handle_job(job.id) + + # session.execute should NOT have been called for parent update + session.execute.assert_not_called() diff --git a/tests/test_core.py b/tests/test_core.py index b995518..f375a5b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -249,3 +249,341 @@ def test_custom_delay(self): def test_is_exception(self): with pytest.raises(RetryLaterError): raise RetryLaterError("test") + + +# --------------------------------------------------------------------------- +# FetchUniProtMetadataOperation +# --------------------------------------------------------------------------- + +import gzip +from io import BytesIO + +from protea.core.operations.fetch_uniprot_metadata import ( + FetchUniProtMetadataOperation, + FetchUniProtMetadataPayload, +) + + +def _noop_emit(*_): + pass + + +def _make_tsv_content(rows: list[dict[str, str]], compressed: bool = True) -> bytes: + """Build a TSV byte string (optionally gzipped) from a list of dicts.""" + if not rows: + header = "Entry\tReviewed\tEntry Name\tOrganism\tGene Names\tLength" + text = header + "\n" + else: + headers = list(rows[0].keys()) + lines = ["\t".join(headers)] + for row in rows: + lines.append("\t".join(row.get(h, "") for h in headers)) + text = "\n".join(lines) + "\n" + + raw = text.encode("utf-8") + if compressed: + buf = BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as f: + f.write(raw) + return buf.getvalue() + return raw + + +class TestFetchUniProtMetadataPayload: + def test_valid_payload(self): + p = FetchUniProtMetadataPayload(search_criteria="organism_id:9606") + assert p.search_criteria == "organism_id:9606" + assert p.page_size == 500 + + def test_empty_search_criteria_raises(self): + with pytest.raises(Exception): + FetchUniProtMetadataPayload(search_criteria=" ") + + def test_empty_user_agent_raises(self): + with pytest.raises(Exception): + FetchUniProtMetadataPayload(search_criteria="test", user_agent="") + + +class TestFetchUniProtMetadataExecute: + def _make_op(self): + op = FetchUniProtMetadataOperation() + op._http = MagicMock() + return op + + def test_execute_empty_page_continues(self): + """Line 108: when rows is empty, continue (skip store).""" + op = self._make_op() + events = [] + + def emit(event, message, fields, level): + events.append(event) + + # Return one page with no data rows, then stop + resp = MagicMock() + resp.status_code = 200 + resp.headers = {"X-Total-Results": "0"} + resp.content = _make_tsv_content([], compressed=True) + op._http.get.return_value = resp + + session = MagicMock() + payload = {"search_criteria": "organism_id:9606", "page_size": 10} + + result = op.execute(session, payload, emit=emit) + assert result.result["rows"] == 0 + assert result.result["pages"] == 1 + + def test_execute_total_limit_truncation(self): + """Lines 110-113: when total_limit is set and rows exceed it, truncate.""" + op = self._make_op() + + # Build 5 rows + rows = [] + for i in range(5): + row = {"Entry": f"P0000{i}", "Reviewed": "reviewed"} + # Add all FIELD_MAP headers as empty + for header in FetchUniProtMetadataOperation.FIELD_MAP.values(): + row[header] = "" + row["Entry Name"] = "" + row["Organism"] = "" + row["Gene Names"] = "" + row["Length"] = "" + rows.append(row) + + resp = MagicMock() + resp.status_code = 200 + resp.headers = {"X-Total-Results": "5"} + resp.content = _make_tsv_content(rows, compressed=True) + op._http.get.return_value = resp + + session = MagicMock() + session.query.return_value.filter.return_value.all.return_value = [] + + payload = { + "search_criteria": "organism_id:9606", + "page_size": 10, + "total_limit": 3, + } + + result = op.execute(session, payload, emit=_noop_emit) + # Should only process 3 rows despite page having 5 + assert result.result["rows"] == 3 + + def test_execute_total_limit_zero_after_truncation(self): + """Line 113: if truncation results in empty rows, break.""" + op = self._make_op() + + rows = [{"Entry": "P00001"}] + for header in FetchUniProtMetadataOperation.FIELD_MAP.values(): + rows[0][header] = "" + rows[0].update({"Reviewed": "", "Entry Name": "", "Organism": "", "Gene Names": "", "Length": ""}) + + # First page returns 1 row, second page returns 1 row + resp1 = MagicMock() + resp1.status_code = 200 + resp1.headers = {"X-Total-Results": "2", "link": '; rel="next"'} + resp1.content = _make_tsv_content(rows, compressed=True) + + resp2 = MagicMock() + resp2.status_code = 200 + resp2.headers = {"X-Total-Results": "2"} + rows2 = [{"Entry": "P00002"}] + for header in FetchUniProtMetadataOperation.FIELD_MAP.values(): + rows2[0][header] = "" + rows2[0].update({"Reviewed": "", "Entry Name": "", "Organism": "", "Gene Names": "", "Length": ""}) + resp2.content = _make_tsv_content(rows2, compressed=True) + + op._http.get.side_effect = [resp1, resp2] + + session = MagicMock() + session.query.return_value.filter.return_value.all.return_value = [] + + payload = { + "search_criteria": "organism_id:9606", + "page_size": 1, + "total_limit": 1, + } + + result = op.execute(session, payload, emit=_noop_emit) + # Should stop after first page (total_limit=1, first page gives 1 row) + assert result.result["rows"] == 1 + + def test_x_total_results_none_on_invalid_header(self): + """Line 227: X-Total-Results header with invalid value.""" + op = self._make_op() + + resp = MagicMock() + resp.status_code = 200 + resp.headers = {"X-Total-Results": "not-a-number"} + resp.content = _make_tsv_content([], compressed=True) + op._http.get.return_value = resp + + session = MagicMock() + payload = {"search_criteria": "test", "page_size": 10} + + result = op.execute(session, payload, emit=_noop_emit) + assert op._total_results is None + + def test_decode_response_uncompressed(self): + """Line 241-242: uncompressed response decoding.""" + op = self._make_op() + resp = MagicMock() + resp.content = b"Entry\tReviewed\nP00001\treviewed\n" + text = op._decode_response(resp, compressed=False) + assert "P00001" in text + + def test_store_rows_empty_accession_skipped(self): + """Line 275: rows with empty Entry are skipped.""" + op = self._make_op() + session = MagicMock() + session.query.return_value.filter.return_value.all.return_value = [] + + p = FetchUniProtMetadataPayload( + search_criteria="test", + update_protein_core=False, + ) + + rows = [{"Entry": "", "Absorption": "test"}] + for header in FetchUniProtMetadataOperation.FIELD_MAP.values(): + if header not in rows[0]: + rows[0][header] = "" + + touched, upserted = op._store_rows(session, rows, p, _noop_emit) + assert touched == 0 + assert upserted == 0 + + def test_store_rows_update_protein_core_fields(self): + """Lines 296-328: update_protein_core fills in missing fields on Protein.""" + op = self._make_op() + session = MagicMock() + + # No existing metadata + session.query.return_value.filter.return_value.all.return_value = [] + + # Create a mock protein with all None fields + protein = MagicMock() + protein.accession = "P12345" + protein.reviewed = None + protein.entry_name = None + protein.organism = None + protein.gene_name = None + protein.length = None + + # Second query().filter().all() returns proteins + call_count = [0] + def query_side_effect(*args): + result = MagicMock() + call_count[0] += 1 + if call_count[0] <= 1: + # First call: metadata lookup + result.filter.return_value.all.return_value = [] + else: + # Second call: protein lookup + result.filter.return_value.all.return_value = [protein] + return result + session.query.side_effect = query_side_effect + + p = FetchUniProtMetadataPayload( + search_criteria="test", + update_protein_core=True, + ) + + row = {"Entry": "P12345", "Reviewed": "reviewed", "Entry Name": "TEST_HUMAN", + "Organism": "Homo sapiens", "Gene Names": "TEST GENE2", "Length": "500"} + for header in FetchUniProtMetadataOperation.FIELD_MAP.values(): + row.setdefault(header, "") + + touched, upserted = op._store_rows(session, [row], p, _noop_emit) + assert protein.reviewed is True + assert protein.entry_name == "TEST_HUMAN" + assert protein.organism == "Homo sapiens" + assert protein.gene_name == "TEST" + assert protein.length == 500 + assert touched == 1 + + def test_store_rows_unreviewed_protein(self): + """Lines 303-305: reviewed == 'unreviewed' sets pr.reviewed = False.""" + op = self._make_op() + session = MagicMock() + + protein = MagicMock() + protein.accession = "Q99999" + protein.reviewed = None + protein.entry_name = None + protein.organism = None + protein.gene_name = None + protein.length = None + + call_count = [0] + def query_side_effect(*args): + result = MagicMock() + call_count[0] += 1 + if call_count[0] <= 1: + result.filter.return_value.all.return_value = [] + else: + result.filter.return_value.all.return_value = [protein] + return result + session.query.side_effect = query_side_effect + + p = FetchUniProtMetadataPayload( + search_criteria="test", + update_protein_core=True, + ) + + row = {"Entry": "Q99999", "Reviewed": "unreviewed"} + for header in FetchUniProtMetadataOperation.FIELD_MAP.values(): + row.setdefault(header, "") + row.setdefault("Entry Name", "") + row.setdefault("Organism", "") + row.setdefault("Gene Names", "") + row.setdefault("Length", "") + + touched, _ = op._store_rows(session, [row], p, _noop_emit) + assert protein.reviewed is False + assert touched == 1 + + def test_store_rows_protein_not_in_db(self): + """Lines 294-295: protein not found in protein_map, no core update.""" + op = self._make_op() + session = MagicMock() + + call_count = [0] + def query_side_effect(*args): + result = MagicMock() + call_count[0] += 1 + if call_count[0] <= 1: + result.filter.return_value.all.return_value = [] + else: + result.filter.return_value.all.return_value = [] # No proteins + return result + session.query.side_effect = query_side_effect + + p = FetchUniProtMetadataPayload( + search_criteria="test", + update_protein_core=True, + ) + + row = {"Entry": "UNKNOWN1", "Reviewed": "reviewed"} + for header in FetchUniProtMetadataOperation.FIELD_MAP.values(): + row.setdefault(header, "") + row.setdefault("Entry Name", "") + row.setdefault("Organism", "") + row.setdefault("Gene Names", "") + row.setdefault("Length", "") + + touched, upserted = op._store_rows(session, [row], p, _noop_emit) + assert touched == 0 + # Still upserted metadata + assert upserted == 1 + + def test_load_existing_metadata_chunks(self): + """Line 346: _load_existing_metadata returns existing metadata by canonical.""" + op = self._make_op() + session = MagicMock() + + m1 = MagicMock() + m1.canonical_accession = "P12345" + session.query.return_value.filter.return_value.all.return_value = [m1] + + result = op._load_existing_metadata(session, ["P12345"], chunk_size=10) + assert "P12345" in result + assert result["P12345"] is m1 diff --git a/tests/test_embeddings_router.py b/tests/test_embeddings_router.py index 89522e9..efed185 100644 --- a/tests/test_embeddings_router.py +++ b/tests/test_embeddings_router.py @@ -342,3 +342,677 @@ def test_multiple_rows_all_included(self, client, session): lines = resp.text.splitlines() assert len(lines) == 6 # 1 header + 5 data + + def test_filter_by_accession(self, client, session): + """The accession query param should filter predictions.""" + set_id = uuid4() + pred = _make_go_prediction("P99999") + gt = _make_go_term() + resp = self._get(client, session, set_id, [(pred, gt)], accession="P99999") + assert resp.status_code == 200 + lines = resp.text.splitlines() + assert len(lines) == 2 + assert "P99999" in lines[1] + + def test_filter_by_aspect(self, client, session): + """The aspect query param should filter predictions.""" + set_id = uuid4() + pred = _make_go_prediction() + gt = _make_go_term(aspect="P") + resp = self._get(client, session, set_id, [(pred, gt)], aspect="P") + assert resp.status_code == 200 + + def test_filter_by_max_distance(self, client, session): + """The max_distance query param should filter predictions.""" + set_id = uuid4() + pred = _make_go_prediction(distance=0.05) + gt = _make_go_term() + resp = self._get(client, session, set_id, [(pred, gt)], max_distance=0.5) + assert resp.status_code == 200 + + def test_alignment_fields_formatted(self, client, session): + """Non-null alignment fields should be formatted with _fmt.""" + set_id = uuid4() + pred = _make_go_prediction() + pred.identity_nw = 0.95123456 + pred.similarity_nw = 0.88 + gt = _make_go_term() + resp = self._get(client, session, set_id, [(pred, gt)]) + lines = resp.text.splitlines() + row = lines[1].split("\t") + header = lines[0].split("\t") + identity_nw_idx = header.index("identity_nw") + assert row[identity_nw_idx] == "0.951235" + + +# --------------------------------------------------------------------------- +# _fmt helper +# --------------------------------------------------------------------------- + +class TestFmt: + def test_none_returns_empty(self): + from protea.api.routers.embeddings import _fmt + assert _fmt(None) == "" + + def test_float_returns_formatted(self): + from protea.api.routers.embeddings import _fmt + assert _fmt(0.123456789) == "0.123457" + + def test_zero_returns_formatted(self): + from protea.api.routers.embeddings import _fmt + assert _fmt(0.0) == "0" + + +# --------------------------------------------------------------------------- +# get_session_factory / get_amqp_url — RuntimeError when not set +# --------------------------------------------------------------------------- + +class TestDependencyGuards: + def test_session_factory_missing_raises(self): + from protea.api.routers.embeddings import get_session_factory + req = MagicMock() + req.app.state = MagicMock(spec=[]) # no session_factory attr + with pytest.raises(RuntimeError, match="session_factory"): + get_session_factory(req) + + def test_amqp_url_missing_raises(self): + from protea.api.routers.embeddings import get_amqp_url + req = MagicMock() + req.app.state = MagicMock(spec=[]) # no amqp_url attr + with pytest.raises(RuntimeError, match="amqp_url"): + get_amqp_url(req) + + +# --------------------------------------------------------------------------- +# Additional validation edge cases +# --------------------------------------------------------------------------- + +class TestValidationEdgeCases: + def test_normalize_residues_non_bool_returns_422(self, client, session): + body = {**_VALID_CONFIG_BODY, "normalize_residues": "yes"} + resp = client.post("/embeddings/configs", json=body) + assert resp.status_code == 422 + assert any("normalize_residues" in str(e) for e in resp.json()["detail"]) + + def test_normalize_non_bool_returns_422(self, client, session): + body = {**_VALID_CONFIG_BODY, "normalize": "yes"} + resp = client.post("/embeddings/configs", json=body) + assert resp.status_code == 422 + assert any("normalize" in str(e) for e in resp.json()["detail"]) + + def test_use_chunking_non_bool_returns_422(self, client, session): + body = {**_VALID_CONFIG_BODY, "use_chunking": "yes"} + resp = client.post("/embeddings/configs", json=body) + assert resp.status_code == 422 + assert any("use_chunking" in str(e) for e in resp.json()["detail"]) + + def test_chunk_size_non_positive_returns_422(self, client, session): + body = {**_VALID_CONFIG_BODY, "chunk_size": -1} + resp = client.post("/embeddings/configs", json=body) + assert resp.status_code == 422 + assert any("chunk_size" in str(e) for e in resp.json()["detail"]) + + def test_chunk_overlap_negative_returns_422(self, client, session): + body = {**_VALID_CONFIG_BODY, "chunk_overlap": -1} + resp = client.post("/embeddings/configs", json=body) + assert resp.status_code == 422 + assert any("chunk_overlap" in str(e) for e in resp.json()["detail"]) + + def test_description_non_string_returns_422(self, client, session): + body = {**_VALID_CONFIG_BODY, "description": 42} + resp = client.post("/embeddings/configs", json=body) + assert resp.status_code == 422 + assert any("description" in str(e) for e in resp.json()["detail"]) + + +# --------------------------------------------------------------------------- +# GET /embeddings/configs/{config_id} +# --------------------------------------------------------------------------- + +class TestGetEmbeddingConfig: + def test_returns_config(self, client, session): + cfg = _make_config() + config_id = cfg.id + session.get.return_value = cfg + # Mock the embedding count query + session.query.return_value.filter.return_value.scalar.return_value = 42 + + resp = client.get(f"/embeddings/configs/{config_id}") + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == str(config_id) + assert data["model_name"] == "facebook/esm2_t33_650M_UR50D" + assert data["embedding_count"] == 42 + + def test_not_found_returns_404(self, client, session): + session.get.return_value = None + resp = client.get(f"/embeddings/configs/{uuid4()}") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE /embeddings/configs/{config_id} — with prediction sets +# --------------------------------------------------------------------------- + +class TestDeleteEmbeddingConfigCascade: + def test_delete_with_prediction_sets(self, client, session): + cfg = _make_config() + config_id = cfg.id + session.get.return_value = cfg + + pred_set_id = uuid4() + # query(PredictionSet.id).filter(...).all() returns [(pred_set_id,)] + session.query.return_value.filter.return_value.all.return_value = [(pred_set_id,)] + # Bulk deletes return counts + session.query.return_value.filter.return_value.delete.return_value = 10 + + resp = client.delete(f"/embeddings/configs/{config_id}") + assert resp.status_code == 200 + data = resp.json() + assert data["deleted"] == str(config_id) + + +# --------------------------------------------------------------------------- +# POST /embeddings/predict +# --------------------------------------------------------------------------- + +class TestPredictGoTerms: + def _make_predict_app(self, session): + factory = MagicMock() + app = _make_app(factory) + return app + + def test_predict_success(self, session): + app = self._make_predict_app(session) + + config_id = uuid4() + ann_id = uuid4() + onto_id = uuid4() + + # session.get returns objects for all three lookups + session.get.return_value = MagicMock() + # session.add captures Job and JobEvent + job_mock = MagicMock() + job_mock.id = 42 + added = [] + + def _fake_add(obj): + added.append(obj) + # If it's a Job, set its id + if hasattr(obj, 'operation'): + obj.id = 42 + + session.add.side_effect = _fake_add + session.flush = MagicMock() + + with patch("protea.api.routers.embeddings.session_scope", side_effect=lambda _: _mock_scope(session)): + with patch("protea.api.routers.embeddings.publish_job") as mock_pub: + client = TestClient(app, raise_server_exceptions=True) + resp = client.post("/embeddings/predict", json={ + "embedding_config_id": str(config_id), + "annotation_set_id": str(ann_id), + "ontology_snapshot_id": str(onto_id), + }) + + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "queued" + mock_pub.assert_called_once() + + def test_predict_invalid_uuid_returns_422(self, session): + app = self._make_predict_app(session) + with patch("protea.api.routers.embeddings.session_scope", side_effect=lambda _: _mock_scope(session)): + client = TestClient(app, raise_server_exceptions=True) + resp = client.post("/embeddings/predict", json={ + "embedding_config_id": "not-a-uuid", + "annotation_set_id": str(uuid4()), + "ontology_snapshot_id": str(uuid4()), + }) + assert resp.status_code == 422 + + def test_predict_config_not_found_returns_404(self, session): + app = self._make_predict_app(session) + # session.get returns None for EmbeddingConfig + session.get.return_value = None + with patch("protea.api.routers.embeddings.session_scope", side_effect=lambda _: _mock_scope(session)): + client = TestClient(app, raise_server_exceptions=True) + resp = client.post("/embeddings/predict", json={ + "embedding_config_id": str(uuid4()), + "annotation_set_id": str(uuid4()), + "ontology_snapshot_id": str(uuid4()), + }) + assert resp.status_code == 404 + + def test_predict_annotation_set_not_found_returns_404(self, session): + app = self._make_predict_app(session) + + def _get_side(model_cls, id_val): + from protea.infrastructure.orm.models.embedding.embedding_config import EmbeddingConfig + if model_cls is EmbeddingConfig: + return MagicMock() + return None + + session.get.side_effect = _get_side + with patch("protea.api.routers.embeddings.session_scope", side_effect=lambda _: _mock_scope(session)): + client = TestClient(app, raise_server_exceptions=True) + resp = client.post("/embeddings/predict", json={ + "embedding_config_id": str(uuid4()), + "annotation_set_id": str(uuid4()), + "ontology_snapshot_id": str(uuid4()), + }) + assert resp.status_code == 404 + + def test_predict_ontology_not_found_returns_404(self, session): + app = self._make_predict_app(session) + + call_count = [0] + def _get_side(model_cls, id_val): + call_count[0] += 1 + from protea.infrastructure.orm.models.annotation.ontology_snapshot import OntologySnapshot + if model_cls is OntologySnapshot: + return None + return MagicMock() + + session.get.side_effect = _get_side + with patch("protea.api.routers.embeddings.session_scope", side_effect=lambda _: _mock_scope(session)): + client = TestClient(app, raise_server_exceptions=True) + resp = client.post("/embeddings/predict", json={ + "embedding_config_id": str(uuid4()), + "annotation_set_id": str(uuid4()), + "ontology_snapshot_id": str(uuid4()), + }) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# GET /embeddings/prediction-sets +# --------------------------------------------------------------------------- + +class TestListPredictionSets: + def test_returns_list(self, client, session): + ps = _make_prediction_set() + ec = _make_config() + ann = MagicMock() + ann.source = "goa" + ann.source_version = "2024-01" + snap = MagicMock() + snap.obo_version = "2024-01-01" + + # The join chain returns list of tuples + session.query.return_value.join.return_value.join.return_value.join.return_value \ + .order_by.return_value.limit.return_value.all.return_value = [(ps, ec, ann, snap)] + # prediction count per set + session.query.return_value.filter.return_value.scalar.return_value = 100 + + resp = client.get("/embeddings/prediction-sets") + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["id"] == str(ps.id) + assert data[0]["embedding_config_name"] == ec.model_name + assert data[0]["annotation_set_label"] == "goa 2024-01" + assert data[0]["ontology_snapshot_version"] == "2024-01-01" + assert data[0]["prediction_count"] == 100 + + def test_annotation_set_without_version(self, client, session): + ps = _make_prediction_set() + ec = _make_config() + ann = MagicMock() + ann.source = "goa" + ann.source_version = None + snap = MagicMock() + snap.obo_version = "2024-01-01" + + session.query.return_value.join.return_value.join.return_value.join.return_value \ + .order_by.return_value.limit.return_value.all.return_value = [(ps, ec, ann, snap)] + session.query.return_value.filter.return_value.scalar.return_value = 0 + + resp = client.get("/embeddings/prediction-sets") + assert resp.status_code == 200 + assert resp.json()[0]["annotation_set_label"] == "goa" + + def test_empty_list(self, client, session): + session.query.return_value.join.return_value.join.return_value.join.return_value \ + .order_by.return_value.limit.return_value.all.return_value = [] + resp = client.get("/embeddings/prediction-sets") + assert resp.status_code == 200 + assert resp.json() == [] + + +# --------------------------------------------------------------------------- +# GET /embeddings/prediction-sets/{set_id} +# --------------------------------------------------------------------------- + +class TestGetPredictionSet: + def test_returns_details(self, client, session): + ps = _make_prediction_set() + ps_id = ps.id + session.get.return_value = ps + session.query.return_value.filter.return_value.scalar.return_value = 50 + session.query.return_value.filter.return_value.group_by.return_value.all.return_value = [ + ("P12345", 30), ("Q67890", 20), + ] + + resp = client.get(f"/embeddings/prediction-sets/{ps_id}") + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == str(ps_id) + assert data["prediction_count"] == 50 + assert data["per_protein_counts"]["P12345"] == 30 + assert data["per_protein_counts"]["Q67890"] == 20 + + def test_not_found_returns_404(self, client, session): + session.get.return_value = None + resp = client.get(f"/embeddings/prediction-sets/{uuid4()}") + assert resp.status_code == 404 + + def test_with_query_set_id(self, client, session): + ps = _make_prediction_set() + ps.query_set_id = uuid4() + session.get.return_value = ps + session.query.return_value.filter.return_value.scalar.return_value = 0 + session.query.return_value.filter.return_value.group_by.return_value.all.return_value = [] + + resp = client.get(f"/embeddings/prediction-sets/{ps.id}") + assert resp.status_code == 200 + assert resp.json()["query_set_id"] == str(ps.query_set_id) + + +# --------------------------------------------------------------------------- +# GET /embeddings/prediction-sets/{set_id}/proteins +# --------------------------------------------------------------------------- + +class TestListPredictionSetProteins: + def _setup_proteins_mocks(self, session, ps, rows_data): + """Set up the complex mock chain for the proteins endpoint.""" + # We need to carefully control the mock chain. + # The endpoint does multiple session.query(...) calls with different args. + # Use a side_effect on session.query to return different mocks per call. + call_idx = [0] + main_q = MagicMock() + main_q.filter.return_value = main_q + main_q.group_by.return_value = main_q + main_q.count.return_value = len(rows_data) + main_q.order_by.return_value = main_q + main_q.offset.return_value = main_q + main_q.limit.return_value = main_q + main_q.all.return_value = rows_data + + prot_q = MagicMock() + prot_mock = MagicMock() + prot_mock.accession = rows_data[0][0] if rows_data else "X" + prot_q.filter.return_value = prot_q + prot_q.all.return_value = [prot_mock] if rows_data else [] + + ann_q = MagicMock() + ann_q.filter.return_value = ann_q + ann_q.group_by.return_value = ann_q + ann_q.all.return_value = [(rows_data[0][0], 5)] if rows_data else [] + + match_q = MagicMock() + match_q.join.return_value = match_q + match_q.filter.return_value = match_q + match_q.group_by.return_value = match_q + match_q.all.return_value = [(rows_data[0][0], 3)] if rows_data else [] + + queries = [main_q, prot_q, ann_q, match_q] + + def _query_side(*args, **kwargs): + idx = call_idx[0] + call_idx[0] += 1 + if idx < len(queries): + return queries[idx] + return MagicMock() + + session.query.side_effect = _query_side + + def test_returns_paginated_proteins(self, client, session): + ps = _make_prediction_set() + ps_id = ps.id + session.get.return_value = ps + self._setup_proteins_mocks(session, ps, [("P12345", 10, 0.05)]) + + resp = client.get(f"/embeddings/prediction-sets/{ps_id}/proteins") + assert resp.status_code == 200 + data = resp.json() + assert "total" in data + assert "items" in data + assert len(data["items"]) == 1 + item = data["items"][0] + assert item["accession"] == "P12345" + assert item["go_count"] == 10 + assert item["in_db"] is True + + def test_not_found_returns_404(self, client, session): + session.get.return_value = None + resp = client.get(f"/embeddings/prediction-sets/{uuid4()}/proteins") + assert resp.status_code == 404 + + def test_search_filter(self, client, session): + ps = _make_prediction_set() + session.get.return_value = ps + + call_idx = [0] + main_q = MagicMock() + main_q.filter.return_value = main_q + main_q.group_by.return_value = main_q + main_q.count.return_value = 0 + main_q.order_by.return_value = main_q + main_q.offset.return_value = main_q + main_q.limit.return_value = main_q + main_q.all.return_value = [] + + prot_q = MagicMock() + prot_q.filter.return_value = prot_q + prot_q.all.return_value = [] + + queries = [main_q, prot_q] + + def _query_side(*args, **kwargs): + idx = call_idx[0] + call_idx[0] += 1 + if idx < len(queries): + return queries[idx] + return MagicMock() + + session.query.side_effect = _query_side + + resp = client.get( + f"/embeddings/prediction-sets/{ps.id}/proteins", + params={"search": "P123"}, + ) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# GET /embeddings/prediction-sets/{set_id}/proteins/{accession} +# --------------------------------------------------------------------------- + +class TestGetProteinPredictions: + def test_returns_predictions(self, client, session): + ps = _make_prediction_set() + ps_id = ps.id + session.get.return_value = ps + + pred = _make_go_prediction("P12345", distance=0.1) + gt = _make_go_term("GO:0003824", "catalytic activity", "F") + + session.query.return_value.join.return_value.filter.return_value \ + .order_by.return_value.all.return_value = [(pred, gt)] + + resp = client.get(f"/embeddings/prediction-sets/{ps_id}/proteins/P12345") + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["go_id"] == "GO:0003824" + assert data[0]["name"] == "catalytic activity" + assert data[0]["aspect"] == "F" + assert data[0]["distance"] == pytest.approx(0.1, abs=1e-4) + assert data[0]["ref_protein_accession"] == "QREF01" + # Alignment fields should be None + assert data[0]["identity_nw"] is None + assert data[0]["taxonomic_relation"] is None + + def test_not_found_returns_404(self, client, session): + session.get.return_value = None + resp = client.get(f"/embeddings/prediction-sets/{uuid4()}/proteins/P12345") + assert resp.status_code == 404 + + def test_empty_predictions_returns_empty_list(self, client, session): + ps = _make_prediction_set() + session.get.return_value = ps + session.query.return_value.join.return_value.filter.return_value \ + .order_by.return_value.all.return_value = [] + resp = client.get(f"/embeddings/prediction-sets/{ps.id}/proteins/UNKNOWN") + assert resp.status_code == 200 + assert resp.json() == [] + + +# --------------------------------------------------------------------------- +# GET /embeddings/prediction-sets/{set_id}/go-terms +# --------------------------------------------------------------------------- + +class TestGoTermDistribution: + def test_returns_distribution(self, client, session): + ps = _make_prediction_set() + ps_id = ps.id + session.get.return_value = ps + + # Top terms query + session.query.return_value.join.return_value.filter.return_value \ + .group_by.return_value.order_by.return_value.limit.return_value \ + .all.return_value = [ + ("GO:0003824", "catalytic activity", "F", 50), + ("GO:0005515", "protein binding", "F", 30), + ("GO:0008150", "biological_process", "P", 20), + ] + + # Aspect counts query + session.query.return_value.join.return_value.filter.return_value \ + .group_by.return_value.all.return_value = [ + ("F", 80), ("P", 20), + ] + + resp = client.get(f"/embeddings/prediction-sets/{ps_id}/go-terms") + assert resp.status_code == 200 + data = resp.json() + assert "by_aspect" in data + assert "aspect_totals" in data + assert "top_terms" in data + + def test_not_found_returns_404(self, client, session): + session.get.return_value = None + resp = client.get(f"/embeddings/prediction-sets/{uuid4()}/go-terms") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# GET /embeddings/prediction-sets/{set_id}/predictions-cafa.tsv +# --------------------------------------------------------------------------- + +class TestDownloadPredictionsCafa: + def _get_cafa(self, client, session, set_id, rows, **params): + ps = _make_prediction_set(set_id) + session.get.return_value = ps + + q = MagicMock() + q.filter.return_value = q + q.order_by.return_value = q + q.yield_per.return_value = iter(rows) + session.query.return_value.join.return_value.filter.return_value = q + + return client.get( + f"/embeddings/prediction-sets/{set_id}/predictions-cafa.tsv", + params=params, + ) + + def test_returns_cafa_format(self, client, session): + set_id = uuid4() + pred = _make_go_prediction("P12345", distance=0.3) + gt = _make_go_term("GO:0003824") + resp = self._get_cafa(client, session, set_id, [(pred, gt)]) + assert resp.status_code == 200 + assert "tab-separated" in resp.headers["content-type"] + lines = resp.text.splitlines() + assert len(lines) == 1 + parts = lines[0].split("\t") + assert parts[0] == "P12345" + assert parts[1] == "GO:0003824" + # score = max(0, 1 - 0.3) = 0.7 + assert float(parts[2]) == pytest.approx(0.7, abs=1e-3) + + def test_cafa_deduplicates_go_terms(self, client, session): + """Same (protein, GO term) pair should appear only once (best score).""" + set_id = uuid4() + pred1 = _make_go_prediction("P12345", distance=0.2) + pred2 = _make_go_prediction("P12345", distance=0.5) + gt = _make_go_term("GO:0003824") + # Both have the same protein + GO id + resp = self._get_cafa(client, session, set_id, [(pred1, gt), (pred2, gt)]) + assert resp.status_code == 200 + lines = resp.text.splitlines() + assert len(lines) == 1 # deduplicated + + def test_cafa_not_found_returns_404(self, client, session): + session.get.return_value = None + with patch("protea.api.routers.embeddings.session_scope", side_effect=lambda _: _mock_scope(session)): + resp = client.get(f"/embeddings/prediction-sets/{uuid4()}/predictions-cafa.tsv") + assert resp.status_code == 404 + + def test_cafa_content_disposition(self, client, session): + set_id = uuid4() + resp = self._get_cafa(client, session, set_id, []) + disposition = resp.headers.get("content-disposition", "") + assert "attachment" in disposition + assert "cafa" in disposition + + def test_cafa_filter_by_aspect(self, client, session): + set_id = uuid4() + pred = _make_go_prediction("P12345", distance=0.1) + gt = _make_go_term("GO:0003824", aspect="F") + resp = self._get_cafa(client, session, set_id, [(pred, gt)], aspect="F") + assert resp.status_code == 200 + + def test_cafa_filter_by_max_distance(self, client, session): + set_id = uuid4() + pred = _make_go_prediction("P12345", distance=0.05) + gt = _make_go_term("GO:0003824") + resp = self._get_cafa(client, session, set_id, [(pred, gt)], max_distance=0.5) + assert resp.status_code == 200 + + def test_cafa_score_clamps_at_zero(self, client, session): + """When distance > 1.0 the score should be 0.0, not negative.""" + set_id = uuid4() + pred = _make_go_prediction("P12345", distance=2.5) + gt = _make_go_term("GO:0003824") + resp = self._get_cafa(client, session, set_id, [(pred, gt)]) + lines = resp.text.splitlines() + assert len(lines) == 1 + score = float(lines[0].split("\t")[2]) + assert score == 0.0 + + +# --------------------------------------------------------------------------- +# DELETE /embeddings/prediction-sets/{set_id} +# --------------------------------------------------------------------------- + +class TestDeletePredictionSet: + def test_delete_existing_returns_200(self, client, session): + ps = _make_prediction_set() + ps_id = ps.id + session.get.return_value = ps + session.query.return_value.filter.return_value.delete.return_value = 25 + + resp = client.delete(f"/embeddings/prediction-sets/{ps_id}") + assert resp.status_code == 200 + data = resp.json() + assert data["deleted"] == str(ps_id) + assert data["predictions_deleted"] == 25 + session.delete.assert_called_once_with(ps) + + def test_delete_nonexistent_returns_404(self, client, session): + session.get.return_value = None + resp = client.delete(f"/embeddings/prediction-sets/{uuid4()}") + assert resp.status_code == 404 diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index f88a759..cb9a3bd 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -1,7 +1,18 @@ -"""Tests for protea.core.evaluation — pure-Python components.""" +"""Tests for protea.core.evaluation — pure-Python components + mocked DB tests.""" +import uuid +from unittest.mock import MagicMock, patch + import pytest -from protea.core.evaluation import EvaluationData, _get_descendants +from protea.core.evaluation import ( + EvaluationData, + _build_negative_keys, + _get_descendants, + _load_children_map, + _load_experimental_annotations_by_ns, + _load_go_maps, + compute_evaluation_data, +) # --------------------------------------------------------------------------- @@ -141,3 +152,416 @@ def test_leaf_node(self): children_map = {1: {2}, 2: set()} result = _get_descendants(1, children_map) assert result == {2} + + +# --------------------------------------------------------------------------- +# _load_children_map — lines 124-137 +# --------------------------------------------------------------------------- + +class TestLoadChildrenMap: + def test_loads_and_groups_by_parent(self): + snap_id = uuid.uuid4() + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [ + (10, 20), + (10, 30), + (20, 40), + ] + result = _load_children_map(mock_session, snap_id) + assert result == {10: {20, 30}, 20: {40}} + + def test_empty_result(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [] + result = _load_children_map(mock_session, uuid.uuid4()) + assert result == {} + + def test_passes_snapshot_id(self): + snap_id = uuid.uuid4() + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [] + _load_children_map(mock_session, snap_id) + call_args = mock_session.execute.call_args + assert call_args[0][1]["snap_id"] == snap_id + + def test_single_relationship(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [(1, 2)] + result = _load_children_map(mock_session, uuid.uuid4()) + assert result == {1: {2}} + + +# --------------------------------------------------------------------------- +# _load_go_maps — lines 161-169 +# --------------------------------------------------------------------------- + +class TestLoadGoMaps: + def test_basic_maps(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [ + (1, "GO:0001", "F"), + (2, "GO:0002", "P"), + (3, "GO:0003", "C"), + ] + id_map, aspect_map = _load_go_maps(mock_session, uuid.uuid4()) + assert id_map == {1: "GO:0001", 2: "GO:0002", 3: "GO:0003"} + assert aspect_map == {1: "F", 2: "P", 3: "C"} + + def test_null_aspect_excluded_from_aspect_map(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [ + (1, "GO:0001", "F"), + (2, "GO:0002", None), + ] + id_map, aspect_map = _load_go_maps(mock_session, uuid.uuid4()) + assert id_map == {1: "GO:0001", 2: "GO:0002"} + assert 2 not in aspect_map + assert aspect_map == {1: "F"} + + def test_empty(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [] + id_map, aspect_map = _load_go_maps(mock_session, uuid.uuid4()) + assert id_map == {} + assert aspect_map == {} + + +# --------------------------------------------------------------------------- +# _build_negative_keys — lines 182-204 +# --------------------------------------------------------------------------- + +class TestBuildNegativeKeys: + def test_no_not_annotations(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [] + result = _build_negative_keys(mock_session, [uuid.uuid4()], {}) + assert result == set() + + def test_single_not_no_descendants(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [("P1", 100)] + result = _build_negative_keys(mock_session, [uuid.uuid4()], {}) + assert result == {("P1", 100)} + + def test_not_with_descendants(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [("P1", 100)] + children_map = {100: {200, 300}, 200: {400}} + result = _build_negative_keys(mock_session, [uuid.uuid4()], children_map) + assert result == {("P1", 100), ("P1", 200), ("P1", 300), ("P1", 400)} + + def test_multiple_proteins(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [ + ("P1", 10), + ("P2", 20), + ] + children_map = {10: {11}} + result = _build_negative_keys(mock_session, [uuid.uuid4()], children_map) + assert ("P1", 10) in result + assert ("P1", 11) in result + assert ("P2", 20) in result + + def test_duplicate_rows_deduplicated(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [ + ("P1", 10), + ("P1", 10), + ] + result = _build_negative_keys(mock_session, [uuid.uuid4()], {}) + assert result == {("P1", 10)} + + def test_passes_set_ids(self): + ids = [uuid.uuid4(), uuid.uuid4()] + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [] + _build_negative_keys(mock_session, ids, {}) + call_args = mock_session.execute.call_args + assert call_args[0][1]["set_ids"] == ids + + +# --------------------------------------------------------------------------- +# _load_experimental_annotations_by_ns — lines 219-238 +# --------------------------------------------------------------------------- + +class TestLoadExperimentalAnnotationsByNs: + def _go_id_map(self): + return {100: "GO:0001", 200: "GO:0002", 300: "GO:0003", 400: "GO:0004"} + + def _aspect_map(self): + return {100: "F", 200: "P", 300: "C", 400: "F"} + + def test_groups_by_protein_and_namespace(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [ + ("P1", 100), + ("P1", 200), + ("P2", 300), + ] + result = _load_experimental_annotations_by_ns( + mock_session, uuid.uuid4(), set(), self._go_id_map(), self._aspect_map() + ) + assert result["P1"]["F"] == {"GO:0001"} + assert result["P1"]["P"] == {"GO:0002"} + assert result["P2"]["C"] == {"GO:0003"} + + def test_negative_keys_excluded(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [ + ("P1", 100), + ("P1", 200), + ] + negative_keys = {("P1", 100)} + result = _load_experimental_annotations_by_ns( + mock_session, uuid.uuid4(), negative_keys, self._go_id_map(), self._aspect_map() + ) + assert "F" not in result.get("P1", {}) + assert result["P1"]["P"] == {"GO:0002"} + + def test_missing_go_id_skipped(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [("P1", 999)] + result = _load_experimental_annotations_by_ns( + mock_session, uuid.uuid4(), set(), self._go_id_map(), self._aspect_map() + ) + assert result == {} + + def test_missing_aspect_skipped(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [("P1", 100)] + result = _load_experimental_annotations_by_ns( + mock_session, uuid.uuid4(), set(), {100: "GO:0001"}, {} + ) + assert result == {} + + def test_empty_rows(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [] + result = _load_experimental_annotations_by_ns( + mock_session, uuid.uuid4(), set(), {}, {} + ) + assert result == {} + + def test_multiple_terms_same_namespace(self): + mock_session = MagicMock() + mock_session.execute.return_value.fetchall.return_value = [ + ("P1", 100), + ("P1", 400), # also F namespace + ] + result = _load_experimental_annotations_by_ns( + mock_session, uuid.uuid4(), set(), self._go_id_map(), self._aspect_map() + ) + assert result["P1"]["F"] == {"GO:0001", "GO:0004"} + + +# --------------------------------------------------------------------------- +# compute_evaluation_data — lines 265-322 +# --------------------------------------------------------------------------- + +class TestComputeEvaluationData: + def _ids(self): + return uuid.uuid4(), uuid.uuid4(), uuid.uuid4() + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_nk_protein(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """Protein with no old annotations -> NK.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {}, # old + {"P1": {"F": {"GO:0001", "GO:0002"}}}, # new + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.nk == {"P1": {"GO:0001", "GO:0002"}} + assert result.lk == {} + assert result.pk == {} + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_lk_protein(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """Protein had F at t0, gains P (no old P) -> LK in P.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {"P1": {"F": {"GO:0001"}}}, + {"P1": {"F": {"GO:0001"}, "P": {"GO:0002"}}}, + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.nk == {} + assert result.lk == {"P1": {"GO:0002"}} + assert result.pk == {} + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_pk_protein(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """Protein had F at t0, gains new F -> PK in F.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {"P1": {"F": {"GO:0001"}}}, + {"P1": {"F": {"GO:0001", "GO:0002"}}}, + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.nk == {} + assert result.lk == {} + assert result.pk == {"P1": {"GO:0002"}} + assert result.pk_known == {"P1": {"GO:0001"}} + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_mixed_lk_and_pk(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """Same protein: PK in F, LK in C.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {"P1": {"F": {"GO:0001"}}}, + {"P1": {"F": {"GO:0001", "GO:0002"}, "C": {"GO:0003"}}}, + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.pk == {"P1": {"GO:0002"}} + assert result.lk == {"P1": {"GO:0003"}} + assert result.pk_known == {"P1": {"GO:0001"}} + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_no_new_annotations(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """Protein only in old -> skipped (no new_all).""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {"P1": {"F": {"GO:0001"}}}, + {}, + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.nk == {} + assert result.lk == {} + assert result.pk == {} + assert result.delta_proteins == 0 + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_no_delta_same_terms(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """Old and new identical -> no delta.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {"P1": {"F": {"GO:0001"}}}, + {"P1": {"F": {"GO:0001"}}}, + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.nk == {} + assert result.lk == {} + assert result.pk == {} + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_known_includes_all_old(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """known dict contains all old experimental annotations flattened.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {"P1": {"F": {"GO:0001"}, "P": {"GO:0002"}}, "P2": {"C": {"GO:0003"}}}, + {"P1": {"F": {"GO:0001"}, "P": {"GO:0002", "GO:0099"}}}, + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.known == {"P1": {"GO:0001", "GO:0002"}, "P2": {"GO:0003"}} + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_multiple_proteins(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """Multiple proteins with different categories.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {"P_old": {"F": {"GO:0001"}}}, + {"P_old": {"F": {"GO:0001", "GO:0002"}}, "P_nk": {"P": {"GO:0010"}}}, + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.nk == {"P_nk": {"GO:0010"}} + assert result.pk == {"P_old": {"GO:0002"}} + assert result.pk_known == {"P_old": {"GO:0001"}} + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_protein_with_empty_new_namespaces(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """Protein key in new but no namespace data -> new_all empty -> skip.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {}, + {"P1": {}}, + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.nk == {} + assert result.lk == {} + assert result.pk == {} + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_all_three_namespaces_pk(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """All three namespaces (F, P, C) gain new terms -> PK in all.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [ + {"P1": {"F": {"GO:F1"}, "P": {"GO:P1"}, "C": {"GO:C1"}}}, + {"P1": {"F": {"GO:F1", "GO:F2"}, "P": {"GO:P1", "GO:P2"}, "C": {"GO:C1", "GO:C2"}}}, + ] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.pk == {"P1": {"GO:F2", "GO:P2", "GO:C2"}} + assert result.pk_known == {"P1": {"GO:F1", "GO:P1", "GO:C1"}} + + @patch("protea.core.evaluation._load_experimental_annotations_by_ns") + @patch("protea.core.evaluation._build_negative_keys") + @patch("protea.core.evaluation._load_children_map") + @patch("protea.core.evaluation._load_go_maps") + def test_both_empty(self, mock_go_maps, mock_children, mock_neg, mock_annots): + """Both old and new empty -> empty result.""" + old_id, new_id, snap_id = self._ids() + mock_go_maps.return_value = ({}, {}) + mock_children.return_value = {} + mock_neg.return_value = set() + mock_annots.side_effect = [{}, {}] + result = compute_evaluation_data(MagicMock(), old_id, new_id, snap_id) + assert result.delta_proteins == 0 + assert result.known == {} diff --git a/tests/test_infrastructure.py b/tests/test_infrastructure.py index 9cbd0a2..a92c508 100644 --- a/tests/test_infrastructure.py +++ b/tests/test_infrastructure.py @@ -85,7 +85,14 @@ def test_returns_engine(self): with patch("protea.infrastructure.database.engine.create_engine") as mock_create: mock_create.return_value = MagicMock() engine = build_engine("sqlite:///:memory:") - mock_create.assert_called_once_with("sqlite:///:memory:", future=True, pool_pre_ping=True) + mock_create.assert_called_once_with( + "sqlite:///:memory:", + future=True, + pool_pre_ping=True, + pool_size=20, + max_overflow=40, + pool_recycle=3600, + ) assert engine is mock_create.return_value @@ -132,3 +139,150 @@ def test_jobs_router_is_registered(self): routes = [r.path for r in app.routes] assert any("/jobs" in p for p in routes) + + def test_health_endpoint_registered(self): + from protea.api.app import create_app + + mock_settings = MagicMock() + mock_settings.db_url = "sqlite:///:memory:" + mock_settings.amqp_url = "amqp://guest:guest@localhost/" + + with patch("protea.api.app.load_settings", return_value=mock_settings), \ + patch("protea.api.app.build_session_factory", return_value=MagicMock()): + app = create_app(Path("/fake/root")) + + routes = [r.path for r in app.routes] + assert "/health" in routes + assert "/health/ready" in routes + + def test_health_endpoint_returns_ok(self): + """GET /health returns 200 with status ok.""" + from fastapi.testclient import TestClient + from protea.api.app import create_app + + mock_settings = MagicMock() + mock_settings.db_url = "sqlite:///:memory:" + mock_settings.amqp_url = "amqp://guest:guest@localhost/" + + with patch("protea.api.app.load_settings", return_value=mock_settings), \ + patch("protea.api.app.build_session_factory", return_value=MagicMock()): + app = create_app(Path("/fake/root")) + + client = TestClient(app) + resp = client.get("/health") + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + def test_readiness_check_succeeds(self): + """GET /health/ready returns 200 when DB and RabbitMQ are reachable.""" + from fastapi.testclient import TestClient + from protea.api.app import create_app + + mock_settings = MagicMock() + mock_settings.db_url = "sqlite:///:memory:" + mock_settings.amqp_url = "amqp://guest:guest@localhost/" + + mock_factory = MagicMock() + mock_session = MagicMock() + mock_factory.return_value = mock_session + mock_session.__enter__ = lambda s: s + mock_session.__exit__ = MagicMock(return_value=False) + + with patch("protea.api.app.load_settings", return_value=mock_settings), \ + patch("protea.api.app.build_session_factory", return_value=mock_factory): + app = create_app(Path("/fake/root")) + + mock_conn = MagicMock() + with patch("protea.infrastructure.session.session_scope") as mock_scope, \ + patch("pika.BlockingConnection", return_value=mock_conn) as mock_pika: + mock_scope.return_value.__enter__ = lambda s: mock_session + mock_scope.return_value.__exit__ = MagicMock(return_value=False) + client = TestClient(app) + resp = client.get("/health/ready") + + assert resp.status_code == 200 + assert resp.json() == {"status": "ready"} + + def test_readiness_check_fails_when_rabbitmq_down(self): + """GET /health/ready returns 503 when RabbitMQ is unreachable.""" + from fastapi.testclient import TestClient + from protea.api.app import create_app + + mock_settings = MagicMock() + mock_settings.db_url = "sqlite:///:memory:" + mock_settings.amqp_url = "amqp://guest:guest@localhost/" + + mock_factory = MagicMock() + mock_session = MagicMock() + mock_factory.return_value = mock_session + mock_session.__enter__ = lambda s: s + mock_session.__exit__ = MagicMock(return_value=False) + + with patch("protea.api.app.load_settings", return_value=mock_settings), \ + patch("protea.api.app.build_session_factory", return_value=mock_factory): + app = create_app(Path("/fake/root")) + + with patch("protea.infrastructure.session.session_scope") as mock_scope, \ + patch("pika.BlockingConnection", side_effect=Exception("Connection refused")): + mock_scope.return_value.__enter__ = lambda s: mock_session + mock_scope.return_value.__exit__ = MagicMock(return_value=False) + client = TestClient(app) + resp = client.get("/health/ready") + + assert resp.status_code == 503 + assert "RabbitMQ unreachable" in resp.json()["detail"] + + def test_project_root_defaults_to_parents_2(self): + """When project_root is None, it defaults to Path(__file__).parents[2].""" + from protea.api.app import create_app + + mock_settings = MagicMock() + mock_settings.db_url = "sqlite:///:memory:" + mock_settings.amqp_url = "amqp://guest:guest@localhost/" + + with patch("protea.api.app.load_settings", return_value=mock_settings) as mock_load, \ + patch("protea.api.app.build_session_factory", return_value=MagicMock()): + app = create_app() # project_root=None + + # load_settings should have been called with the resolved parents[2] path + called_root = mock_load.call_args[0][0] + assert isinstance(called_root, Path) + assert called_root.is_absolute() + + def test_sphinx_mount_when_directory_exists(self, tmp_path): + """When docs/build/html exists, /sphinx is mounted.""" + from protea.api.app import create_app + + sphinx_dir = tmp_path / "docs" / "build" / "html" + sphinx_dir.mkdir(parents=True) + (sphinx_dir / "index.html").write_text("") + + mock_settings = MagicMock() + mock_settings.db_url = "sqlite:///:memory:" + mock_settings.amqp_url = "amqp://guest:guest@localhost/" + + with patch("protea.api.app.load_settings", return_value=mock_settings), \ + patch("protea.api.app.build_session_factory", return_value=MagicMock()): + app = create_app(project_root=tmp_path) + + route_paths = [r.path for r in app.routes] + assert any("/sphinx" in p for p in route_paths) + + def test_static_mount_when_directory_exists(self, tmp_path): + """When static/ exists, /static is mounted.""" + from protea.api.app import create_app + + static_dir = tmp_path / "static" + static_dir.mkdir() + (static_dir / "test.txt").write_text("hello") + + mock_settings = MagicMock() + mock_settings.db_url = "sqlite:///:memory:" + mock_settings.amqp_url = "amqp://guest:guest@localhost/" + + with patch("protea.api.app.load_settings", return_value=mock_settings), \ + patch("protea.api.app.build_session_factory", return_value=MagicMock()): + app = create_app(project_root=tmp_path) + + route_paths = [r.path for r in app.routes] + assert any("/static" in p for p in route_paths) diff --git a/tests/test_insert_proteins.py b/tests/test_insert_proteins.py index 9922641..5a4bd50 100644 --- a/tests/test_insert_proteins.py +++ b/tests/test_insert_proteins.py @@ -160,6 +160,231 @@ def test_sequence_hash_is_set(self): assert records[0]["sequence_hash"] is not None assert len(records[0]["sequence_hash"]) == 32 # MD5 hex + def test_empty_sequence_skipped(self): + """Lines 231-233: header with no sequence lines is skipped.""" + fasta = ">sp|P12345|TEST_HUMAN Test OS=Homo sapiens OX=9606\n\n" + records = self.op._parse_fasta(fasta) + assert records == [] + + def test_header_without_pipe_separators(self): + """Lines 264-265: header without | uses first word as accession.""" + fasta = ">SIMPLE_ACC some description\nMKTAYIAK\n" + records = self.op._parse_fasta(fasta) + assert len(records) == 1 + assert records[0]["accession"] == "SIMPLE_ACC" + assert records[0]["entry_name"] is None + + def test_isoform_accession_parsed(self): + fasta = ( + ">sp|P12345-3|TEST_HUMAN Isoform 3 OS=Homo sapiens OX=9606 GN=TEST\n" + "MKTAYIAK\n" + ) + records = self.op._parse_fasta(fasta) + r = records[0] + assert r["accession"] == "P12345-3" + assert r["canonical_accession"] == "P12345" + assert r["is_canonical"] is False + assert r["isoform_index"] == 3 + + def test_canonical_accession_flagged(self): + records = self.op._parse_fasta(FASTA_ONE) + r = records[0] + assert r["canonical_accession"] == "P12345" + assert r["is_canonical"] is True + assert r["isoform_index"] is None + + def test_reviewed_vs_unreviewed(self): + records = self.op._parse_fasta(FASTA_TWO) + assert records[0]["reviewed"] is True # sp| + assert records[1]["reviewed"] is False # tr| + + def test_sequence_deduplication_by_hash(self): + """Two identical sequences produce the same hash.""" + fasta = ( + ">sp|P11111|A_HUMAN Prot A OS=Homo sapiens OX=9606\nMKTAYIAK\n" + ">sp|P22222|B_HUMAN Prot B OS=Homo sapiens OX=9606\nMKTAYIAK\n" + ) + records = self.op._parse_fasta(fasta) + assert len(records) == 2 + assert records[0]["sequence_hash"] == records[1]["sequence_hash"] + + def test_multiline_sequence(self): + fasta = ( + ">sp|P12345|TEST_HUMAN Test OS=Homo sapiens OX=9606\n" + "MKTAY\n" + "IAKQR\n" + ) + records = self.op._parse_fasta(fasta) + assert records[0]["sequence"] == "MKTAYIAKQR" + assert records[0]["length"] == 10 + + +# --------------------------------------------------------------------------- +# Unit tests — _decode_response +# --------------------------------------------------------------------------- + +class TestDecodeResponse: + def setup_method(self): + self.op = InsertProteinsOperation() + + def test_decode_uncompressed(self): + """Line 217: uncompressed path.""" + resp = MagicMock() + resp.content = b"hello world" + result = self.op._decode_response(resp, compressed=False) + assert result == "hello world" + + def test_decode_compressed(self): + """Lines 215-216: gzip decompression path.""" + import gzip + from io import BytesIO + + raw = b"compressed content" + buf = BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as f: + f.write(raw) + resp = MagicMock() + resp.content = buf.getvalue() + result = self.op._decode_response(resp, compressed=True) + assert result == "compressed content" + + +# --------------------------------------------------------------------------- +# Unit tests — _store_records +# --------------------------------------------------------------------------- + +class TestStoreRecords: + def setup_method(self): + self.op = InsertProteinsOperation() + + def test_empty_records_returns_zeros(self): + """Line 300: empty records early return.""" + session = _make_mock_session() + result = self.op._store_records(session, [], _noop_emit) + assert result == (0, 0, 0, 0) + session.query.assert_not_called() + + def test_updates_existing_protein(self): + """Lines 350-394: existing protein gets conservative updates.""" + from protea.infrastructure.orm.models.sequence.sequence import ( + Sequence as SequenceModel, + ) + + seq_hash = SequenceModel.compute_hash("MKTAYIAK") + record = { + "accession": "P12345", + "entry_name": "TEST_HUMAN", + "canonical_accession": "P12345", + "is_canonical": True, + "isoform_index": None, + "organism": "Homo sapiens", + "taxonomy_id": "9606", + "gene_name": "TEST", + "reviewed": True, + "sequence": "MKTAYIAK", + "length": 8, + "sequence_hash": seq_hash, + } + + # Existing protein with missing fields (triggers updates) + existing_prot = MagicMock() + existing_prot.accession = "P12345" + existing_prot.sequence_id = None # will be updated + existing_prot.entry_name = None # will be updated + existing_prot.canonical_accession = "OLD_ACC" # will be updated + existing_prot.is_canonical = False # will be updated + existing_prot.isoform_index = 2 # will be updated + existing_prot.reviewed = None # will be updated + existing_prot.taxonomy_id = None # will be updated + existing_prot.organism = None # will be updated + existing_prot.gene_name = None # will be updated + existing_prot.length = None # will be updated + + session = MagicMock(spec=Session) + + # _load_existing_sequences returns the hash → id map + seq_query = MagicMock() + seq_query.filter.return_value.all.return_value = [(seq_hash, 42)] + + # _load_existing_proteins returns the existing protein + prot_query = MagicMock() + prot_query.filter.return_value.all.return_value = [existing_prot] + + call_idx = {"n": 0} + + def query_side_effect(*args): + call_idx["n"] += 1 + if call_idx["n"] == 1: + return seq_query + return prot_query + + session.query.side_effect = query_side_effect + + ins_p, upd_p, ins_s, re_s = self.op._store_records(session, [record], _noop_emit) + + assert ins_p == 0 + assert upd_p == 1 # existing protein was updated + assert re_s == 1 # sequence was reused from DB + assert ins_s == 0 + # Verify fields were updated + assert existing_prot.sequence_id == 42 + assert existing_prot.entry_name == "TEST_HUMAN" + assert existing_prot.canonical_accession == "P12345" + assert existing_prot.is_canonical is True + assert existing_prot.isoform_index is None + assert existing_prot.reviewed is True + + def test_inserts_new_sequence_when_missing(self): + """Lines 318-334: new sequence inserted when hash not in DB.""" + from protea.infrastructure.orm.models.sequence.sequence import ( + Sequence as SequenceModel, + ) + + seq_hash = SequenceModel.compute_hash("MKTAYIAK") + record = { + "accession": "P12345", + "entry_name": "TEST_HUMAN", + "canonical_accession": "P12345", + "is_canonical": True, + "isoform_index": None, + "organism": "Homo sapiens", + "taxonomy_id": "9606", + "gene_name": "TEST", + "reviewed": True, + "sequence": "MKTAYIAK", + "length": 8, + "sequence_hash": seq_hash, + } + + session = MagicMock(spec=Session) + + # No existing sequences + seq_query = MagicMock() + seq_query.filter.return_value.all.return_value = [] + + # No existing proteins + prot_query = MagicMock() + prot_query.filter.return_value.all.return_value = [] + + call_idx = {"n": 0} + + def query_side_effect(*args): + call_idx["n"] += 1 + if call_idx["n"] == 1: + return seq_query + return prot_query + + session.query.side_effect = query_side_effect + + ins_p, upd_p, ins_s, re_s = self.op._store_records(session, [record], _noop_emit) + + assert ins_p == 1 + assert upd_p == 0 + assert ins_s == 1 + assert re_s == 0 + # add_all called twice: once for sequences, once for proteins + assert session.add_all.call_count == 2 + # --------------------------------------------------------------------------- # Unit tests — execute() with mocked HTTP and session @@ -242,6 +467,234 @@ def test_two_records_counts_correctly(self): assert result.result["retrieved_records"] == 2 assert result.result["proteins_inserted"] == 2 + def test_empty_page_continues(self): + """Line 93: empty records list triggers continue.""" + session = _make_mock_session() + emit = _capturing_emit() + # First response is empty FASTA, no link header → single page with 0 records + empty_resp = _make_mock_response("") + with patch.object(self.op._http, "get", return_value=empty_resp): + result = self.op.execute( + session, + {"search_criteria": "q", "compressed": False}, + emit=emit, + ) + assert result.result["retrieved_records"] == 0 + assert result.result["pages"] == 1 + + def test_total_limit_trims_to_zero_breaks(self): + """Lines 96-98: when total_limit is already reached, records trimmed to empty → break.""" + session = _make_mock_session() + emit = _capturing_emit() + + # Two pages: first has 2 records (we set limit=2), second also has records + # but after retrieving 2 on page 1 we should stop + page1_resp = _make_mock_response( + FASTA_TWO, + link_header='; rel="next"', + ) + page2_resp = _make_mock_response(FASTA_ONE) + + call_count = {"n": 0} + + def get_side_effect(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return page1_resp + return page2_resp + + with patch.object(self.op._http, "get", side_effect=get_side_effect): + result = self.op.execute( + session, + {"search_criteria": "q", "total_limit": 2, "compressed": False}, + emit=emit, + ) + + assert result.result["retrieved_records"] == 2 + + def test_compressed_param_appended(self): + """Line 180: compressed=true adds compressed=true to URL params.""" + session = _make_mock_session() + emit = _capturing_emit() + + import gzip + from io import BytesIO + + buf = BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as f: + f.write(FASTA_ONE.encode("utf-8")) + compressed_content = buf.getvalue() + + resp = MagicMock() + resp.status_code = 200 + resp.content = compressed_content + resp.headers = {"link": ""} + resp.raise_for_status = MagicMock() + + with patch.object(self.op._http, "get", return_value=resp) as mock_get: + self.op.execute( + session, + {"search_criteria": "q", "compressed": True}, + emit=emit, + ) + + called_url = mock_get.call_args[0][0] + assert "compressed=true" in called_url + + def test_total_results_from_header(self): + """Line 200: X-Total-Results header is captured.""" + session = _make_mock_session() + emit = _capturing_emit() + + resp = _make_mock_response(FASTA_ONE) + resp.headers["X-Total-Results"] = "42" + + op = InsertProteinsOperation() + with patch.object(op._http, "get", return_value=resp): + op.execute( + session, + {"search_criteria": "q", "compressed": False}, + emit=emit, + ) + + assert op._total_results == 42 + + def test_total_results_invalid_header_ignored(self): + """Line 200: non-numeric X-Total-Results doesn't crash.""" + session = _make_mock_session() + emit = _capturing_emit() + + resp = _make_mock_response(FASTA_ONE) + resp.headers["X-Total-Results"] = "not-a-number" + + op = InsertProteinsOperation() + with patch.object(op._http, "get", return_value=resp): + op.execute( + session, + {"search_criteria": "q", "compressed": False}, + emit=emit, + ) + + assert op._total_results is None + + def test_cursor_pagination(self): + """Lines 208-210: cursor-based pagination follows link headers.""" + session = _make_mock_session() + emit = _capturing_emit() + + page1_resp = _make_mock_response( + FASTA_ONE, + link_header='; rel="next"', + ) + page2_resp = _make_mock_response(FASTA_ONE) # no link header → last page + + call_count = {"n": 0} + called_urls: list[str] = [] + + def get_side_effect(url, **kwargs): + call_count["n"] += 1 + called_urls.append(url) + if call_count["n"] == 1: + return page1_resp + return page2_resp + + op = InsertProteinsOperation() + with patch.object(op._http, "get", side_effect=get_side_effect): + result = op.execute( + session, + {"search_criteria": "q", "compressed": False}, + emit=emit, + ) + + assert result.result["pages"] == 2 + assert result.result["retrieved_records"] == 2 + # Second call URL should contain cursor + assert "cursor=abc123" in called_urls[1] + + def test_network_failure_propagates(self): + """HTTP errors propagate to caller.""" + import requests as req + + session = _make_mock_session() + op = InsertProteinsOperation() + + with patch.object( + op._http, + "get", + side_effect=req.ConnectionError("network down"), + ): + with pytest.raises(req.ConnectionError): + op.execute( + session, + { + "search_criteria": "q", + "compressed": False, + "max_retries": 1, + "backoff_base_seconds": 0.0, + "backoff_max_seconds": 0.0, + "jitter_seconds": 0.0, + }, + emit=_noop_emit, + ) + + def test_isoform_records_counted(self): + """Isoform records are counted in the result.""" + session = _make_mock_session() + emit = _capturing_emit() + + fasta_with_isoform = ( + ">sp|P12345|TEST_HUMAN Test OS=Homo sapiens OX=9606\nMKTAYIAK\n" + ">sp|P12345-2|TEST_HUMAN Isoform 2 OS=Homo sapiens OX=9606\nMKTAYIAKQR\n" + ) + resp = _make_mock_response(fasta_with_isoform) + op = InsertProteinsOperation() + with patch.object(op._http, "get", return_value=resp): + result = op.execute( + session, + {"search_criteria": "q", "compressed": False}, + emit=emit, + ) + + assert result.result["isoform_records"] == 1 + + def test_progress_emission_with_total(self): + """Progress events include _progress_current and _progress_total.""" + session = _make_mock_session() + emit = _capturing_emit() + + resp = _make_mock_response(FASTA_ONE) + resp.headers["X-Total-Results"] = "100" + + op = InsertProteinsOperation() + with patch.object(op._http, "get", return_value=resp): + op.execute( + session, + {"search_criteria": "q", "compressed": False}, + emit=emit, + ) + + page_done_events = [ + c for c in emit.calls if c["event"] == "insert_proteins.page_done" + ] + assert len(page_done_events) == 1 + fields = page_done_events[0]["fields"] + assert fields["_progress_current"] == 1 + assert fields["_progress_total"] == 100 + + def test_include_isoforms_false_omits_param(self): + """include_isoforms=False does not add includeIsoform to URL.""" + session = _make_mock_session() + resp = _make_mock_response(FASTA_ONE) + op = InsertProteinsOperation() + with patch.object(op._http, "get", return_value=resp) as mock_get: + op.execute( + session, + {"search_criteria": "q", "compressed": False, "include_isoforms": False}, + emit=_noop_emit, + ) + called_url = mock_get.call_args[0][0] + assert "includeIsoform" not in called_url + # --------------------------------------------------------------------------- # Integration test — full round-trip against real Postgres diff --git a/tests/test_load_goa_annotations.py b/tests/test_load_goa_annotations.py index a2bfde8..6ec133e 100644 --- a/tests/test_load_goa_annotations.py +++ b/tests/test_load_goa_annotations.py @@ -1,10 +1,16 @@ +""" +Unit tests for LoadGOAAnnotationsOperation. +No DB or network required — everything is mocked. +""" from __future__ import annotations +import io import uuid from unittest.mock import MagicMock, patch import pytest +from protea.core.contracts.operation import OperationResult from protea.core.operations.load_goa_annotations import ( LoadGOAAnnotationsOperation, LoadGOAAnnotationsPayload, @@ -13,14 +19,49 @@ _noop_emit = lambda *_: None # noqa: E731 _SNAPSHOT_ID = str(uuid.uuid4()) +_ANNOTATION_SET_ID = uuid.uuid4() -_GAF_SAMPLE = """\ -!gaf-version: 2.2 -!Generated by GO -UniProtKB\tP12345\tproteinA\tenables\tGO:0003824\tPMID:123\tIDA\t\t\t\t\tprotein\t\t20240101\tUniProt\t\t -UniProtKB\tQ67890\tproteinB\tinvolved_in\tGO:0008150\tPMID:456\tIEA\t\t\t\t\tprotein\t\t20240101\tUniProt\t\t -UniProtKB\tXXXXXX\tunknown\tenables\tGO:0003824\tPMID:789\tIDA\t\t\t\t\tprotein\t\t20240101\tUniProt\t\t -""" + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_emit(): + """Return a recording emit function and its event list.""" + events = [] + + def emit(event, message, fields, level): + events.append({"event": event, "fields": fields, "level": level}) + + return emit, events + + +def _gaf_line( + accession="P12345", + go_id="GO:0003674", + qualifier="enables", + evidence="IDA", + db_ref="PMID:1234", + with_from="", + date="20240101", + assigned_by="UniProt", +): + """Build a valid 15-column GAF line.""" + cols = ["UniProtKB"] + [""] * 14 + cols[1] = accession + cols[3] = qualifier + cols[4] = go_id + cols[5] = db_ref + cols[6] = evidence + cols[7] = with_from + cols[13] = date + cols[14] = assigned_by + return "\t".join(cols) + + +# --------------------------------------------------------------------------- +# Payload validation +# --------------------------------------------------------------------------- class TestLoadGOAAnnotationsPayload: @@ -48,26 +89,91 @@ def test_empty_snapshot_id_raises(self) -> None: "source_version": "2024-03", }) + def test_empty_gaf_url_raises(self) -> None: + with pytest.raises(ValueError): + LoadGOAAnnotationsPayload( + ontology_snapshot_id=_SNAPSHOT_ID, + gaf_url="", + source_version="v1", + ) + + def test_empty_source_version_raises(self) -> None: + with pytest.raises(ValueError): + LoadGOAAnnotationsPayload( + ontology_snapshot_id=_SNAPSHOT_ID, + gaf_url="https://example.com/goa.gaf.gz", + source_version="", + ) + + def test_page_size_must_be_positive(self) -> None: + with pytest.raises(ValueError): + LoadGOAAnnotationsPayload( + ontology_snapshot_id=_SNAPSHOT_ID, + gaf_url="https://example.com/goa.gaf.gz", + source_version="v1", + page_size=0, + ) + + def test_strings_are_stripped(self) -> None: + p = LoadGOAAnnotationsPayload( + ontology_snapshot_id=f" {_SNAPSHOT_ID} ", + gaf_url=" https://example.com/goa.gaf.gz ", + source_version=" v1 ", + ) + assert p.ontology_snapshot_id == _SNAPSHOT_ID + assert p.gaf_url == "https://example.com/goa.gaf.gz" + assert p.source_version == "v1" + + def test_defaults(self) -> None: + p = LoadGOAAnnotationsPayload( + ontology_snapshot_id=_SNAPSHOT_ID, + gaf_url="https://example.com/goa.gaf.gz", + source_version="v1", + ) + assert p.timeout_seconds == 300 + assert p.commit_every_page is True + assert p.total_limit is None + + +# --------------------------------------------------------------------------- +# _store_buffer +# --------------------------------------------------------------------------- + class TestStoreBuffer: def _op(self) -> LoadGOAAnnotationsOperation: return LoadGOAAnnotationsOperation() + def _make_record(self, accession="P12345", go_id="GO:0003824", evidence="IDA"): + return { + "accession": accession, + "go_id": go_id, + "qualifier": "enables", + "evidence_code": evidence, + "db_reference": "PMID:1", + "with_from": "", + "assigned_by": "UniProt", + "annotation_date": "20240101", + } + def test_skips_unknown_accession(self) -> None: op = self._op() session = MagicMock() - records = [ - { - "accession": "UNKNOWN", - "go_id": "GO:0003824", - "qualifier": "enables", - "evidence_code": "IDA", - "db_reference": "PMID:1", - "with_from": "", - "assigned_by": "UniProt", - "annotation_date": "20240101", - } - ] + records = [self._make_record(accession="UNKNOWN")] + inserted, skipped = op._store_buffer( + session, + records, + uuid.UUID(_SNAPSHOT_ID), + valid_accessions={"P12345"}, + go_term_map={"GO:0003824": 1}, + ) + assert inserted == 0 + assert skipped == 1 + + def test_skips_empty_accession(self) -> None: + op = self._op() + session = MagicMock() + records = [self._make_record(accession=" ")] inserted, skipped = op._store_buffer( session, records, @@ -77,23 +183,11 @@ def test_skips_unknown_accession(self) -> None: ) assert inserted == 0 assert skipped == 1 - session.add_all.assert_not_called() def test_skips_unknown_go_term(self) -> None: op = self._op() session = MagicMock() - records = [ - { - "accession": "P12345", - "go_id": "GO:9999999", - "qualifier": "enables", - "evidence_code": "IDA", - "db_reference": "PMID:1", - "with_from": "", - "assigned_by": "UniProt", - "annotation_date": "20240101", - } - ] + records = [self._make_record(go_id="GO:9999999")] inserted, skipped = op._store_buffer( session, records, @@ -108,26 +202,8 @@ def test_inserts_valid_records(self) -> None: op = self._op() session = MagicMock() records = [ - { - "accession": "P12345", - "go_id": "GO:0003824", - "qualifier": "enables", - "evidence_code": "IDA", - "db_reference": "PMID:123", - "with_from": "", - "assigned_by": "UniProt", - "annotation_date": "20240101", - }, - { - "accession": "Q67890", - "go_id": "GO:0008150", - "qualifier": "involved_in", - "evidence_code": "IEA", - "db_reference": "PMID:456", - "with_from": "", - "assigned_by": "UniProt", - "annotation_date": "20240101", - }, + self._make_record(accession="P12345", go_id="GO:0003824"), + self._make_record(accession="Q67890", go_id="GO:0008150", evidence="IEA"), ] inserted, skipped = op._store_buffer( session, @@ -138,5 +214,573 @@ def test_inserts_valid_records(self) -> None: ) assert inserted == 2 assert skipped == 0 - # Uses bulk pg_insert().on_conflict_do_nothing() instead of add_all session.execute.assert_called() + + def test_deduplicates_within_buffer(self) -> None: + op = self._op() + session = MagicMock() + rec = self._make_record() + records = [rec.copy(), rec.copy(), rec.copy()] + inserted, skipped = op._store_buffer( + session, + records, + uuid.UUID(_SNAPSHOT_ID), + valid_accessions={"P12345"}, + go_term_map={"GO:0003824": 1}, + ) + assert inserted == 1 + assert skipped == 2 + + def test_different_evidence_codes_not_deduplicated(self) -> None: + op = self._op() + session = MagicMock() + records = [ + self._make_record(evidence="IDA"), + self._make_record(evidence="IEA"), + ] + inserted, skipped = op._store_buffer( + session, + records, + uuid.UUID(_SNAPSHOT_ID), + valid_accessions={"P12345"}, + go_term_map={"GO:0003824": 1}, + ) + assert inserted == 2 + assert skipped == 0 + + def test_mixed_valid_and_invalid(self) -> None: + op = self._op() + session = MagicMock() + records = [ + self._make_record(accession="P12345"), + self._make_record(accession="UNKNOWN"), + self._make_record(accession="Q67890", go_id="GO:0008150"), + self._make_record(go_id="GO:INVALID"), + ] + inserted, skipped = op._store_buffer( + session, + records, + uuid.UUID(_SNAPSHOT_ID), + valid_accessions={"P12345", "Q67890"}, + go_term_map={"GO:0003824": 1, "GO:0008150": 2}, + ) + assert inserted == 2 + assert skipped == 2 + + def test_empty_buffer(self) -> None: + op = self._op() + session = MagicMock() + inserted, skipped = op._store_buffer( + session, [], uuid.UUID(_SNAPSHOT_ID), + valid_accessions={"P12345"}, go_term_map={"GO:0003824": 1}, + ) + assert inserted == 0 + assert skipped == 0 + session.execute.assert_not_called() + + def test_empty_evidence_treated_as_none_for_dedup(self) -> None: + """Empty string evidence_code becomes None; two such records are duplicates.""" + op = self._op() + session = MagicMock() + records = [ + self._make_record(evidence=""), + self._make_record(evidence=""), + ] + inserted, skipped = op._store_buffer( + session, + records, + uuid.UUID(_SNAPSHOT_ID), + valid_accessions={"P12345"}, + go_term_map={"GO:0003824": 1}, + ) + assert inserted == 1 + assert skipped == 1 + + +# --------------------------------------------------------------------------- +# _stream_gaf +# --------------------------------------------------------------------------- + + +class TestStreamGaf: + def setup_method(self): + self.op = LoadGOAAnnotationsOperation() + + def _stream_from_text(self, text: str, url="https://example.com/goa.gaf"): + """Mock requests.get and stream GAF text through _stream_gaf.""" + payload = LoadGOAAnnotationsPayload( + ontology_snapshot_id=_SNAPSHOT_ID, + gaf_url=url, + source_version="v1", + ) + emit, _ = _make_emit() + + raw = io.BytesIO(text.encode("utf-8")) + mock_resp = MagicMock() + mock_resp.raw = raw + mock_resp.raise_for_status = MagicMock() + + with patch("protea.core.operations.load_goa_annotations.requests.get", return_value=mock_resp): + return list(self.op._stream_gaf(payload, emit)) + + def test_parses_valid_gaf_line(self): + line = _gaf_line(accession="P12345", go_id="GO:0003674", evidence="IDA") + records = self._stream_from_text(line + "\n") + assert len(records) == 1 + assert records[0]["accession"] == "P12345" + assert records[0]["go_id"] == "GO:0003674" + assert records[0]["evidence_code"] == "IDA" + + def test_skips_comment_lines(self): + text = "!this is a comment\n" + _gaf_line() + "\n" + records = self._stream_from_text(text) + assert len(records) == 1 + + def test_skips_empty_lines(self): + text = "\n\n" + _gaf_line() + "\n\n" + records = self._stream_from_text(text) + assert len(records) == 1 + + def test_skips_short_lines(self): + text = "col1\tcol2\tcol3\n" + _gaf_line() + "\n" + records = self._stream_from_text(text) + assert len(records) == 1 + + def test_multiple_records(self): + lines = [ + _gaf_line(accession="A1"), + _gaf_line(accession="A2"), + _gaf_line(accession="A3"), + ] + records = self._stream_from_text("\n".join(lines) + "\n") + assert len(records) == 3 + assert [r["accession"] for r in records] == ["A1", "A2", "A3"] + + def test_extracts_all_fields(self): + line = _gaf_line( + accession="Q99999", + go_id="GO:0005575", + qualifier="located_in", + evidence="IEA", + db_ref="GO_REF:001", + with_from="InterPro:IPR000001", + date="20230615", + assigned_by="InterPro", + ) + records = self._stream_from_text(line + "\n") + r = records[0] + assert r["accession"] == "Q99999" + assert r["go_id"] == "GO:0005575" + assert r["qualifier"] == "located_in" + assert r["evidence_code"] == "IEA" + assert r["db_reference"] == "GO_REF:001" + assert r["with_from"] == "InterPro:IPR000001" + assert r["annotation_date"] == "20230615" + assert r["assigned_by"] == "InterPro" + + def test_gzip_url_uses_gzip_decompression(self): + import gzip as gzip_mod + + line = _gaf_line() + "\n" + compressed = gzip_mod.compress(line.encode("utf-8")) + + payload = LoadGOAAnnotationsPayload( + ontology_snapshot_id=_SNAPSHOT_ID, + gaf_url="https://example.com/goa.gaf.gz", + source_version="v1", + ) + emit, _ = _make_emit() + + raw = io.BytesIO(compressed) + mock_resp = MagicMock() + mock_resp.raw = raw + mock_resp.raise_for_status = MagicMock() + + with patch("protea.core.operations.load_goa_annotations.requests.get", return_value=mock_resp): + records = list(self.op._stream_gaf(payload, emit)) + assert len(records) == 1 + + def test_empty_file_returns_no_records(self): + records = self._stream_from_text("") + assert records == [] + + def test_file_with_only_comments(self): + text = "!comment1\n!comment2\n" + records = self._stream_from_text(text) + assert records == [] + + +# --------------------------------------------------------------------------- +# _load_accessions +# --------------------------------------------------------------------------- + + +class TestLoadAccessions: + def setup_method(self): + self.op = LoadGOAAnnotationsOperation() + + def test_returns_set_of_accessions(self): + session = MagicMock() + session.scalars.return_value = iter(["P12345", "Q99999"]) + emit, events = _make_emit() + + result = self.op._load_accessions(session, emit) + assert result == {"P12345", "Q99999"} + event_names = [e["event"] for e in events] + assert "load_goa_annotations.load_accessions_start" in event_names + assert "load_goa_annotations.load_accessions_done" in event_names + + def test_returns_empty_set(self): + session = MagicMock() + session.scalars.return_value = iter([]) + emit, _ = _make_emit() + + result = self.op._load_accessions(session, emit) + assert result == set() + + def test_emits_count_in_done_event(self): + session = MagicMock() + session.scalars.return_value = iter(["A", "B", "C"]) + emit, events = _make_emit() + + self.op._load_accessions(session, emit) + done = [e for e in events if e["event"] == "load_goa_annotations.load_accessions_done"] + assert len(done) == 1 + assert done[0]["fields"]["canonical_accessions"] == 3 + + +# --------------------------------------------------------------------------- +# _load_go_term_map +# --------------------------------------------------------------------------- + + +class TestLoadGoTermMap: + def setup_method(self): + self.op = LoadGOAAnnotationsOperation() + + def _mock_session(self, rows): + session = MagicMock() + query_mock = MagicMock() + session.query.return_value = query_mock + query_mock.filter.return_value = query_mock + query_mock.all.return_value = rows + return session + + def test_returns_mapping(self): + session = self._mock_session([("GO:0003674", 1), ("GO:0005575", 2)]) + emit, events = _make_emit() + + result = self.op._load_go_term_map(session, uuid.uuid4(), emit) + assert result == {"GO:0003674": 1, "GO:0005575": 2} + event_names = [e["event"] for e in events] + assert "load_goa_annotations.load_go_terms_start" in event_names + assert "load_goa_annotations.load_go_terms_done" in event_names + + def test_empty_ontology(self): + session = self._mock_session([]) + emit, _ = _make_emit() + + result = self.op._load_go_term_map(session, uuid.uuid4(), emit) + assert result == {} + + def test_emits_count_in_done_event(self): + session = self._mock_session([("GO:0003674", 1)]) + emit, events = _make_emit() + + self.op._load_go_term_map(session, uuid.uuid4(), emit) + done = [e for e in events if e["event"] == "load_goa_annotations.load_go_terms_done"] + assert len(done) == 1 + assert done[0]["fields"]["go_terms"] == 1 + + +# --------------------------------------------------------------------------- +# execute (full integration of all pieces, mocked) +# --------------------------------------------------------------------------- + + +class TestExecute: + def setup_method(self): + self.op = LoadGOAAnnotationsOperation() + self.snapshot_id = uuid.uuid4() + + def _make_session(self, accessions, go_terms): + session = MagicMock() + # session.get(OntologySnapshot, id) returns a truthy mock + session.get.return_value = MagicMock() + # _load_accessions uses session.scalars + session.scalars.return_value = iter(accessions) + # _load_go_term_map uses session.query + query_mock = MagicMock() + session.query.return_value = query_mock + query_mock.filter.return_value = query_mock + query_mock.all.return_value = list(go_terms.items()) + return session + + def _run(self, gaf_text, accessions, go_terms, + page_size=10000, total_limit=None, commit_every_page=True, + store_buffer_side_effect=None): + session = self._make_session(accessions, go_terms) + emit, events = _make_emit() + + ann_set_mock = MagicMock() + ann_set_mock.id = _ANNOTATION_SET_ID + + payload = { + "ontology_snapshot_id": str(self.snapshot_id), + "gaf_url": "https://example.com/goa.gaf", + "source_version": "v1", + "page_size": page_size, + "commit_every_page": commit_every_page, + } + if total_limit is not None: + payload["total_limit"] = total_limit + + raw = io.BytesIO(gaf_text.encode("utf-8")) + mock_resp = MagicMock() + mock_resp.raw = raw + mock_resp.raise_for_status = MagicMock() + + # _store_buffer does a lazy import of pg_insert which needs a real + # SQLAlchemy Table object. We mock the whole method and count + # inserted/skipped via the records passed to it, using the real + # filtering logic from the valid_accessions and go_terms sets. + real_valid = set(accessions) + real_go = dict(go_terms) + + def fake_store_buffer(_session, records, _ann_set_id, _valid, _go_map): + inserted = 0 + skipped = 0 + seen = set() + for rec in records: + acc = rec["accession"].strip() + if not acc or acc not in real_valid: + skipped += 1 + continue + go_id = rec["go_id"].strip() + go_term_id = real_go.get(go_id) + if go_term_id is None: + skipped += 1 + continue + ev = rec["evidence_code"] or None + key = (_ann_set_id, acc, go_term_id, ev) + if key in seen: + skipped += 1 + continue + seen.add(key) + inserted += 1 + return inserted, skipped + + if store_buffer_side_effect is not None: + fake_store_buffer = store_buffer_side_effect + + with patch( + "protea.core.operations.load_goa_annotations.requests.get", + return_value=mock_resp, + ), patch( + "protea.core.operations.load_goa_annotations.AnnotationSet", + return_value=ann_set_mock, + ), patch.object( + self.op, "_store_buffer", side_effect=fake_store_buffer, + ): + result = self.op.execute(session, payload, emit=emit) + + return result, events, session + + def test_basic_execution(self): + gaf = _gaf_line(accession="P12345", go_id="GO:0003674") + "\n" + result, events, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + assert isinstance(result, OperationResult) + assert result.result["annotations_inserted"] == 1 + assert result.result["annotations_skipped"] == 0 + event_names = [e["event"] for e in events] + assert "load_goa_annotations.start" in event_names + assert "load_goa_annotations.done" in event_names + + def test_snapshot_not_found_raises(self): + session = MagicMock() + session.get.return_value = None + emit, _ = _make_emit() + + payload = { + "ontology_snapshot_id": str(self.snapshot_id), + "gaf_url": "https://example.com/goa.gaf", + "source_version": "v1", + } + with pytest.raises(ValueError, match="not found"): + self.op.execute(session, payload, emit=emit) + + def test_no_proteins_returns_zero(self): + gaf = _gaf_line() + "\n" + result, events, _ = self._run( + gaf, accessions=[], go_terms={"GO:0003674": 1}, + ) + assert result.result == {"annotations_inserted": 0} + event_names = [e["event"] for e in events] + assert "load_goa_annotations.no_proteins" in event_names + + def test_skips_unmatched_accessions(self): + gaf = _gaf_line(accession="UNKNOWN") + "\n" + result, _, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + assert result.result["annotations_inserted"] == 0 + assert result.result["annotations_skipped"] == 1 + + def test_skips_unmatched_go_ids(self): + gaf = _gaf_line(accession="P12345", go_id="GO:UNKNOWN") + "\n" + result, _, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + assert result.result["annotations_inserted"] == 0 + assert result.result["annotations_skipped"] == 1 + + def test_pagination_emits_page_done(self): + lines = [_gaf_line(accession="P12345", go_id="GO:0003674", evidence=f"E{i}") + for i in range(5)] + gaf = "\n".join(lines) + "\n" + result, events, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + page_size=2, + ) + page_events = [e for e in events if e["event"] == "load_goa_annotations.page_done"] + # 5 records, page_size=2 -> 2 full pages emitted (remainder flushed separately) + assert len(page_events) == 2 + assert result.result["annotations_inserted"] == 5 + assert result.result["pages"] == 3 + + def test_commit_every_page(self): + lines = [_gaf_line(accession="P12345", go_id="GO:0003674", evidence=f"E{i}") + for i in range(4)] + gaf = "\n".join(lines) + "\n" + _, _, session = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + page_size=2, commit_every_page=True, + ) + # 4 records, page_size=2 -> 2 full pages -> 2 commits + assert session.commit.call_count == 2 + + def test_no_commit_when_disabled(self): + lines = [_gaf_line(accession="P12345", go_id="GO:0003674", evidence=f"E{i}") + for i in range(4)] + gaf = "\n".join(lines) + "\n" + _, _, session = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + page_size=2, commit_every_page=False, + ) + session.commit.assert_not_called() + + def test_total_limit_stops_early(self): + lines = [_gaf_line(accession="P12345", go_id="GO:0003674", evidence=f"E{i}") + for i in range(10)] + gaf = "\n".join(lines) + "\n" + result, events, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + page_size=3, total_limit=3, + ) + assert result.result["annotations_inserted"] == 3 + event_names = [e["event"] for e in events] + assert "load_goa_annotations.limit_reached" in event_names + + def test_empty_file(self): + result, _, _ = self._run( + "", accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + assert result.result["annotations_inserted"] == 0 + assert result.result["total_lines_read"] == 0 + assert result.result["pages"] == 0 + + def test_result_contains_elapsed_seconds(self): + gaf = _gaf_line() + "\n" + result, _, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + assert "elapsed_seconds" in result.result + assert result.result["elapsed_seconds"] >= 0 + + def test_result_contains_annotation_set_id(self): + gaf = _gaf_line() + "\n" + result, _, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + assert result.result["annotation_set_id"] == str(_ANNOTATION_SET_ID) + + def test_duplicate_annotations_in_file(self): + line = _gaf_line(accession="P12345", go_id="GO:0003674", evidence="IDA") + gaf = (line + "\n") * 5 + result, _, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + assert result.result["annotations_inserted"] == 1 + assert result.result["annotations_skipped"] == 4 + + def test_comments_and_short_lines_not_counted(self): + text = ( + "!GAF header comment\n" + "!another comment\n" + "short\tline\n" + + _gaf_line(accession="P12345", go_id="GO:0003674") + "\n" + ) + result, _, _ = self._run( + text, accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + # Only valid GAF lines are counted as total_lines_read + assert result.result["total_lines_read"] == 1 + assert result.result["annotations_inserted"] == 1 + + def test_annotation_set_created_event(self): + gaf = _gaf_line() + "\n" + _, events, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + event_names = [e["event"] for e in events] + assert "load_goa_annotations.annotation_set_created" in event_names + created = [e for e in events if e["event"] == "load_goa_annotations.annotation_set_created"] + assert created[0]["fields"]["annotation_set_id"] == str(_ANNOTATION_SET_ID) + + def test_page_done_event_fields(self): + lines = [_gaf_line(accession="P12345", go_id="GO:0003674", evidence=f"E{i}") + for i in range(3)] + gaf = "\n".join(lines) + "\n" + _, events, _ = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + page_size=2, + ) + page_events = [e for e in events if e["event"] == "load_goa_annotations.page_done"] + assert len(page_events) == 1 + fields = page_events[0]["fields"] + assert fields["page"] == 1 + assert fields["total_lines"] == 2 + assert fields["total_inserted"] == 2 + + def test_session_flush_called_after_annotation_set_add(self): + gaf = _gaf_line() + "\n" + _, _, session = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + ) + session.flush.assert_called() + + def test_multiple_pages_with_remainder(self): + """7 records with page_size=3 -> 2 full pages + 1 remainder = 3 pages total.""" + lines = [_gaf_line(accession="P12345", go_id="GO:0003674", evidence=f"E{i}") + for i in range(7)] + gaf = "\n".join(lines) + "\n" + result, events, session = self._run( + gaf, accessions=["P12345"], go_terms={"GO:0003674": 1}, + page_size=3, + ) + assert result.result["pages"] == 3 + assert result.result["annotations_inserted"] == 7 + page_events = [e for e in events if e["event"] == "load_goa_annotations.page_done"] + assert len(page_events) == 2 # only full pages emit page_done + + +# --------------------------------------------------------------------------- +# Operation name +# --------------------------------------------------------------------------- + + +class TestOperationName: + def test_name(self): + assert LoadGOAAnnotationsOperation.name == "load_goa_annotations" diff --git a/tests/test_load_ontology_snapshot.py b/tests/test_load_ontology_snapshot.py index 611713c..ad71230 100644 --- a/tests/test_load_ontology_snapshot.py +++ b/tests/test_load_ontology_snapshot.py @@ -103,6 +103,177 @@ def test_typedef_not_included(self) -> None: assert "part_of" not in go_ids +_OBO_WITH_RELATIONSHIPS = """\ +format-version: 1.2 +data-version: releases/2024-06-01 + +[Term] +id: GO:0008150 +name: biological_process +namespace: biological_process +def: "Root biological process." [GOC:go_curators] + +[Term] +id: GO:0009987 +name: cellular process +namespace: biological_process +def: "Any process that is carried out at the cellular level." [GOC:go_curators] +is_a: GO:0008150 ! biological_process + +[Term] +id: GO:0044237 +name: cellular metabolic process +namespace: biological_process +def: "The chemical reactions involving a cell." [GOC:go_curators] +is_a: GO:0009987 ! cellular process +relationship: part_of GO:0008150 ! biological_process +""" + + +class TestParseTermsRelationships: + """Tests for is_a and relationship: parsing (lines 275-287).""" + + def _op(self) -> LoadOntologySnapshotOperation: + return LoadOntologySnapshotOperation() + + def test_is_a_relationship_parsed(self) -> None: + op = self._op() + terms = {t["go_id"]: t for t in op._parse_terms(_OBO_WITH_RELATIONSHIPS)} + cellular = terms["GO:0009987"] + assert ("is_a", "GO:0008150") in cellular["relationships"] + + def test_part_of_relationship_parsed(self) -> None: + op = self._op() + terms = {t["go_id"]: t for t in op._parse_terms(_OBO_WITH_RELATIONSHIPS)} + metabolic = terms["GO:0044237"] + assert ("part_of", "GO:0008150") in metabolic["relationships"] + + def test_multiple_relationships_on_single_term(self) -> None: + op = self._op() + terms = {t["go_id"]: t for t in op._parse_terms(_OBO_WITH_RELATIONSHIPS)} + metabolic = terms["GO:0044237"] + assert len(metabolic["relationships"]) == 2 + assert ("is_a", "GO:0009987") in metabolic["relationships"] + assert ("part_of", "GO:0008150") in metabolic["relationships"] + + def test_root_term_has_no_relationships(self) -> None: + op = self._op() + terms = {t["go_id"]: t for t in op._parse_terms(_OBO_WITH_RELATIONSHIPS)} + root = terms["GO:0008150"] + assert root["relationships"] == [] + + def test_all_supported_relationship_types(self) -> None: + """Each of the 7 supported relationship types is captured.""" + op = self._op() + for rt in [ + "part_of", "regulates", "negatively_regulates", + "positively_regulates", "occurs_in", "capable_of", + "capable_of_part_of", + ]: + obo = ( + "format-version: 1.2\ndata-version: releases/2024-01-01\n\n" + "[Term]\nid: GO:0000001\nname: child\nnamespace: biological_process\n" + f"relationship: {rt} GO:0000002 ! parent\n" + ) + terms = op._parse_terms(obo) + assert (rt, "GO:0000002") in terms[0]["relationships"], f"Failed for {rt}" + + def test_unsupported_relationship_type_ignored(self) -> None: + op = self._op() + obo = ( + "format-version: 1.2\ndata-version: releases/2024-01-01\n\n" + "[Term]\nid: GO:0000001\nname: child\nnamespace: biological_process\n" + "relationship: has_part GO:0000002 ! parent\n" + ) + terms = op._parse_terms(obo) + assert terms[0].get("relationships", []) == [] + + def test_relationship_line_with_no_go_prefix_ignored(self) -> None: + """relationship: part_of SOMETHING (not GO:) is skipped.""" + op = self._op() + obo = ( + "format-version: 1.2\ndata-version: releases/2024-01-01\n\n" + "[Term]\nid: GO:0000001\nname: child\nnamespace: biological_process\n" + "relationship: part_of CHEBI:12345 ! not a GO term\n" + ) + terms = op._parse_terms(obo) + assert terms[0].get("relationships", []) == [] + + def test_definition_without_quotes_gives_none(self) -> None: + """def: line that doesn't match the quoted pattern yields None.""" + op = self._op() + obo = ( + "format-version: 1.2\ndata-version: releases/2024-01-01\n\n" + "[Term]\nid: GO:0000001\nname: test\nnamespace: biological_process\n" + "def: no quotes here\n" + ) + terms = op._parse_terms(obo) + assert terms[0]["definition"] is None + + +class TestDownload: + """Tests for _download (lines 202-207).""" + + def test_download_success(self) -> None: + op = LoadOntologySnapshotOperation() + payload = LoadOntologySnapshotPayload.model_validate( + {"obo_url": "http://example.org/go.obo"} + ) + emit = MagicMock() + + mock_resp = MagicMock() + mock_resp.text = _OBO_SAMPLE + mock_resp.raise_for_status = MagicMock() + + with patch( + "protea.core.operations.load_ontology_snapshot.requests.get", + return_value=mock_resp, + ) as mock_get: + result = op._download(payload, emit) + + assert result == _OBO_SAMPLE + mock_get.assert_called_once_with( + "http://example.org/go.obo", timeout=120, stream=True + ) + # Should emit download_start and download_done + assert emit.call_count == 2 + assert emit.call_args_list[0][0][0] == "load_ontology_snapshot.download_start" + assert emit.call_args_list[1][0][0] == "load_ontology_snapshot.download_done" + assert emit.call_args_list[1][0][2]["bytes"] == len(_OBO_SAMPLE) + + def test_download_http_error_propagates(self) -> None: + import requests as req + + op = LoadOntologySnapshotOperation() + payload = LoadOntologySnapshotPayload.model_validate( + {"obo_url": "http://example.org/go.obo"} + ) + mock_resp = MagicMock() + mock_resp.raise_for_status.side_effect = req.HTTPError("404 Not Found") + + with patch( + "protea.core.operations.load_ontology_snapshot.requests.get", + return_value=mock_resp, + ): + with pytest.raises(req.HTTPError): + op._download(payload, MagicMock()) + + def test_download_connection_error_propagates(self) -> None: + import requests as req + + op = LoadOntologySnapshotOperation() + payload = LoadOntologySnapshotPayload.model_validate( + {"obo_url": "http://example.org/go.obo"} + ) + + with patch( + "protea.core.operations.load_ontology_snapshot.requests.get", + side_effect=req.ConnectionError("DNS failure"), + ): + with pytest.raises(req.ConnectionError): + op._download(payload, MagicMock()) + + class TestLoadOntologySnapshotExecute: def _mock_session(self, existing_snapshot=None, rel_count=0): session = MagicMock() @@ -161,3 +332,207 @@ def add_side_effect(obj): assert session.add_all.call_count == 2 terms_call_args = session.add_all.call_args_list[0][0][0] assert len(terms_call_args) == 4 + + def test_new_snapshot_inserts_relationships(self) -> None: + """Lines 163-167: relationship GOTermRelationship objects are created for new snapshots.""" + session = self._mock_session(existing_snapshot=None) + + _id_counter = {"n": 0} + + def add_side_effect(obj): + from protea.infrastructure.orm.models.annotation.ontology_snapshot import ( + OntologySnapshot, + ) + if isinstance(obj, OntologySnapshot): + obj.id = "snap-id" + + session.add.side_effect = add_side_effect + + def add_all_side_effect(items): + """Simulate DB flush assigning IDs to GOTerm objects.""" + for item in items: + from protea.infrastructure.orm.models.annotation.go_term import GOTerm + if isinstance(item, GOTerm) and item.id is None: + _id_counter["n"] += 1 + item.id = _id_counter["n"] + + session.add_all.side_effect = add_all_side_effect + + with patch.object( + LoadOntologySnapshotOperation, + "_download", + return_value=_OBO_WITH_RELATIONSHIPS, + ): + op = LoadOntologySnapshotOperation() + result = op.execute( + session, + {"obo_url": "http://example.org/go.obo"}, + emit=_noop_emit, + ) + + # 3 terms, 3 relationships (1 is_a on GO:0009987, 1 is_a + 1 part_of on GO:0044237) + assert result.result["terms_inserted"] == 3 + assert result.result["relationships_inserted"] == 3 + # Second add_all call is the relationships + rel_call_args = session.add_all.call_args_list[1][0][0] + assert len(rel_call_args) == 3 + + def test_new_snapshot_skips_relationship_with_missing_parent(self) -> None: + """Lines 164-166: if parent GO ID not in go_id_to_db_id, relationship is skipped.""" + obo = ( + "format-version: 1.2\ndata-version: releases/2024-01-01\n\n" + "[Term]\nid: GO:0000001\nname: child\nnamespace: biological_process\n" + "is_a: GO:9999999 ! nonexistent parent\n" + ) + session = self._mock_session(existing_snapshot=None) + + _id_counter = {"n": 0} + + def add_side_effect(obj): + from protea.infrastructure.orm.models.annotation.ontology_snapshot import ( + OntologySnapshot, + ) + if isinstance(obj, OntologySnapshot): + obj.id = "snap-id" + + session.add.side_effect = add_side_effect + + def add_all_side_effect(items): + for item in items: + from protea.infrastructure.orm.models.annotation.go_term import GOTerm + if isinstance(item, GOTerm) and item.id is None: + _id_counter["n"] += 1 + item.id = _id_counter["n"] + + session.add_all.side_effect = add_all_side_effect + + with patch.object( + LoadOntologySnapshotOperation, "_download", return_value=obo + ): + op = LoadOntologySnapshotOperation() + result = op.execute( + session, + {"obo_url": "http://example.org/go.obo"}, + emit=_noop_emit, + ) + + # Parent GO:9999999 doesn't exist in terms, so relationship is skipped + assert result.result["relationships_inserted"] == 0 + + def test_emits_progress_events(self) -> None: + session = self._mock_session(existing_snapshot=None) + emit = MagicMock() + + with patch.object( + LoadOntologySnapshotOperation, "_download", return_value=_OBO_SAMPLE + ): + op = LoadOntologySnapshotOperation() + op.execute(session, {"obo_url": "http://x.org/go.obo"}, emit=emit) + + events = [c.args[0] for c in emit.call_args_list] + assert "load_ontology_snapshot.start" in events + assert "load_ontology_snapshot.version" in events + assert "load_ontology_snapshot.parsed" in events + assert "load_ontology_snapshot.done" in events + + def test_done_event_includes_elapsed(self) -> None: + session = self._mock_session(existing_snapshot=None) + emit = MagicMock() + + with patch.object( + LoadOntologySnapshotOperation, "_download", return_value=_OBO_SAMPLE + ): + op = LoadOntologySnapshotOperation() + result = op.execute(session, {"obo_url": "http://x.org/go.obo"}, emit=emit) + + assert "elapsed_seconds" in result.result + assert result.result["elapsed_seconds"] >= 0 + + def test_backfill_relationships_when_zero(self) -> None: + """Lines 87-125: snapshot exists but has 0 relationships — backfill them.""" + existing = MagicMock() + existing.id = "existing-uuid" + + call_idx = {"n": 0} + + def query_side_effect(*args): + call_idx["n"] += 1 + m = MagicMock() + if call_idx["n"] == 1: + # OntologySnapshot filter_by query + m.filter_by.return_value.first.return_value = existing + elif call_idx["n"] == 2: + # func.count(GOTermRelationship.id) → 0 + m.filter.return_value.scalar.return_value = 0 + elif call_idx["n"] == 3: + # GOTerm (go_id, id) query for the backfill map + m.filter.return_value.all.return_value = [ + ("GO:0003674", 1), + ("GO:0008150", 2), + ("GO:0005575", 3), + ("GO:0003824", 4), + ] + return m + + session = MagicMock() + session.query.side_effect = query_side_effect + emit = MagicMock() + + with patch.object( + LoadOntologySnapshotOperation, "_download", return_value=_OBO_SAMPLE + ): + op = LoadOntologySnapshotOperation() + result = op.execute( + session, + {"obo_url": "http://example.org/go.obo"}, + emit=emit, + ) + + assert result.result["skipped"] is False + assert result.result["ontology_snapshot_id"] == "existing-uuid" + assert "relationships_inserted" in result.result + session.add_all.assert_called_once() + session.flush.assert_called_once() + + events = [c.args[0] for c in emit.call_args_list] + assert "load_ontology_snapshot.backfill_relationships" in events + assert "load_ontology_snapshot.backfill_done" in events + + def test_backfill_skips_unknown_go_ids(self) -> None: + """Lines 103-107: during backfill, terms with no DB ID are skipped.""" + existing = MagicMock() + existing.id = "existing-uuid" + + call_idx = {"n": 0} + + def query_side_effect(*args): + call_idx["n"] += 1 + m = MagicMock() + if call_idx["n"] == 1: + m.filter_by.return_value.first.return_value = existing + elif call_idx["n"] == 2: + m.filter.return_value.scalar.return_value = 0 + elif call_idx["n"] == 3: + # Return only one term — the others won't be in the map + m.filter.return_value.all.return_value = [("GO:0003674", 1)] + return m + + session = MagicMock() + session.query.side_effect = query_side_effect + + with patch.object( + LoadOntologySnapshotOperation, "_download", return_value=_OBO_SAMPLE + ): + op = LoadOntologySnapshotOperation() + result = op.execute( + session, + {"obo_url": "http://example.org/go.obo"}, + emit=_noop_emit, + ) + + assert result.result["relationships_inserted"] == 0 + + def test_invalid_payload_raises(self) -> None: + op = LoadOntologySnapshotOperation() + with pytest.raises(Exception): + op.execute(MagicMock(), {}, emit=_noop_emit) diff --git a/tests/test_load_quickgo_annotations.py b/tests/test_load_quickgo_annotations.py index c50346a..48b90f8 100644 --- a/tests/test_load_quickgo_annotations.py +++ b/tests/test_load_quickgo_annotations.py @@ -1,9 +1,10 @@ from __future__ import annotations import uuid -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +import requests from protea.core.operations.load_quickgo_annotations import ( LoadQuickGOAnnotationsOperation, @@ -155,3 +156,670 @@ def test_raw_eco_stored_when_no_mapping(self) -> None: from sqlalchemy.dialects.postgresql import dialect as pg_dialect compiled = call_stmt.compile(dialect=pg_dialect()) assert compiled.params["evidence_code_m0"] == "ECO:0000314" + + def test_empty_eco_id_becomes_none(self) -> None: + op = self._op() + session = MagicMock() + row = dict(_QUICKGO_ROWS[0]) + row["ECO ID"] = "" + inserted, _ = op._store_buffer( + session, [row], uuid.UUID(_SNAPSHOT_ID), + valid_accessions={"P12345"}, + go_term_map={"GO:0003824": 1}, + eco_map={}, + ) + assert inserted == 1 + + def test_empty_accession_skipped(self) -> None: + op = self._op() + session = MagicMock() + row = dict(_QUICKGO_ROWS[0]) + row["GENE PRODUCT ID"] = " " + inserted, skipped = op._store_buffer( + session, [row], uuid.UUID(_SNAPSHOT_ID), + valid_accessions={"P12345"}, + go_term_map={"GO:0003824": 1}, + eco_map={}, + ) + assert inserted == 0 + assert skipped == 1 + + def test_chunked_insert_large_buffer(self) -> None: + """When to_add > 5000, session.execute is called multiple times.""" + op = self._op() + session = MagicMock() + records = [dict(_QUICKGO_ROWS[0])] * 5001 + inserted, skipped = op._store_buffer( + session, records, uuid.UUID(_SNAPSHOT_ID), + valid_accessions={"P12345"}, + go_term_map={"GO:0003824": 1}, + eco_map={}, + ) + assert inserted == 5001 + assert skipped == 0 + assert session.execute.call_count == 2 + + +# --------------------------------------------------------------------------- +# _load_accessions +# --------------------------------------------------------------------------- + +class TestLoadAccessions: + def test_returns_canonical_and_protein_sets(self) -> None: + op = LoadQuickGOAnnotationsOperation() + session = MagicMock() + session.scalars.side_effect = [ + iter({"P12345", "Q99999"}), + iter({"P12345", "P12345-2", "Q99999"}), + ] + events: list[str] = [] + emit = lambda event, msg, fields, level: events.append(event) + + canon, prots = op._load_accessions(session, emit) + assert canon == {"P12345", "Q99999"} + assert prots == {"P12345", "P12345-2", "Q99999"} + assert "load_quickgo_annotations.load_accessions_start" in events + assert "load_quickgo_annotations.load_accessions_done" in events + + def test_emits_counts(self) -> None: + op = LoadQuickGOAnnotationsOperation() + session = MagicMock() + session.scalars.side_effect = [iter({"A", "B"}), iter({"A", "B", "C"})] + fields_log: list[dict] = [] + emit = lambda event, msg, fields, level: fields_log.append(fields) + + op._load_accessions(session, emit) + done_fields = fields_log[-1] + assert done_fields["canonical_accessions"] == 2 + assert done_fields["protein_accessions"] == 3 + + +# --------------------------------------------------------------------------- +# _load_go_term_map +# --------------------------------------------------------------------------- + +class TestLoadGoTermMap: + def test_returns_mapping(self) -> None: + op = LoadQuickGOAnnotationsOperation() + session = MagicMock() + sid = uuid.uuid4() + query_mock = MagicMock() + query_mock.filter.return_value.all.return_value = [ + ("GO:0005634", 1), ("GO:0008150", 2), + ] + session.query.return_value = query_mock + + events: list[str] = [] + emit = lambda event, msg, fields, level: events.append(event) + + result = op._load_go_term_map(session, sid, emit) + assert result == {"GO:0005634": 1, "GO:0008150": 2} + assert "load_quickgo_annotations.load_go_terms_start" in events + assert "load_quickgo_annotations.load_go_terms_done" in events + + def test_empty_terms(self) -> None: + op = LoadQuickGOAnnotationsOperation() + session = MagicMock() + query_mock = MagicMock() + query_mock.filter.return_value.all.return_value = [] + session.query.return_value = query_mock + + result = op._load_go_term_map(session, uuid.uuid4(), _noop_emit) + assert result == {} + + +# --------------------------------------------------------------------------- +# _load_eco_mapping +# --------------------------------------------------------------------------- + +class TestLoadEcoMapping: + def test_no_url_returns_empty(self) -> None: + op = LoadQuickGOAnnotationsOperation() + p = LoadQuickGOAnnotationsPayload.model_validate({ + "ontology_snapshot_id": _SNAPSHOT_ID, + "source_version": "v1", + }) + assert op._load_eco_mapping(p, _noop_emit) == {} + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_parses_mapping_file(self, mock_get) -> None: + resp = MagicMock() + resp.text = "ECO:0000314 IDA\nECO:0000501 IEA\n# comment\nbadline\n" + resp.raise_for_status = MagicMock() + mock_get.return_value = resp + + op = LoadQuickGOAnnotationsOperation() + p = LoadQuickGOAnnotationsPayload.model_validate({ + "ontology_snapshot_id": _SNAPSHOT_ID, + "source_version": "v1", + "eco_mapping_url": "https://eco.test/map.txt", + }) + result = op._load_eco_mapping(p, _noop_emit) + assert result == {"ECO:0000314": "IDA", "ECO:0000501": "IEA"} + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_http_error_raises(self, mock_get) -> None: + resp = MagicMock() + resp.raise_for_status.side_effect = requests.HTTPError("404") + mock_get.return_value = resp + + op = LoadQuickGOAnnotationsOperation() + p = LoadQuickGOAnnotationsPayload.model_validate({ + "ontology_snapshot_id": _SNAPSHOT_ID, + "source_version": "v1", + "eco_mapping_url": "https://eco.test/bad", + }) + with pytest.raises(requests.HTTPError): + op._load_eco_mapping(p, _noop_emit) + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_emits_start_and_done(self, mock_get) -> None: + resp = MagicMock() + resp.text = "ECO:0000314 IDA\n" + resp.raise_for_status = MagicMock() + mock_get.return_value = resp + + op = LoadQuickGOAnnotationsOperation() + p = LoadQuickGOAnnotationsPayload.model_validate({ + "ontology_snapshot_id": _SNAPSHOT_ID, + "source_version": "v1", + "eco_mapping_url": "https://eco.test/map.txt", + }) + events: list[str] = [] + emit = lambda event, msg, fields, level: events.append(event) + op._load_eco_mapping(p, emit) + assert "load_quickgo_annotations.eco_mapping_start" in events + assert "load_quickgo_annotations.eco_mapping_done" in events + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_ignores_non_eco_lines(self, mock_get) -> None: + resp = MagicMock() + resp.text = "ECO:0000314 IDA\nNOT_ECO stuff\n \nECO:0000501 IEA\n" + resp.raise_for_status = MagicMock() + mock_get.return_value = resp + + op = LoadQuickGOAnnotationsOperation() + p = LoadQuickGOAnnotationsPayload.model_validate({ + "ontology_snapshot_id": _SNAPSHOT_ID, + "source_version": "v1", + "eco_mapping_url": "https://eco.test/map.txt", + }) + result = op._load_eco_mapping(p, _noop_emit) + assert len(result) == 2 + + +# --------------------------------------------------------------------------- +# _fetch_quickgo_page — TSV stream parsing +# --------------------------------------------------------------------------- + +import io as _io + +QUICKGO_HEADER_LINE = ( + "GENE PRODUCT ID\tGO TERM\tQUALIFIER\tECO ID\tREFERENCE\tWITH/FROM\tASSIGNED BY\tDATE" +) + + +def _tsv_row_str( + accession: str = "P12345", + go_term: str = "GO:0005634", + qualifier: str = "enables", + eco_id: str = "ECO:0000314", + reference: str = "PMID:12345", + with_from: str = "", + assigned_by: str = "UniProt", + date: str = "20240101", +) -> str: + return f"{accession}\t{go_term}\t{qualifier}\t{eco_id}\t{reference}\t{with_from}\t{assigned_by}\t{date}" + + +def _make_tsv_text(*data_rows: str) -> str: + return "\n".join([QUICKGO_HEADER_LINE] + list(data_rows)) + "\n" + + +def _make_stream_response(text: str, status_code: int = 200) -> MagicMock: + resp = MagicMock() + resp.status_code = status_code + resp.raise_for_status = MagicMock() + if status_code >= 400: + resp.raise_for_status.side_effect = requests.HTTPError(f"{status_code}") + raw = _io.BytesIO(text.encode("utf-8")) + resp.raw = raw + resp.raw.decode_content = True + return resp + + +class TestFetchQuickgoPage: + def _payload(self, **kw): + return LoadQuickGOAnnotationsPayload.model_validate({ + "ontology_snapshot_id": _SNAPSHOT_ID, + "source_version": "v1", + **kw, + }) + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_parses_rows(self, mock_get) -> None: + tsv = _make_tsv_text( + _tsv_row_str("P12345", "GO:0005634"), + _tsv_row_str("Q99999", "GO:0008150"), + ) + mock_get.return_value = _make_stream_response(tsv) + + op = LoadQuickGOAnnotationsOperation() + records = list( + op._fetch_quickgo_page(self._payload(), _noop_emit, gp_ids=["P12345"], batch_index=0, total_batches=1) + ) + assert len(records) == 2 + assert records[0]["GENE PRODUCT ID"] == "P12345" + assert records[1]["GO TERM"] == "GO:0008150" + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_skips_empty_lines(self, mock_get) -> None: + tsv = QUICKGO_HEADER_LINE + "\n\n" + _tsv_row_str() + "\n\n" + mock_get.return_value = _make_stream_response(tsv) + + op = LoadQuickGOAnnotationsOperation() + records = list( + op._fetch_quickgo_page(self._payload(), _noop_emit, gp_ids=None, batch_index=0, total_batches=1) + ) + assert len(records) == 1 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_skips_short_rows(self, mock_get) -> None: + tsv = QUICKGO_HEADER_LINE + "\ntoo\tfew\n" + _tsv_row_str() + "\n" + mock_get.return_value = _make_stream_response(tsv) + + op = LoadQuickGOAnnotationsOperation() + records = list( + op._fetch_quickgo_page(self._payload(), _noop_emit, gp_ids=None, batch_index=0, total_batches=1) + ) + assert len(records) == 1 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_http_error_raises(self, mock_get) -> None: + mock_get.return_value = _make_stream_response("", status_code=500) + + op = LoadQuickGOAnnotationsOperation() + with pytest.raises(requests.HTTPError): + list( + op._fetch_quickgo_page(self._payload(), _noop_emit, gp_ids=None, batch_index=0, total_batches=1) + ) + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_sends_correct_params_with_gp_ids(self, mock_get) -> None: + mock_get.return_value = _make_stream_response(_make_tsv_text()) + + op = LoadQuickGOAnnotationsOperation() + list( + op._fetch_quickgo_page(self._payload(), _noop_emit, gp_ids=["P12345", "Q99999"], batch_index=0, total_batches=1) + ) + _, kwargs = mock_get.call_args + assert kwargs["params"]["geneProductId"] == "P12345,Q99999" + assert kwargs["params"]["geneProductType"] == "protein" + assert kwargs["headers"]["Accept"] == "text/tsv" + assert kwargs["stream"] is True + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_no_gp_ids_omits_gene_product_param(self, mock_get) -> None: + mock_get.return_value = _make_stream_response(_make_tsv_text()) + + op = LoadQuickGOAnnotationsOperation() + list( + op._fetch_quickgo_page(self._payload(), _noop_emit, gp_ids=None, batch_index=0, total_batches=1) + ) + _, kwargs = mock_get.call_args + assert "geneProductId" not in kwargs["params"] + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_emits_download_start_with_progress(self, mock_get) -> None: + mock_get.return_value = _make_stream_response(_make_tsv_text()) + events: list[tuple[str, dict]] = [] + emit = lambda event, msg, fields, level: events.append((event, fields)) + + op = LoadQuickGOAnnotationsOperation() + list( + op._fetch_quickgo_page(self._payload(), emit, gp_ids=["X"], batch_index=2, total_batches=5) + ) + start_events = [e for e in events if e[0] == "load_quickgo_annotations.download_start"] + assert len(start_events) == 1 + assert start_events[0][1]["batch"] == 3 + assert start_events[0][1]["of"] == 5 + assert start_events[0][1]["_progress_current"] == 3 + assert start_events[0][1]["_progress_total"] == 5 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_header_only_yields_nothing(self, mock_get) -> None: + tsv = QUICKGO_HEADER_LINE + "\n" + mock_get.return_value = _make_stream_response(tsv) + + op = LoadQuickGOAnnotationsOperation() + records = list( + op._fetch_quickgo_page(self._payload(), _noop_emit, gp_ids=None, batch_index=0, total_batches=1) + ) + assert records == [] + + +# --------------------------------------------------------------------------- +# _stream_quickgo — batching logic +# --------------------------------------------------------------------------- + +class TestStreamQuickgo: + def _payload(self, **kw): + return LoadQuickGOAnnotationsPayload.model_validate({ + "ontology_snapshot_id": _SNAPSHOT_ID, + "source_version": "v1", + **kw, + }) + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_batches_accessions(self, mock_get) -> None: + mock_get.side_effect = lambda *a, **kw: _make_stream_response(_make_tsv_text()) + + op = LoadQuickGOAnnotationsOperation() + p = self._payload(gene_product_batch_size=2) + list(op._stream_quickgo(p, _noop_emit, gene_product_ids=["A", "B", "C", "D", "E"])) + assert mock_get.call_count == 3 # 2+2+1 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_no_ids_single_request(self, mock_get) -> None: + mock_get.return_value = _make_stream_response(_make_tsv_text()) + + op = LoadQuickGOAnnotationsOperation() + p = self._payload(use_db_accessions=False) + list(op._stream_quickgo(p, _noop_emit, gene_product_ids=None)) + assert mock_get.call_count == 1 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_emits_batching_event(self, mock_get) -> None: + mock_get.side_effect = lambda *a, **kw: _make_stream_response(_make_tsv_text()) + + events: list[tuple[str, dict]] = [] + emit = lambda event, msg, fields, level: events.append((event, fields)) + + op = LoadQuickGOAnnotationsOperation() + p = self._payload(gene_product_batch_size=2) + list(op._stream_quickgo(p, emit, gene_product_ids=["A", "B", "C"])) + batching = [e for e in events if e[0] == "load_quickgo_annotations.batching"] + assert len(batching) == 1 + assert batching[0][1]["total_accessions"] == 3 + assert batching[0][1]["total_batches"] == 2 + assert batching[0][1]["batch_size"] == 2 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_yields_records_from_all_batches(self, mock_get) -> None: + tsv = _make_tsv_text(_tsv_row_str("P12345")) + mock_get.side_effect = lambda *a, **kw: _make_stream_response(tsv) + + op = LoadQuickGOAnnotationsOperation() + p = self._payload(gene_product_batch_size=1) + records = list(op._stream_quickgo(p, _noop_emit, gene_product_ids=["A", "B"])) + # Each batch returns 1 record, 2 batches + assert len(records) == 2 + + +# --------------------------------------------------------------------------- +# Full execute flow +# --------------------------------------------------------------------------- + +def _mock_session( + canonical_accessions: set[str] | None = None, + protein_accessions: set[str] | None = None, + go_terms: list[tuple[str, int]] | None = None, + snapshot_exists: bool = True, +) -> MagicMock: + session = MagicMock() + if snapshot_exists: + session.get.return_value = MagicMock() + else: + session.get.return_value = None + + canon = canonical_accessions if canonical_accessions is not None else {"P12345"} + prots = protein_accessions if protein_accessions is not None else {"P12345"} + session.scalars.side_effect = [iter(canon), iter(prots)] + + terms = go_terms or [("GO:0003824", 1), ("GO:0008150", 2)] + query_mock = MagicMock() + query_mock.filter.return_value.all.return_value = terms + session.query.return_value = query_mock + + def _set_id(obj): + obj.id = uuid.uuid4() + session.add.side_effect = _set_id + + return session + + +def _base_payload(**overrides) -> dict: + d = { + "ontology_snapshot_id": _SNAPSHOT_ID, + "source_version": "2024-01-01", + "quickgo_base_url": "https://quickgo.test/annotation/downloadSearch", + "use_db_accessions": True, + "eco_mapping_url": None, + "page_size": 100, + "timeout_seconds": 10, + "commit_every_page": False, + "gene_product_batch_size": 200, + } + d.update(overrides) + return d + + +class TestExecute: + def test_snapshot_not_found_raises(self) -> None: + session = _mock_session(snapshot_exists=False) + op = LoadQuickGOAnnotationsOperation() + with pytest.raises(ValueError, match="not found"): + op.execute(session, _base_payload(), emit=_noop_emit) + + def test_no_proteins_returns_zero(self) -> None: + session = _mock_session(canonical_accessions=set()) + session.scalars.side_effect = [iter(set()), iter(set())] + op = LoadQuickGOAnnotationsOperation() + events: list[str] = [] + emit = lambda event, msg, fields, level: events.append(event) + result = op.execute(session, _base_payload(), emit=emit) + assert result.result["annotations_inserted"] == 0 + assert "load_quickgo_annotations.no_proteins" in events + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_full_run_inserts_and_skips(self, mock_get) -> None: + tsv = _make_tsv_text( + _tsv_row_str("P12345", "GO:0003824"), + _tsv_row_str("UNKNOWN", "GO:0003824"), + _tsv_row_str("P12345", "GO:9999999"), + ) + mock_get.return_value = _make_stream_response(tsv) + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345"}, + go_terms=[("GO:0003824", 1)], + ) + + events: list[str] = [] + emit = lambda event, msg, fields, level: events.append(event) + + op = LoadQuickGOAnnotationsOperation() + result = op.execute(session, _base_payload(), emit=emit) + assert result.result["annotations_inserted"] == 1 + assert result.result["annotations_skipped"] == 2 + assert "load_quickgo_annotations.start" in events + assert "load_quickgo_annotations.done" in events + assert "load_quickgo_annotations.annotation_set_created" in events + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_total_limit_stops_early(self, mock_get) -> None: + tsv = _make_tsv_text( + _tsv_row_str("P12345", "GO:0003824"), + _tsv_row_str("P12345", "GO:0008150"), + _tsv_row_str("P12345", "GO:0003824"), + ) + mock_get.return_value = _make_stream_response(tsv) + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345"}, + ) + + events: list[str] = [] + emit = lambda event, msg, fields, level: events.append(event) + + op = LoadQuickGOAnnotationsOperation() + result = op.execute( + session, _base_payload(total_limit=1, page_size=1), emit=emit, + ) + assert "load_quickgo_annotations.limit_reached" in events + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_commit_every_page(self, mock_get) -> None: + tsv = _make_tsv_text( + _tsv_row_str("P12345", "GO:0003824"), + _tsv_row_str("P12345", "GO:0008150"), + ) + mock_get.return_value = _make_stream_response(tsv) + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345"}, + ) + + op = LoadQuickGOAnnotationsOperation() + op.execute( + session, _base_payload(commit_every_page=True, page_size=1), emit=_noop_emit, + ) + assert session.commit.call_count >= 2 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_no_commit_when_disabled(self, mock_get) -> None: + tsv = _make_tsv_text(_tsv_row_str("P12345", "GO:0003824")) + mock_get.return_value = _make_stream_response(tsv) + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345"}, + ) + + op = LoadQuickGOAnnotationsOperation() + op.execute( + session, _base_payload(commit_every_page=False, page_size=1), emit=_noop_emit, + ) + session.commit.assert_not_called() + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_page_done_emitted(self, mock_get) -> None: + tsv = _make_tsv_text( + _tsv_row_str("P12345", "GO:0003824"), + _tsv_row_str("P12345", "GO:0008150"), + _tsv_row_str("P12345", "GO:0003824"), + ) + mock_get.return_value = _make_stream_response(tsv) + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345"}, + ) + + events: list[tuple[str, dict]] = [] + emit = lambda event, msg, fields, level: events.append((event, fields)) + + op = LoadQuickGOAnnotationsOperation() + result = op.execute(session, _base_payload(page_size=2), emit=emit) + page_done = [e for e in events if e[0] == "load_quickgo_annotations.page_done"] + assert len(page_done) >= 1 + assert result.result["pages"] == 2 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_result_contains_elapsed_seconds(self, mock_get) -> None: + tsv = _make_tsv_text(_tsv_row_str("P12345", "GO:0003824")) + mock_get.return_value = _make_stream_response(tsv) + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345"}, + ) + + op = LoadQuickGOAnnotationsOperation() + result = op.execute(session, _base_payload(), emit=_noop_emit) + assert "elapsed_seconds" in result.result + assert result.result["elapsed_seconds"] >= 0 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_use_db_accessions_false(self, mock_get) -> None: + tsv = _make_tsv_text(_tsv_row_str("X00001", "GO:0003824")) + mock_get.return_value = _make_stream_response(tsv) + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345", "X00001"}, + go_terms=[("GO:0003824", 1)], + ) + + op = LoadQuickGOAnnotationsOperation() + result = op.execute( + session, + _base_payload(use_db_accessions=False, gene_product_ids=["X00001"]), + emit=_noop_emit, + ) + _, kwargs = mock_get.call_args + assert "X00001" in kwargs["params"]["geneProductId"] + assert result.result["annotations_inserted"] == 1 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_eco_mapping_integrated_in_execute(self, mock_get) -> None: + eco_resp = MagicMock() + eco_resp.text = "ECO:0000314 IDA\n" + eco_resp.raise_for_status = MagicMock() + + tsv_resp = _make_stream_response( + _make_tsv_text(_tsv_row_str("P12345", "GO:0003824", eco_id="ECO:0000314")) + ) + + mock_get.side_effect = [eco_resp, tsv_resp] + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345"}, + go_terms=[("GO:0003824", 1)], + ) + + op = LoadQuickGOAnnotationsOperation() + result = op.execute( + session, + _base_payload(eco_mapping_url="https://eco.test/map.txt"), + emit=_noop_emit, + ) + assert result.result["annotations_inserted"] == 1 + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_result_has_annotation_set_id(self, mock_get) -> None: + tsv = _make_tsv_text(_tsv_row_str("P12345", "GO:0003824")) + mock_get.return_value = _make_stream_response(tsv) + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345"}, + ) + + op = LoadQuickGOAnnotationsOperation() + result = op.execute(session, _base_payload(), emit=_noop_emit) + assert "annotation_set_id" in result.result + + @patch("protea.core.operations.load_quickgo_annotations.requests.get") + def test_remainder_buffer_flushed(self, mock_get) -> None: + """Records that don't fill a full page are still flushed at the end.""" + tsv = _make_tsv_text(_tsv_row_str("P12345", "GO:0003824")) + mock_get.return_value = _make_stream_response(tsv) + + session = _mock_session( + canonical_accessions={"P12345"}, + protein_accessions={"P12345"}, + ) + + op = LoadQuickGOAnnotationsOperation() + # page_size much larger than record count → only remainder flush + result = op.execute(session, _base_payload(page_size=10000), emit=_noop_emit) + assert result.result["annotations_inserted"] == 1 + assert result.result["pages"] == 1 + + def test_operation_name(self) -> None: + assert LoadQuickGOAnnotationsOperation().name == "load_quickgo_annotations" diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..496fa9e --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,157 @@ +"""Tests for protea/infrastructure/logging.py""" +from __future__ import annotations + +import json +import logging +from unittest.mock import patch + +import pytest + +from protea.infrastructure.logging import JSONFormatter, configure_logging + + +class TestJSONFormatter: + """Tests for the JSONFormatter class.""" + + def _make_record(self, msg="hello", level=logging.INFO, name="test.logger", **kwargs): + record = logging.LogRecord( + name=name, + level=level, + pathname="test.py", + lineno=1, + msg=msg, + args=(), + exc_info=kwargs.pop("exc_info", None), + ) + for k, v in kwargs.items(): + setattr(record, k, v) + return record + + def test_formats_valid_json_with_expected_keys(self): + formatter = JSONFormatter() + record = self._make_record("test message") + output = formatter.format(record) + data = json.loads(output) + + assert "timestamp" in data + assert data["level"] == "INFO" + assert data["message"] == "test message" + assert data["logger"] == "test.logger" + + def test_timestamp_is_utc_iso_format(self): + formatter = JSONFormatter() + record = self._make_record() + data = json.loads(formatter.format(record)) + # UTC ISO timestamps end with +00:00 + assert "+00:00" in data["timestamp"] + + def test_includes_exc_info_when_present(self): + formatter = JSONFormatter() + try: + raise ValueError("boom") + except ValueError: + import sys + exc_info = sys.exc_info() + + record = self._make_record("error occurred", exc_info=exc_info) + data = json.loads(formatter.format(record)) + + assert "exception" in data + assert "ValueError" in data["exception"] + assert "boom" in data["exception"] + + def test_exc_info_absent_when_no_exception(self): + formatter = JSONFormatter() + record = self._make_record("all good") + data = json.loads(formatter.format(record)) + assert "exception" not in data + + def test_includes_extra_fields(self): + formatter = JSONFormatter() + record = self._make_record("with extras", queue="protea.jobs", batch_size=100) + data = json.loads(formatter.format(record)) + + assert data["queue"] == "protea.jobs" + assert data["batch_size"] == 100 + + def test_builtin_attrs_excluded_from_extras(self): + formatter = JSONFormatter() + record = self._make_record("check builtins") + data = json.loads(formatter.format(record)) + + # Standard LogRecord attributes should not appear as top-level keys + for attr in ("args", "exc_info", "exc_text", "lineno", "pathname", "thread"): + assert attr not in data + + def test_stack_info_included_when_present(self): + formatter = JSONFormatter() + record = self._make_record("with stack") + record.stack_info = "Stack trace here" + data = json.loads(formatter.format(record)) + assert data["stack_info"] == "Stack trace here" + + def test_non_serializable_extra_uses_default_str(self): + formatter = JSONFormatter() + record = self._make_record("non-serializable", obj=object()) + # Should not raise — json.dumps(default=str) handles it + output = formatter.format(record) + data = json.loads(output) + assert "obj" in data + + +class TestConfigureLogging: + """Tests for the configure_logging function.""" + + def setup_method(self): + """Save root logger state before each test.""" + self._root = logging.getLogger() + self._original_handlers = list(self._root.handlers) + self._original_level = self._root.level + + def teardown_method(self): + """Restore root logger state after each test.""" + self._root.handlers = self._original_handlers + self._root.setLevel(self._original_level) + + def test_json_true_sets_json_formatter(self): + configure_logging(json=True, level="WARNING") + root = logging.getLogger() + assert len(root.handlers) == 1 + assert isinstance(root.handlers[0].formatter, JSONFormatter) + + def test_json_false_uses_standard_formatter(self): + configure_logging(json=False, level="INFO") + root = logging.getLogger() + assert len(root.handlers) == 1 + formatter = root.handlers[0].formatter + assert not isinstance(formatter, JSONFormatter) + assert isinstance(formatter, logging.Formatter) + + def test_respects_level_parameter(self): + configure_logging(json=True, level="DEBUG") + assert logging.getLogger().level == logging.DEBUG + + configure_logging(json=True, level="ERROR") + assert logging.getLogger().level == logging.ERROR + + def test_level_is_case_insensitive(self): + configure_logging(json=True, level="warning") + assert logging.getLogger().level == logging.WARNING + + def test_clears_existing_handlers(self): + root = logging.getLogger() + root.addHandler(logging.StreamHandler()) + root.addHandler(logging.StreamHandler()) + assert len(root.handlers) >= 2 + + configure_logging(json=True) + assert len(root.handlers) == 1 + + def test_invalid_level_falls_back_to_info(self): + configure_logging(json=True, level="NONEXISTENT") + assert logging.getLogger().level == logging.INFO + + def test_handler_is_stream_handler(self): + configure_logging(json=True) + root = logging.getLogger() + assert isinstance(root.handlers[0], logging.StreamHandler) diff --git a/tests/test_proteins_router.py b/tests/test_proteins_router.py new file mode 100644 index 0000000..ea53c89 --- /dev/null +++ b/tests/test_proteins_router.py @@ -0,0 +1,353 @@ +"""Unit tests for the /proteins router. + +Database is fully mocked -- no real infrastructure required. +""" +from __future__ import annotations + +from contextlib import contextmanager +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from protea.api.routers.proteins import router + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_app(session_factory): + app = FastAPI() + app.state.session_factory = session_factory + app.include_router(router) + return app + + +@contextmanager +def _mock_scope(session): + yield session + + +def _make_protein(**overrides): + defaults = { + "accession": "P12345", + "entry_name": "TEST_HUMAN", + "gene_name": "TEST", + "organism": "Homo sapiens", + "taxonomy_id": 9606, + "length": 100, + "reviewed": True, + "is_canonical": True, + "canonical_accession": "P12345", + "isoform_index": None, + "sequence_id": 1, + } + defaults.update(overrides) + p = MagicMock() + for k, v in defaults.items(): + setattr(p, k, v) + return p + + +def _make_metadata(): + meta = MagicMock() + for attr in ( + "function_cc", "ec_number", "catalytic_activity", "pathway", + "keywords", "cofactor", "activity_regulation", "absorption", + "kinetics", "ph_dependence", "redox_potential", "temperature_dependence", + "active_site", "binding_site", "dna_binding", "rhea_id", "site", "features", + ): + setattr(meta, attr, f"mock_{attr}") + return meta + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def session(): + return MagicMock() + + +@pytest.fixture() +def factory(session): + return MagicMock() + + +@pytest.fixture() +def client(session, factory): + app = _make_app(factory) + with patch( + "protea.api.routers.proteins.session_scope", + side_effect=lambda _: _mock_scope(session), + ): + with TestClient(app) as c: + yield c, session + + +# --------------------------------------------------------------------------- +# GET /proteins/stats +# --------------------------------------------------------------------------- + +class TestProteinStats: + def test_returns_all_stat_keys(self, client): + c, session = client + # Each scalar() call returns a value in order: + # total, canonical, reviewed, with_metadata, with_embeddings, with_go + session.query.return_value.scalar.return_value = 10 + session.query.return_value.filter.return_value.scalar.return_value = 5 + session.query.return_value.join.return_value.scalar.return_value = 3 + + resp = c.get("/proteins/stats") + assert resp.status_code == 200 + data = resp.json() + for key in ( + "total", "canonical", "isoforms", "reviewed", + "unreviewed", "with_metadata", "with_embeddings", "with_go_annotations", + ): + assert key in data + + def test_stats_zero_values(self, client): + c, session = client + session.query.return_value.scalar.return_value = 0 + session.query.return_value.filter.return_value.scalar.return_value = 0 + session.query.return_value.join.return_value.scalar.return_value = 0 + + resp = c.get("/proteins/stats") + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 0 + assert data["isoforms"] == 0 + + +# --------------------------------------------------------------------------- +# GET /proteins +# --------------------------------------------------------------------------- + +class TestListProteins: + def test_returns_paginated_list(self, client): + c, session = client + p = _make_protein() + q_mock = MagicMock() + session.query.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.count.return_value = 1 + q_mock.order_by.return_value.offset.return_value.limit.return_value.all.return_value = [p] + + resp = c.get("/proteins") + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 1 + assert len(data["items"]) == 1 + assert data["items"][0]["accession"] == "P12345" + + def test_search_filter(self, client): + c, session = client + q_mock = MagicMock() + session.query.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.count.return_value = 0 + q_mock.order_by.return_value.offset.return_value.limit.return_value.all.return_value = [] + + resp = c.get("/proteins", params={"search": "kinase"}) + assert resp.status_code == 200 + assert resp.json()["total"] == 0 + assert resp.json()["items"] == [] + + def test_reviewed_filter(self, client): + c, session = client + q_mock = MagicMock() + session.query.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.count.return_value = 0 + q_mock.order_by.return_value.offset.return_value.limit.return_value.all.return_value = [] + + resp = c.get("/proteins", params={"reviewed": "true"}) + assert resp.status_code == 200 + + def test_canonical_only_false(self, client): + c, session = client + q_mock = MagicMock() + session.query.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.count.return_value = 0 + q_mock.order_by.return_value.offset.return_value.limit.return_value.all.return_value = [] + + resp = c.get("/proteins", params={"canonical_only": "false"}) + assert resp.status_code == 200 + + def test_pagination_params(self, client): + c, session = client + q_mock = MagicMock() + session.query.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.count.return_value = 100 + q_mock.order_by.return_value.offset.return_value.limit.return_value.all.return_value = [] + + resp = c.get("/proteins", params={"limit": 10, "offset": 20}) + assert resp.status_code == 200 + data = resp.json() + assert data["limit"] == 10 + assert data["offset"] == 20 + + def test_empty_list(self, client): + c, session = client + q_mock = MagicMock() + session.query.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.count.return_value = 0 + q_mock.order_by.return_value.offset.return_value.limit.return_value.all.return_value = [] + + resp = c.get("/proteins") + assert resp.status_code == 200 + assert resp.json()["items"] == [] + + +# --------------------------------------------------------------------------- +# GET /proteins/{accession} +# --------------------------------------------------------------------------- + +class TestGetProtein: + def test_returns_protein_with_metadata(self, client): + c, session = client + p = _make_protein() + meta = _make_metadata() + session.get.return_value = p + session.query.return_value.filter.return_value.first.return_value = meta + session.query.return_value.filter.return_value.scalar.return_value = 2 + session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + + resp = c.get("/proteins/P12345") + assert resp.status_code == 200 + data = resp.json() + assert data["accession"] == "P12345" + assert data["metadata"] is not None + assert data["metadata"]["function_cc"] == "mock_function_cc" + + def test_returns_protein_without_metadata(self, client): + c, session = client + p = _make_protein() + session.get.return_value = p + session.query.return_value.filter.return_value.first.return_value = None + session.query.return_value.filter.return_value.scalar.return_value = 0 + session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + + resp = c.get("/proteins/P12345") + assert resp.status_code == 200 + assert resp.json()["metadata"] is None + + def test_not_found_returns_404(self, client): + c, session = client + session.get.return_value = None + + resp = c.get("/proteins/UNKNOWN") + assert resp.status_code == 404 + + def test_canonical_lists_isoforms(self, client): + c, session = client + p = _make_protein(is_canonical=True) + meta = _make_metadata() + session.get.return_value = p + session.query.return_value.filter.return_value.first.return_value = meta + session.query.return_value.filter.return_value.scalar.return_value = 0 + + iso1 = MagicMock() + iso1.accession = "P12345-2" + iso2 = MagicMock() + iso2.accession = "P12345-3" + session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [iso1, iso2] + + resp = c.get("/proteins/P12345") + assert resp.status_code == 200 + assert resp.json()["isoforms"] == ["P12345-2", "P12345-3"] + + def test_non_canonical_no_isoform_list(self, client): + c, session = client + p = _make_protein(is_canonical=False, accession="P12345-2", sequence_id=None) + session.get.return_value = p + session.query.return_value.filter.return_value.first.return_value = None + session.query.return_value.filter.return_value.scalar.return_value = 0 + + resp = c.get("/proteins/P12345-2") + assert resp.status_code == 200 + data = resp.json() + assert data["isoforms"] == [] + assert data["embedding_count"] == 0 + + +# --------------------------------------------------------------------------- +# GET /proteins/{accession}/annotations +# --------------------------------------------------------------------------- + +class TestGetProteinAnnotations: + def _make_annotation_row(self, go_id="GO:0003674", name="molecular_function", + aspect="F", qualifier="enables", evidence="IDA", + assigned_by="UniProt", db_ref="PMID:123", + ann_set_id=None, source="goa", version="2024-01"): + ann = MagicMock() + ann.qualifier = qualifier + ann.evidence_code = evidence + ann.assigned_by = assigned_by + ann.db_reference = db_ref + ann.annotation_set_id = ann_set_id or uuid4() + + gt = MagicMock() + gt.go_id = go_id + gt.name = name + gt.aspect = aspect + + aset = MagicMock() + aset.source = source + aset.source_version = version + + return (ann, gt, aset) + + def test_returns_annotations(self, client): + c, session = client + row = self._make_annotation_row() + q_mock = MagicMock() + session.query.return_value.join.return_value.join.return_value.filter.return_value = q_mock + q_mock.order_by.return_value.all.return_value = [row] + + resp = c.get("/proteins/P12345/annotations") + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["go_id"] == "GO:0003674" + assert data[0]["evidence_code"] == "IDA" + + def test_empty_annotations(self, client): + c, session = client + q_mock = MagicMock() + session.query.return_value.join.return_value.join.return_value.filter.return_value = q_mock + q_mock.order_by.return_value.all.return_value = [] + + resp = c.get("/proteins/P12345/annotations") + assert resp.status_code == 200 + assert resp.json() == [] + + def test_filter_by_annotation_set_id(self, client): + c, session = client + ann_set_id = uuid4() + row = self._make_annotation_row(ann_set_id=ann_set_id) + q_mock = MagicMock() + session.query.return_value.join.return_value.join.return_value.filter.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.order_by.return_value.all.return_value = [row] + + resp = c.get("/proteins/P12345/annotations", params={"annotation_set_id": str(ann_set_id)}) + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + def test_invalid_annotation_set_id_returns_422(self, client): + c, session = client + q_mock = MagicMock() + session.query.return_value.join.return_value.join.return_value.filter.return_value = q_mock + q_mock.filter.side_effect = ValueError("bad uuid") + + resp = c.get("/proteins/P12345/annotations", params={"annotation_set_id": "not-a-uuid"}) + assert resp.status_code == 422 diff --git a/tests/test_queue.py b/tests/test_queue.py index 5912620..15024f8 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -5,6 +5,7 @@ from __future__ import annotations import json +import threading from unittest.mock import MagicMock, patch from uuid import UUID, uuid4 @@ -142,10 +143,20 @@ def _mock_pika(self, consumer): return conn, channel - def test_run_declares_queue(self): + def test_run_declares_queue_with_dlx(self): consumer = _consumer() conn, channel = self._mock_pika(consumer) - channel.queue_declare.assert_called_once_with(queue="test.jobs", durable=True) + # DLQ + main queue + assert channel.queue_declare.call_count == 2 + channel.queue_declare.assert_any_call(queue="protea.dead-letter", durable=True) + channel.queue_declare.assert_any_call( + queue="test.jobs", + durable=True, + arguments={"x-dead-letter-exchange": "protea.dlx"}, + ) + channel.exchange_declare.assert_called_once_with( + exchange="protea.dlx", exchange_type="fanout", durable=True + ) def test_run_sets_prefetch(self): consumer = _consumer() @@ -180,7 +191,8 @@ def test_publishes_correct_body(self): conn.channel.return_value = channel conn.is_open = True - with patch("protea.infrastructure.queue.publisher.pika.BlockingConnection", return_value=conn): + with patch("protea.infrastructure.queue.publisher.pika.BlockingConnection", return_value=conn), \ + patch("protea.infrastructure.queue.publisher._local", threading.local()): publish_job("amqp://localhost/", "test.jobs", job_id) channel.basic_publish.assert_called_once() @@ -189,16 +201,18 @@ def test_publishes_correct_body(self): body = json.loads(kwargs["body"].decode()) assert body["job_id"] == str(job_id) - def test_closes_connection_on_success(self): + def test_reuses_connection_on_success(self): + """With thread-local connection reuse, conn is NOT closed after a successful publish.""" conn = MagicMock() channel = MagicMock() conn.channel.return_value = channel conn.is_open = True - with patch("protea.infrastructure.queue.publisher.pika.BlockingConnection", return_value=conn): + with patch("protea.infrastructure.queue.publisher.pika.BlockingConnection", return_value=conn), \ + patch("protea.infrastructure.queue.publisher._local", threading.local()): publish_job("amqp://localhost/", "q", uuid4()) - conn.close.assert_called_once() + conn.close.assert_not_called() def test_closes_connection_on_exception(self): conn = MagicMock() @@ -208,12 +222,13 @@ def test_closes_connection_on_exception(self): conn.is_open = True with patch("protea.infrastructure.queue.publisher.pika.BlockingConnection", return_value=conn), \ - patch("protea.infrastructure.queue.publisher.time.sleep"): + patch("protea.infrastructure.queue.publisher.time.sleep"), \ + patch("protea.infrastructure.queue.publisher._local", threading.local()): with pytest.raises(RuntimeError, match="Failed to publish to queue"): publish_job("amqp://localhost/", "q", uuid4()) - # close() is called once per retry attempt (4 total: 1 initial + 3 retries) - assert conn.close.call_count == 4 + # _close_cached_connection calls conn.close() once per failed attempt (5 total) + assert conn.close.call_count == 5 def test_declares_durable_queue(self): conn = MagicMock() @@ -221,7 +236,545 @@ def test_declares_durable_queue(self): conn.channel.return_value = channel conn.is_open = False - with patch("protea.infrastructure.queue.publisher.pika.BlockingConnection", return_value=conn): + with patch("protea.infrastructure.queue.publisher.pika.BlockingConnection", return_value=conn), \ + patch("protea.infrastructure.queue.publisher._local", threading.local()): publish_job("amqp://localhost/", "my.queue", uuid4()) channel.queue_declare.assert_called_once_with(queue="my.queue", durable=True) + + def test_exponential_backoff_delays(self): + """Verify that the publisher uses exponential backoff between retries.""" + conn = MagicMock() + channel = MagicMock() + channel.basic_publish.side_effect = RuntimeError("broker down") + conn.channel.return_value = channel + conn.is_open = True + + sleep_calls = [] + with patch("protea.infrastructure.queue.publisher.pika.BlockingConnection", return_value=conn), \ + patch("protea.infrastructure.queue.publisher.time.sleep", side_effect=lambda d: sleep_calls.append(d)), \ + patch("protea.infrastructure.queue.publisher._local", threading.local()): + with pytest.raises(RuntimeError, match="Failed to publish"): + publish_job("amqp://localhost/", "q", uuid4()) + + # 5 attempts → 4 sleeps: 1, 2, 4, 8 + assert sleep_calls == [1, 2, 4, 8] + + +# --------------------------------------------------------------------------- +# OperationConsumer — emit writes to parent job +# --------------------------------------------------------------------------- + +class TestOperationConsumerEmit: + """Verify that OperationConsumer's emit writes JobEvent rows to the parent job.""" + + def test_emit_writes_job_event_on_parent(self): + from protea.infrastructure.queue.consumer import OperationConsumer + from protea.core.contracts.operation import OperationResult + + parent_job_id = uuid4() + + # Mock registry and operation + op = MagicMock() + op.execute.return_value = OperationResult() + registry = MagicMock() + registry.get.return_value = op + + # Track sessions created by the factory + sessions = [] + def make_session(): + s = MagicMock() + sessions.append(s) + return s + factory = MagicMock(side_effect=make_session) + + consumer = OperationConsumer( + amqp_url="amqp://localhost/", + queue_name="test.queue", + registry=registry, + session_factory=factory, + ) + + # Build a valid message with a parent job_id + body = json.dumps({ + "operation": "test_op", + "job_id": str(parent_job_id), + "payload": {"key": "value"}, + }).encode() + + channel = MagicMock() + method = _make_method() + props = MagicMock() + + consumer._on_message(channel, method, props, body) + + # Operation should have been called + op.execute.assert_called_once() + channel.basic_ack.assert_called_once() + + def test_emit_records_failure_on_parent(self): + from protea.infrastructure.queue.consumer import OperationConsumer + + parent_job_id = uuid4() + + # Operation that raises + op = MagicMock() + op.execute.side_effect = ValueError("boom") + registry = MagicMock() + registry.get.return_value = op + + sessions = [] + def make_session(): + s = MagicMock() + sessions.append(s) + return s + factory = MagicMock(side_effect=make_session) + + consumer = OperationConsumer( + amqp_url="amqp://localhost/", + queue_name="test.queue", + registry=registry, + session_factory=factory, + ) + + body = json.dumps({ + "operation": "test_op", + "job_id": str(parent_job_id), + "payload": {}, + }).encode() + + channel = MagicMock() + method = _make_method() + props = MagicMock() + + consumer._on_message(channel, method, props, body) + + # Should nack (not requeue by default) + channel.basic_nack.assert_called_once() + # Should have created a session to write the error event + # At least: 1 execution session + 1 error event session + assert len(sessions) >= 2 + # The error event session should have had .add() called with a JobEvent + error_session = sessions[-1] + error_session.add.assert_called_once() + error_session.commit.assert_called_once() + + +# --------------------------------------------------------------------------- +# OperationConsumer._on_message — extended coverage +# --------------------------------------------------------------------------- + +class TestOperationConsumerOnMessage: + """Cover uncovered lines in OperationConsumer._on_message.""" + + def _make_consumer(self, op=None, raises=None, requeue_on_failure=False): + from protea.infrastructure.queue.consumer import OperationConsumer + from protea.core.contracts.operation import OperationResult + + if op is None: + op = MagicMock() + if raises: + op.execute.side_effect = raises + else: + op.execute.return_value = OperationResult() + + registry = MagicMock() + registry.get.return_value = op + + sessions = [] + def make_session(): + s = MagicMock() + sessions.append(s) + return s + + factory = MagicMock(side_effect=make_session) + + consumer = OperationConsumer( + amqp_url="amqp://localhost/", + queue_name="test.ops", + registry=registry, + session_factory=factory, + requeue_on_failure=requeue_on_failure, + ) + return consumer, sessions, factory, op + + def _body(self, operation="test_op", job_id=None, payload=None): + msg = { + "operation": operation, + "payload": payload or {}, + } + if job_id is not None: + msg["job_id"] = str(job_id) + return json.dumps(msg).encode() + + def test_successful_operation_acks(self): + consumer, sessions, _, op = self._make_consumer() + channel = MagicMock() + method = _make_method(10) + + consumer._on_message(channel, method, MagicMock(), self._body()) + + op.execute.assert_called_once() + channel.basic_ack.assert_called_once_with(delivery_tag=10) + channel.basic_nack.assert_not_called() + + def test_failed_operation_nacks_without_requeue(self): + consumer, sessions, _, _ = self._make_consumer(raises=ValueError("oops")) + channel = MagicMock() + method = _make_method(20) + + consumer._on_message(channel, method, MagicMock(), self._body()) + + channel.basic_nack.assert_called_once_with(delivery_tag=20, requeue=False) + channel.basic_ack.assert_not_called() + + def test_failed_operation_nacks_with_requeue_when_flag_set(self): + consumer, sessions, _, _ = self._make_consumer( + raises=ValueError("oops"), requeue_on_failure=True + ) + channel = MagicMock() + method = _make_method(21) + + consumer._on_message(channel, method, MagicMock(), self._body()) + + channel.basic_nack.assert_called_once_with(delivery_tag=21, requeue=True) + + def test_cuda_oom_clears_cache_and_requeues(self): + exc = RuntimeError("CUDA out of memory. Tried to allocate 2 GiB") + consumer, sessions, _, _ = self._make_consumer(raises=exc) + channel = MagicMock() + method = _make_method(30) + + with patch("protea.infrastructure.queue.consumer.torch", create=True) as mock_torch: + # Import torch inside the handler — we patch at module level + import sys + mock_module = MagicMock() + with patch.dict(sys.modules, {"torch": mock_module}): + consumer._on_message(channel, method, MagicMock(), self._body()) + + # Should requeue regardless of requeue_on_failure flag + channel.basic_nack.assert_called_once() + call_kwargs = channel.basic_nack.call_args.kwargs + assert call_kwargs["requeue"] is True + + def test_unparseable_message_nacks_without_requeue(self): + consumer, _, _, _ = self._make_consumer() + channel = MagicMock() + method = _make_method(40) + + consumer._on_message(channel, method, MagicMock(), b"not json") + + channel.basic_nack.assert_called_once_with(delivery_tag=40, requeue=False) + channel.basic_ack.assert_not_called() + + def test_missing_operation_key_nacks(self): + consumer, _, _, _ = self._make_consumer() + channel = MagicMock() + method = _make_method(41) + body = json.dumps({"payload": {}}).encode() + + consumer._on_message(channel, method, MagicMock(), body) + + channel.basic_nack.assert_called_once_with(delivery_tag=41, requeue=False) + + def test_stop_flag_nacks_with_requeue(self): + consumer, _, _, _ = self._make_consumer() + consumer._stop = True + channel = MagicMock() + method = _make_method(50) + + consumer._on_message(channel, method, MagicMock(), self._body()) + + channel.basic_nack.assert_called_once_with(delivery_tag=50, requeue=True) + channel.basic_ack.assert_not_called() + + def test_emit_writes_job_event_to_parent_session(self): + """When operation calls emit, a JobEvent is written to a separate session.""" + from protea.core.contracts.operation import OperationResult + + parent_id = uuid4() + + def _execute(session, payload, *, emit): + emit("progress", "doing stuff", {"step": 1}, "info") + return OperationResult() + + op = MagicMock() + op.execute.side_effect = _execute + + consumer, sessions, _, _ = self._make_consumer(op=op) + channel = MagicMock() + method = _make_method() + + consumer._on_message(channel, method, MagicMock(), self._body(job_id=parent_id)) + + # sessions: [0]=execution session, [1]=emit event session + assert len(sessions) >= 2 + emit_session = sessions[1] + emit_session.add.assert_called_once() + emit_session.commit.assert_called_once() + emit_session.close.assert_called_once() + + def test_emit_without_parent_job_id_only_logs(self): + """When no job_id in message, emit should not create an event session.""" + from protea.core.contracts.operation import OperationResult + + def _execute(session, payload, *, emit): + emit("progress", "no parent", {}, "info") + return OperationResult() + + op = MagicMock() + op.execute.side_effect = _execute + + consumer, sessions, _, _ = self._make_consumer(op=op) + channel = MagicMock() + method = _make_method() + + # Message without job_id + body = json.dumps({"operation": "test_op", "payload": {}}).encode() + consumer._on_message(channel, method, MagicMock(), body) + + # Only the execution session should have been created (no event session) + assert len(sessions) == 1 + + def test_emit_session_failure_is_handled_gracefully(self): + """If writing the event to DB fails, the operation should still complete.""" + from protea.core.contracts.operation import OperationResult + + parent_id = uuid4() + call_count = [0] + + def _execute(session, payload, *, emit): + emit("progress", "msg", {}, "info") + return OperationResult() + + op = MagicMock() + op.execute.side_effect = _execute + + sessions_created = [] + def make_session(): + s = MagicMock() + sessions_created.append(s) + # Make the second session (emit session) fail on commit + if len(sessions_created) == 2: + s.commit.side_effect = RuntimeError("DB down") + return s + + from protea.infrastructure.queue.consumer import OperationConsumer + registry = MagicMock() + registry.get.return_value = op + factory = MagicMock(side_effect=make_session) + + consumer = OperationConsumer( + amqp_url="amqp://localhost/", + queue_name="test.ops", + registry=registry, + session_factory=factory, + ) + channel = MagicMock() + method = _make_method() + + consumer._on_message(channel, method, MagicMock(), self._body(job_id=parent_id)) + + # Should still ack despite emit failure + channel.basic_ack.assert_called_once() + + def test_publish_operations_forwarded(self): + """Downstream publish_operations from result are forwarded via publish_operation.""" + from protea.core.contracts.operation import OperationResult + + result = OperationResult( + publish_operations=[ + ("protea.embeddings.write", {"batch": [1, 2]}), + ("protea.predictions.write", {"batch": [3, 4]}), + ] + ) + op = MagicMock() + op.execute.return_value = result + + consumer, sessions, _, _ = self._make_consumer(op=op) + channel = MagicMock() + method = _make_method() + + with patch("protea.infrastructure.queue.consumer.publish_operation") as mock_pub: + consumer._on_message(channel, method, MagicMock(), self._body()) + + assert mock_pub.call_count == 2 + mock_pub.assert_any_call("amqp://localhost/", "protea.embeddings.write", {"batch": [1, 2]}) + mock_pub.assert_any_call("amqp://localhost/", "protea.predictions.write", {"batch": [3, 4]}) + + def test_failed_operation_writes_error_event_to_parent(self): + """On failure with parent_job_id, a child.failed event is written.""" + parent_id = uuid4() + consumer, sessions, _, _ = self._make_consumer(raises=TypeError("bad type")) + channel = MagicMock() + method = _make_method() + + consumer._on_message(channel, method, MagicMock(), self._body(job_id=parent_id)) + + # Find the error event session (last one created besides execution session) + # sessions: [0]=execution, [1]=error event + assert len(sessions) >= 2 + err_session = sessions[-1] + err_session.add.assert_called_once() + added_event = err_session.add.call_args[0][0] + assert added_event.job_id == parent_id + assert added_event.event == "child.failed" + assert added_event.level == "error" + assert "bad type" in added_event.message + + def test_invalid_job_id_in_message_is_ignored(self): + """If job_id is not a valid UUID, parent_job_id should be None (no crash).""" + from protea.core.contracts.operation import OperationResult + + op = MagicMock() + op.execute.return_value = OperationResult() + + consumer, sessions, _, _ = self._make_consumer(op=op) + channel = MagicMock() + method = _make_method() + + body = json.dumps({ + "operation": "test_op", + "job_id": "not-a-uuid", + "payload": {}, + }).encode() + + consumer._on_message(channel, method, MagicMock(), body) + + # Should still succeed — only 1 session (execution), no event sessions + channel.basic_ack.assert_called_once() + assert len(sessions) == 1 + + def test_error_event_session_rollback_on_commit_failure(self): + """If the error event session commit fails, rollback is called.""" + parent_id = uuid4() + + sessions_created = [] + def make_session(): + s = MagicMock() + sessions_created.append(s) + # Make the error event session (3rd: exec + err_event) fail + if len(sessions_created) == 2: + s.commit.side_effect = RuntimeError("DB gone") + return s + + from protea.infrastructure.queue.consumer import OperationConsumer + op = MagicMock() + op.execute.side_effect = ValueError("boom") + registry = MagicMock() + registry.get.return_value = op + factory = MagicMock(side_effect=make_session) + + consumer = OperationConsumer( + amqp_url="amqp://localhost/", + queue_name="test.ops", + registry=registry, + session_factory=factory, + ) + channel = MagicMock() + method = _make_method() + + consumer._on_message(channel, method, MagicMock(), self._body(job_id=parent_id)) + + # Error event session should have rollback called + err_session = sessions_created[1] + err_session.rollback.assert_called_once() + err_session.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# QueueConsumer._on_message — RetryLaterError handling +# --------------------------------------------------------------------------- + +class TestQueueConsumerRetryLater: + """Cover RetryLaterError handling in QueueConsumer._on_message (lines 142-151).""" + + def test_retry_later_sleeps_and_republishes(self): + from protea.core.contracts.operation import RetryLaterError + + job_id = uuid4() + worker = _make_worker(raises=RetryLaterError("GPU busy", delay_seconds=30)) + consumer = _consumer(worker) + + channel = MagicMock() + method = _make_method(99) + props = MagicMock() + + consumer._on_message(channel, method, props, _encode(job_id)) + + # Should ack before execution + channel.basic_ack.assert_called_once_with(delivery_tag=99) + # Should sleep on the connection + channel.connection.sleep.assert_called_once_with(30) + # Should re-publish + channel.basic_publish.assert_called_once() + pub_kwargs = channel.basic_publish.call_args.kwargs + assert pub_kwargs["routing_key"] == "test.jobs" + body = json.loads(pub_kwargs["body"].decode()) + assert body["job_id"] == str(job_id) + + def test_shutdown_draining_nacks_with_requeue(self): + """When _stop is set, messages are nacked with requeue=True.""" + consumer = _consumer() + consumer._stop = True + + channel = MagicMock() + method = _make_method(77) + + consumer._on_message(channel, method, MagicMock(), _encode(uuid4())) + + channel.basic_nack.assert_called_once_with(delivery_tag=77, requeue=True) + channel.basic_ack.assert_not_called() + + +# --------------------------------------------------------------------------- +# OperationConsumer._handle_stop +# --------------------------------------------------------------------------- + +class TestOperationConsumerHandleStop: + def test_handle_stop_sets_flag(self): + from protea.infrastructure.queue.consumer import OperationConsumer + + consumer = OperationConsumer( + amqp_url="amqp://localhost/", + queue_name="test.ops", + registry=MagicMock(), + session_factory=MagicMock(), + ) + assert consumer._stop is False + consumer._handle_stop() + assert consumer._stop is True + + +# --------------------------------------------------------------------------- +# OperationConsumer.run (pika fully mocked) +# --------------------------------------------------------------------------- + +class TestOperationConsumerRun: + def test_run_declares_queue_and_starts_consuming(self): + from protea.infrastructure.queue.consumer import OperationConsumer + + consumer = OperationConsumer( + amqp_url="amqp://localhost/", + queue_name="test.ops", + registry=MagicMock(), + session_factory=MagicMock(), + prefetch_count=4, + ) + + conn = MagicMock() + channel = MagicMock() + conn.channel.return_value = channel + conn.is_open = False + + with patch("protea.infrastructure.queue.consumer.pika.BlockingConnection", return_value=conn): + consumer.run() + + channel.queue_declare.assert_any_call( + queue="test.ops", + durable=True, + arguments={"x-dead-letter-exchange": "protea.dlx"}, + ) + channel.basic_qos.assert_called_once_with(prefetch_count=4) + channel.basic_consume.assert_called_once() + channel.start_consuming.assert_called_once() diff --git a/tests/test_run_cafa_evaluation.py b/tests/test_run_cafa_evaluation.py new file mode 100644 index 0000000..3b2c7a2 --- /dev/null +++ b/tests/test_run_cafa_evaluation.py @@ -0,0 +1,1167 @@ +"""Unit tests for RunCafaEvaluationOperation. + +No real DB, network, or cafaeval binary required — everything is mocked. +""" +from __future__ import annotations + +import gzip +import os +import tempfile +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock, call, patch + +import pandas as pd +import pytest +from pydantic import ValidationError + +from protea.core.evaluation import EvaluationData +from protea.core.operations.run_cafa_evaluation import ( + RunCafaEvaluationOperation, + RunCafaEvaluationPayload, + _NS_LABELS, + _NS_SHORT, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +EVAL_SET_ID = str(uuid.uuid4()) +PRED_SET_ID = str(uuid.uuid4()) +OLD_ANN_SET_ID = uuid.uuid4() +NEW_ANN_SET_ID = uuid.uuid4() +SNAP_ID = uuid.uuid4() +SCORING_CONFIG_ID = str(uuid.uuid4()) + + +def _make_emit(): + """Return a mock emit function that records all calls.""" + return MagicMock() + + +def _make_eval_set(eval_set_id=None): + es = MagicMock() + es.id = uuid.UUID(eval_set_id or EVAL_SET_ID) + es.old_annotation_set_id = OLD_ANN_SET_ID + es.new_annotation_set_id = NEW_ANN_SET_ID + return es + + +def _make_pred_set(pred_set_id=None): + ps = MagicMock() + ps.id = uuid.UUID(pred_set_id or PRED_SET_ID) + return ps + + +def _make_ann_old(): + ann = MagicMock() + ann.ontology_snapshot_id = SNAP_ID + return ann + + +def _make_snapshot(obo_url="https://example.com/go.obo", ia_url=None): + snap = MagicMock() + snap.obo_url = obo_url + snap.ia_url = ia_url + return snap + + +def _make_eval_data(nk=None, lk=None, pk=None, known=None, pk_known=None): + return EvaluationData( + nk=nk or {"P1": {"GO:0000001"}}, + lk=lk or {"P2": {"GO:0000002"}}, + pk=pk or {}, + known=known or {}, + pk_known=pk_known or {}, + ) + + +def _make_scoring_config(): + sc = MagicMock() + sc.formula = "linear" + sc.weights = {"embedding_similarity": 1.0} + return sc + + +def _dfs_best_fixture(): + """Build a dfs_best dict matching cafaeval output format.""" + df_f = pd.DataFrame( + [ + { + "ns": "biological_process", + "f": 0.45, + "pr": 0.51, + "rc": 0.40, + "tau": 0.32, + "cov_max": 0.95, + "n": 100, + }, + { + "ns": "molecular_function", + "f": 0.60, + "pr": 0.65, + "rc": 0.55, + "tau": 0.20, + "cov_max": 0.88, + "n": 50, + }, + { + "ns": "cellular_component", + "f": 0.70, + "pr": 0.72, + "rc": 0.68, + "tau": 0.15, + "cov_max": 0.92, + "n": 75, + }, + ] + ) + return {"f": df_f} + + +# --------------------------------------------------------------------------- +# Payload validation +# --------------------------------------------------------------------------- + + +class TestRunCafaEvaluationPayload: + def test_valid_payload(self): + p = RunCafaEvaluationPayload( + evaluation_set_id=EVAL_SET_ID, + prediction_set_id=PRED_SET_ID, + ) + assert p.evaluation_set_id == EVAL_SET_ID + assert p.prediction_set_id == PRED_SET_ID + assert p.max_distance is None + assert p.artifacts_dir is None + assert p.scoring_config_id is None + assert p.ia_file is None + + def test_valid_payload_all_fields(self): + p = RunCafaEvaluationPayload( + evaluation_set_id=EVAL_SET_ID, + prediction_set_id=PRED_SET_ID, + max_distance=1.5, + artifacts_dir="/tmp/artifacts", + scoring_config_id=SCORING_CONFIG_ID, + ia_file="/tmp/ia.tsv", + ) + assert p.max_distance == 1.5 + assert p.artifacts_dir == "/tmp/artifacts" + assert p.scoring_config_id == SCORING_CONFIG_ID + assert p.ia_file == "/tmp/ia.tsv" + + def test_empty_evaluation_set_id_raises(self): + with pytest.raises(ValidationError, match="non-empty"): + RunCafaEvaluationPayload( + evaluation_set_id=" ", + prediction_set_id=PRED_SET_ID, + ) + + def test_empty_prediction_set_id_raises(self): + with pytest.raises(ValidationError, match="non-empty"): + RunCafaEvaluationPayload( + evaluation_set_id=EVAL_SET_ID, + prediction_set_id="", + ) + + def test_non_string_evaluation_set_id_raises(self): + with pytest.raises(ValidationError): + RunCafaEvaluationPayload( + evaluation_set_id=123, + prediction_set_id=PRED_SET_ID, + ) + + def test_max_distance_out_of_range(self): + with pytest.raises(ValidationError): + RunCafaEvaluationPayload( + evaluation_set_id=EVAL_SET_ID, + prediction_set_id=PRED_SET_ID, + max_distance=3.0, + ) + + def test_max_distance_negative(self): + with pytest.raises(ValidationError): + RunCafaEvaluationPayload( + evaluation_set_id=EVAL_SET_ID, + prediction_set_id=PRED_SET_ID, + max_distance=-0.1, + ) + + def test_strips_whitespace(self): + p = RunCafaEvaluationPayload( + evaluation_set_id=f" {EVAL_SET_ID} ", + prediction_set_id=f" {PRED_SET_ID} ", + ) + assert p.evaluation_set_id == EVAL_SET_ID + assert p.prediction_set_id == PRED_SET_ID + + def test_frozen_payload(self): + p = RunCafaEvaluationPayload( + evaluation_set_id=EVAL_SET_ID, + prediction_set_id=PRED_SET_ID, + ) + with pytest.raises(ValidationError): + p.evaluation_set_id = "new_value" + + +# --------------------------------------------------------------------------- +# Operation name +# --------------------------------------------------------------------------- + + +class TestOperationName: + def test_name(self): + op = RunCafaEvaluationOperation() + assert op.name == "run_cafa_evaluation" + + +# --------------------------------------------------------------------------- +# _parse_results +# --------------------------------------------------------------------------- + + +class TestParseResults: + def setup_method(self): + self.op = RunCafaEvaluationOperation() + + def test_parse_all_namespaces(self): + dfs_best = _dfs_best_fixture() + result = self.op._parse_results(dfs_best) + assert set(result.keys()) == {"BPO", "MFO", "CCO"} + + def test_parse_bpo_values(self): + dfs_best = _dfs_best_fixture() + result = self.op._parse_results(dfs_best) + bpo = result["BPO"] + assert bpo["fmax"] == 0.45 + assert bpo["precision"] == 0.51 + assert bpo["recall"] == 0.40 + assert bpo["tau"] == 0.32 + assert bpo["coverage"] == 0.95 + assert bpo["n_proteins"] == 100 + + def test_parse_mfo_values(self): + dfs_best = _dfs_best_fixture() + result = self.op._parse_results(dfs_best) + mfo = result["MFO"] + assert mfo["fmax"] == 0.60 + assert mfo["precision"] == 0.65 + assert mfo["recall"] == 0.55 + + def test_parse_empty_dfs_best(self): + result = self.op._parse_results({}) + assert result == {} + + def test_parse_none_df_f(self): + result = self.op._parse_results({"f": None}) + assert result == {} + + def test_parse_empty_df_f(self): + result = self.op._parse_results({"f": pd.DataFrame()}) + assert result == {} + + def test_parse_ignores_unknown_namespaces(self): + df_f = pd.DataFrame( + [{"ns": "unknown_namespace", "f": 0.5, "pr": 0.5, "rc": 0.5, "tau": 0.1, "cov_max": 0.9, "n": 10}] + ) + result = self.op._parse_results({"f": df_f}) + assert result == {} + + def test_parse_uses_cov_fallback_when_no_cov_max(self): + df_f = pd.DataFrame( + [{"ns": "biological_process", "f": 0.5, "pr": 0.5, "rc": 0.5, "tau": 0.1, "cov": 0.85, "n": 10}] + ) + result = self.op._parse_results({"f": df_f}) + assert result["BPO"]["coverage"] == 0.85 + + def test_parse_missing_n_column(self): + df_f = pd.DataFrame( + [{"ns": "biological_process", "f": 0.5, "pr": 0.5, "rc": 0.5, "tau": 0.1, "cov_max": 0.9}] + ) + result = self.op._parse_results({"f": df_f}) + assert result["BPO"]["n_proteins"] is None + + +# --------------------------------------------------------------------------- +# _write_gt +# --------------------------------------------------------------------------- + + +class TestWriteGt: + def setup_method(self): + self.op = RunCafaEvaluationOperation() + + def test_write_gt_basic(self): + annotations = { + "P2": {"GO:0000002", "GO:0000003"}, + "P1": {"GO:0000001"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".tsv", delete=False) as f: + path = f.name + try: + self.op._write_gt(annotations, path) + with open(path) as f: + lines = f.read().strip().split("\n") + # Sorted by protein then by GO ID + assert lines[0] == "P1\tGO:0000001" + assert lines[1] == "P2\tGO:0000002" + assert lines[2] == "P2\tGO:0000003" + assert len(lines) == 3 + finally: + os.unlink(path) + + def test_write_gt_empty(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".tsv", delete=False) as f: + path = f.name + try: + self.op._write_gt({}, path) + with open(path) as f: + content = f.read() + assert content == "" + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# _download_obo +# --------------------------------------------------------------------------- + + +class TestDownloadObo: + def setup_method(self): + self.op = RunCafaEvaluationOperation() + + @patch("protea.core.operations.run_cafa_evaluation.requests.get") + def test_download_plain(self, mock_get): + mock_resp = MagicMock() + mock_resp.text = "format-version: 1.2\n" + mock_resp.raise_for_status = MagicMock() + mock_get.return_value = mock_resp + + with tempfile.NamedTemporaryFile(suffix=".obo", delete=False) as f: + path = f.name + try: + self.op._download_obo("https://example.com/go.obo", path) + with open(path) as f: + assert f.read() == "format-version: 1.2\n" + finally: + os.unlink(path) + + @patch("protea.core.operations.run_cafa_evaluation.requests.get") + def test_download_gzip(self, mock_get): + original = b"format-version: 1.2\n" + compressed = gzip.compress(original) + mock_resp = MagicMock() + mock_resp.content = compressed + mock_resp.raise_for_status = MagicMock() + mock_get.return_value = mock_resp + + with tempfile.NamedTemporaryFile(suffix=".obo", delete=False) as f: + path = f.name + try: + self.op._download_obo("https://example.com/go.obo.gz", path) + with open(path, "rb") as f: + assert f.read() == original + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# _download_tsv +# --------------------------------------------------------------------------- + + +class TestDownloadTsv: + def setup_method(self): + self.op = RunCafaEvaluationOperation() + + def test_local_absolute_path(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".tsv", delete=False) as src: + src.write("GO:0001\t0.5\n") + src_path = src.name + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as dst: + dst_path = dst.name + try: + self.op._download_tsv(src_path, dst_path) + with open(dst_path) as f: + assert f.read() == "GO:0001\t0.5\n" + finally: + os.unlink(src_path) + os.unlink(dst_path) + + def test_local_file_scheme(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".tsv", delete=False) as src: + src.write("GO:0002\t0.8\n") + src_path = src.name + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as dst: + dst_path = dst.name + try: + self.op._download_tsv(f"file://{src_path}", dst_path) + with open(dst_path) as f: + assert f.read() == "GO:0002\t0.8\n" + finally: + os.unlink(src_path) + os.unlink(dst_path) + + def test_local_gzip_path(self): + original = b"GO:0003\t0.3\n" + with tempfile.NamedTemporaryFile(suffix=".tsv.gz", delete=False) as src: + src.write(gzip.compress(original)) + src_path = src.name + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as dst: + dst_path = dst.name + try: + self.op._download_tsv(src_path, dst_path) + with open(dst_path, "rb") as f: + assert f.read() == original + finally: + os.unlink(src_path) + os.unlink(dst_path) + + @patch("protea.core.operations.run_cafa_evaluation.requests.get") + def test_http_download(self, mock_get): + mock_resp = MagicMock() + mock_resp.text = "GO:0004\t0.9\n" + mock_resp.raise_for_status = MagicMock() + mock_get.return_value = mock_resp + + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as dst: + dst_path = dst.name + try: + self.op._download_tsv("https://example.com/ia.tsv", dst_path) + with open(dst_path) as f: + assert f.read() == "GO:0004\t0.9\n" + finally: + os.unlink(dst_path) + + @patch("protea.core.operations.run_cafa_evaluation.requests.get") + def test_http_gzip_download(self, mock_get): + original = b"GO:0005\t0.6\n" + mock_resp = MagicMock() + mock_resp.content = gzip.compress(original) + mock_resp.raise_for_status = MagicMock() + mock_get.return_value = mock_resp + + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as dst: + dst_path = dst.name + try: + self.op._download_tsv("https://example.com/ia.tsv.gz", dst_path) + with open(dst_path, "rb") as f: + assert f.read() == original + finally: + os.unlink(dst_path) + + +# --------------------------------------------------------------------------- +# _write_predictions +# --------------------------------------------------------------------------- + + +class TestWritePredictions: + def setup_method(self): + self.op = RunCafaEvaluationOperation() + + def test_write_predictions_without_scoring_config(self): + pred_mock = MagicMock() + pred_mock.protein_accession = "P1" + pred_mock.distance = 0.4 + pred_mock.identity_nw = None + pred_mock.identity_sw = None + pred_mock.evidence_code = None + pred_mock.taxonomic_distance = None + + gt_mock = MagicMock() + gt_mock.go_id = "GO:0000001" + + session = MagicMock() + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [(pred_mock, gt_mock)] + + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as f: + path = f.name + try: + self.op._write_predictions( + session, uuid.uuid4(), {"P1"}, None, path, None + ) + with open(path) as f: + line = f.read().strip() + # score = max(0, 1 - 0.4/2) = 0.8 + assert line == "P1\tGO:0000001\t0.8000" + finally: + os.unlink(path) + + def test_write_predictions_deduplicates(self): + pred1 = MagicMock() + pred1.protein_accession = "P1" + pred1.distance = 0.2 + + pred2 = MagicMock() + pred2.protein_accession = "P1" + pred2.distance = 0.6 + + gt_mock = MagicMock() + gt_mock.go_id = "GO:0000001" + + session = MagicMock() + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [(pred1, gt_mock), (pred2, gt_mock)] + + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as f: + path = f.name + try: + self.op._write_predictions( + session, uuid.uuid4(), {"P1"}, None, path, None + ) + with open(path) as f: + lines = f.read().strip().split("\n") + # Only the first (closest) prediction should be written + assert len(lines) == 1 + finally: + os.unlink(path) + + @patch("protea.core.operations.run_cafa_evaluation.compute_score") + def test_write_predictions_with_scoring_config(self, mock_compute_score): + mock_compute_score.return_value = 0.75 + + pred_mock = MagicMock() + pred_mock.protein_accession = "P1" + pred_mock.distance = 0.4 + pred_mock.identity_nw = 0.8 + pred_mock.identity_sw = 0.9 + pred_mock.evidence_code = "IDA" + pred_mock.taxonomic_distance = 2.0 + + gt_mock = MagicMock() + gt_mock.go_id = "GO:0000001" + + session = MagicMock() + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [(pred_mock, gt_mock)] + + scoring_config = _make_scoring_config() + + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as f: + path = f.name + try: + self.op._write_predictions( + session, uuid.uuid4(), {"P1"}, None, path, scoring_config + ) + with open(path) as f: + line = f.read().strip() + assert line == "P1\tGO:0000001\t0.7500" + mock_compute_score.assert_called_once() + finally: + os.unlink(path) + + def test_write_predictions_zero_distance(self): + pred_mock = MagicMock() + pred_mock.protein_accession = "P1" + pred_mock.distance = 0.0 + + gt_mock = MagicMock() + gt_mock.go_id = "GO:0000001" + + session = MagicMock() + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [(pred_mock, gt_mock)] + + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as f: + path = f.name + try: + self.op._write_predictions( + session, uuid.uuid4(), {"P1"}, None, path, None + ) + with open(path) as f: + line = f.read().strip() + # score = max(0, 1 - 0/2) = 1.0 + assert line == "P1\tGO:0000001\t1.0000" + finally: + os.unlink(path) + + def test_write_predictions_with_max_distance(self): + """When max_distance is provided, query should include the filter.""" + pred_mock = MagicMock() + pred_mock.protein_accession = "P1" + pred_mock.distance = 0.3 + + gt_mock = MagicMock() + gt_mock.go_id = "GO:0000001" + + session = MagicMock() + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [(pred_mock, gt_mock)] + + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as f: + path = f.name + try: + self.op._write_predictions( + session, uuid.uuid4(), {"P1"}, 0.5, path, None + ) + with open(path) as f: + line = f.read().strip() + assert line == "P1\tGO:0000001\t0.8500" + # filter should have been called 3 times: + # pred_set_id, protein_accession IN, distance <= + assert query.filter.call_count == 3 + finally: + os.unlink(path) + + def test_write_predictions_none_distance_fallback(self): + pred_mock = MagicMock() + pred_mock.protein_accession = "P1" + pred_mock.distance = None + + gt_mock = MagicMock() + gt_mock.go_id = "GO:0000001" + + session = MagicMock() + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [(pred_mock, gt_mock)] + + with tempfile.NamedTemporaryFile(suffix=".tsv", delete=False) as f: + path = f.name + try: + self.op._write_predictions( + session, uuid.uuid4(), {"P1"}, None, path, None + ) + with open(path) as f: + line = f.read().strip() + # score = max(0, 1 - 0/2) = 1.0 (None → 0.0) + assert line == "P1\tGO:0000001\t1.0000" + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# execute — error paths +# --------------------------------------------------------------------------- + + +class TestExecuteErrors: + def setup_method(self): + self.op = RunCafaEvaluationOperation() + self.emit = _make_emit() + + def test_missing_evaluation_set(self): + session = MagicMock() + session.get.return_value = None + + with pytest.raises(ValueError, match="EvaluationSet.*not found"): + self.op.execute( + session, + {"evaluation_set_id": EVAL_SET_ID, "prediction_set_id": PRED_SET_ID}, + emit=self.emit, + ) + + def test_missing_prediction_set(self): + session = MagicMock() + eval_set = _make_eval_set() + # First call returns eval_set, second returns None (pred_set missing) + session.get.side_effect = [eval_set, None] + + with pytest.raises(ValueError, match="PredictionSet.*not found"): + self.op.execute( + session, + {"evaluation_set_id": EVAL_SET_ID, "prediction_set_id": PRED_SET_ID}, + emit=self.emit, + ) + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_no_delta_proteins(self, mock_compute): + mock_compute.return_value = EvaluationData( + nk={}, lk={}, pk={}, known={}, pk_known={} + ) + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot() + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + with pytest.raises(ValueError, match="No delta proteins"): + self.op.execute( + session, + {"evaluation_set_id": EVAL_SET_ID, "prediction_set_id": PRED_SET_ID}, + emit=self.emit, + ) + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_missing_scoring_config(self, mock_compute): + mock_compute.return_value = _make_eval_data() + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot() + # get calls: eval_set, pred_set, ann_old, snapshot, scoring_config (None) + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot, None] + + with pytest.raises(ValueError, match="ScoringConfig.*not found"): + self.op.execute( + session, + { + "evaluation_set_id": EVAL_SET_ID, + "prediction_set_id": PRED_SET_ID, + "scoring_config_id": SCORING_CONFIG_ID, + }, + emit=self.emit, + ) + + +# --------------------------------------------------------------------------- +# execute — happy path +# --------------------------------------------------------------------------- + + +class TestExecuteHappyPath: + def setup_method(self): + self.op = RunCafaEvaluationOperation() + self.emit = _make_emit() + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_full_run(self, mock_compute): + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot() + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + # Mock the DB query for _write_predictions + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + dfs_best = _dfs_best_fixture() + + with patch.object(self.op, "_download_obo"): + with patch( + "cafaeval.evaluation.cafa_eval", + return_value=(MagicMock(), dfs_best), + ) as mock_cafa: + result = self.op.execute( + session, + {"evaluation_set_id": EVAL_SET_ID, "prediction_set_id": PRED_SET_ID}, + emit=self.emit, + ) + + assert "evaluation_result_id" in result.result + assert "results" in result.result + # cafa_eval called 3 times: NK, LK, PK + assert mock_cafa.call_count == 3 + # session.add called for EvaluationResult + session.add.assert_called_once() + session.flush.assert_called_once() + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_emit_events(self, mock_compute): + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot() + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + dfs_best = _dfs_best_fixture() + + with patch.object(self.op, "_download_obo"): + with patch( + "cafaeval.evaluation.cafa_eval", + return_value=(MagicMock(), dfs_best), + ): + self.op.execute( + session, + {"evaluation_set_id": EVAL_SET_ID, "prediction_set_id": PRED_SET_ID}, + emit=self.emit, + ) + + # Verify key emit events were fired + emit_events = [c[0][0] for c in self.emit.call_args_list] + assert "run_cafa_evaluation.start" in emit_events + assert "run_cafa_evaluation.computing_delta" in emit_events + assert "run_cafa_evaluation.delta_done" in emit_events + assert "run_cafa_evaluation.downloading_obo" in emit_events + assert "run_cafa_evaluation.writing_predictions" in emit_events + assert "run_cafa_evaluation.done" in emit_events + # 3 evaluating events (NK, LK, PK) + assert emit_events.count("run_cafa_evaluation.evaluating") == 3 + assert emit_events.count("run_cafa_evaluation.setting_done") == 3 + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_cafa_eval_failure_catches_exception(self, mock_compute): + """When cafa_eval raises for one setting, it should log warning and continue.""" + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot() + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + with patch.object(self.op, "_download_obo"): + with patch( + "cafaeval.evaluation.cafa_eval", + side_effect=RuntimeError("cafa_eval exploded"), + ): + result = self.op.execute( + session, + {"evaluation_set_id": EVAL_SET_ID, "prediction_set_id": PRED_SET_ID}, + emit=self.emit, + ) + + # All three settings should be empty dicts (all failed) + results = result.result["results"] + assert results["NK"] == {} + assert results["LK"] == {} + assert results["PK"] == {} + + # Emit should have 3 setting_failed events + emit_events = [c[0][0] for c in self.emit.call_args_list] + assert emit_events.count("run_cafa_evaluation.setting_failed") == 3 + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_ia_missing_warning(self, mock_compute): + """When no IA file and no ia_url, a warning should be emitted.""" + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot(ia_url=None) # no ia_url + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + with patch.object(self.op, "_download_obo"): + with patch( + "cafaeval.evaluation.cafa_eval", + return_value=(MagicMock(), _dfs_best_fixture()), + ): + self.op.execute( + session, + {"evaluation_set_id": EVAL_SET_ID, "prediction_set_id": PRED_SET_ID}, + emit=self.emit, + ) + + emit_events = [c[0][0] for c in self.emit.call_args_list] + assert "run_cafa_evaluation.ia_missing" in emit_events + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_ia_url_download(self, mock_compute): + """When snapshot has ia_url, _download_tsv should be called.""" + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot(ia_url="https://example.com/ia.tsv") + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + with patch.object(self.op, "_download_obo"), \ + patch.object(self.op, "_download_tsv") as mock_dl_tsv, \ + patch( + "cafaeval.evaluation.cafa_eval", + return_value=(MagicMock(), _dfs_best_fixture()), + ): + self.op.execute( + session, + {"evaluation_set_id": EVAL_SET_ID, "prediction_set_id": PRED_SET_ID}, + emit=self.emit, + ) + + mock_dl_tsv.assert_called_once() + assert mock_dl_tsv.call_args[0][0] == "https://example.com/ia.tsv" + + emit_events = [c[0][0] for c in self.emit.call_args_list] + assert "run_cafa_evaluation.downloading_ia" in emit_events + assert "run_cafa_evaluation.ia_resolved" in emit_events + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_explicit_ia_file_takes_precedence(self, mock_compute): + """Explicit ia_file in payload overrides snapshot ia_url.""" + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot(ia_url="https://example.com/ia.tsv") + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + with patch.object(self.op, "_download_obo"), \ + patch.object(self.op, "_download_tsv") as mock_dl_tsv, \ + patch( + "cafaeval.evaluation.cafa_eval", + return_value=(MagicMock(), _dfs_best_fixture()), + ): + self.op.execute( + session, + { + "evaluation_set_id": EVAL_SET_ID, + "prediction_set_id": PRED_SET_ID, + "ia_file": "/custom/ia.tsv", + }, + emit=self.emit, + ) + + # _download_tsv should NOT be called because ia_file overrides ia_url + mock_dl_tsv.assert_not_called() + + emit_events = [c[0][0] for c in self.emit.call_args_list] + assert "run_cafa_evaluation.ia_resolved" in emit_events + assert "run_cafa_evaluation.downloading_ia" not in emit_events + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_session_commit_before_cafa_eval(self, mock_compute): + """Session should be committed before cafa_eval to release DB connection.""" + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot() + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + call_order = [] + session.commit.side_effect = lambda: call_order.append("commit") + + with patch.object(self.op, "_download_obo"): + with patch( + "cafaeval.evaluation.cafa_eval", + side_effect=lambda *a, **kw: (call_order.append("cafa_eval"), (MagicMock(), _dfs_best_fixture()))[-1], + ): + self.op.execute( + session, + {"evaluation_set_id": EVAL_SET_ID, "prediction_set_id": PRED_SET_ID}, + emit=self.emit, + ) + + assert call_order[0] == "commit" + assert "cafa_eval" in call_order + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_artifacts_dir(self, mock_compute): + """When artifacts_dir is set, artifact directory should be created.""" + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot() + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + with tempfile.TemporaryDirectory() as tmpdir: + with patch.object(self.op, "_download_obo"): + with patch( + "cafaeval.evaluation.cafa_eval", + return_value=(None, _dfs_best_fixture()), + ): + result = self.op.execute( + session, + { + "evaluation_set_id": EVAL_SET_ID, + "prediction_set_id": PRED_SET_ID, + "artifacts_dir": tmpdir, + }, + emit=self.emit, + ) + + result_id = result.result["evaluation_result_id"] + assert os.path.isdir(os.path.join(tmpdir, result_id)) + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_artifacts_dir_with_write_results(self, mock_compute): + """When artifacts_dir is set and df is not None, write_results is called.""" + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot() + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot] + + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + df_mock = MagicMock() # non-None df triggers write_results + dfs_best = _dfs_best_fixture() + + with tempfile.TemporaryDirectory() as tmpdir: + with patch.object(self.op, "_download_obo"), \ + patch( + "cafaeval.evaluation.cafa_eval", + return_value=(df_mock, dfs_best), + ), \ + patch( + "cafaeval.evaluation.write_results" + ) as mock_write: + result = self.op.execute( + session, + { + "evaluation_set_id": EVAL_SET_ID, + "prediction_set_id": PRED_SET_ID, + "artifacts_dir": tmpdir, + }, + emit=self.emit, + ) + + # write_results called 3 times (NK, LK, PK) + assert mock_write.call_count == 3 + result_id = result.result["evaluation_result_id"] + # Check setting subdirectories were created + for setting in ("NK", "LK", "PK"): + setting_dir = os.path.join(tmpdir, result_id, setting) + assert os.path.isdir(setting_dir) + + @patch("protea.core.operations.run_cafa_evaluation.compute_evaluation_data") + def test_scoring_config_snapshot(self, mock_compute): + """When scoring_config_id is provided and found, it snapshots the config.""" + mock_compute.return_value = _make_eval_data() + + session = MagicMock() + eval_set = _make_eval_set() + pred_set = _make_pred_set() + ann_old = _make_ann_old() + snapshot = _make_snapshot() + scoring_cfg = MagicMock() + scoring_cfg.formula = "linear" + scoring_cfg.weights = {"embedding_similarity": 1.0} + session.get.side_effect = [eval_set, pred_set, ann_old, snapshot, scoring_cfg] + + query = MagicMock() + session.query.return_value = query + query.join.return_value = query + query.filter.return_value = query + query.order_by.return_value = query + query.yield_per.return_value = [] + + with patch.object(self.op, "_download_obo"), \ + patch( + "cafaeval.evaluation.cafa_eval", + return_value=(MagicMock(), _dfs_best_fixture()), + ), \ + patch( + "protea.core.operations.run_cafa_evaluation.ScoringConfig" + ) as mock_sc_cls: + mock_sc_cls.return_value = MagicMock() + result = self.op.execute( + session, + { + "evaluation_set_id": EVAL_SET_ID, + "prediction_set_id": PRED_SET_ID, + "scoring_config_id": SCORING_CONFIG_ID, + }, + emit=self.emit, + ) + + # ScoringConfig constructor was called for snapshotting + mock_sc_cls.assert_called_once_with( + formula="linear", + weights={"embedding_similarity": 1.0}, + ) + assert "evaluation_result_id" in result.result + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + + +class TestConstants: + def test_ns_labels_mapping(self): + assert _NS_LABELS["biological_process"] == "BPO" + assert _NS_LABELS["molecular_function"] == "MFO" + assert _NS_LABELS["cellular_component"] == "CCO" + + def test_ns_short_set(self): + assert _NS_SHORT == {"BPO", "MFO", "CCO"} diff --git a/tests/test_scoring_router.py b/tests/test_scoring_router.py index eb037e6..15f36b4 100644 --- a/tests/test_scoring_router.py +++ b/tests/test_scoring_router.py @@ -226,7 +226,6 @@ def test_prediction_set_not_found(self, client, session): assert resp.status_code == 404 def test_scoring_config_not_found(self, client, session): - from unittest.mock import call from protea.infrastructure.orm.models.embedding.prediction_set import PredictionSet # First get (PredictionSet) found, second (ScoringConfig) not found session.get.side_effect = [MagicMock(), None] @@ -236,6 +235,155 @@ def test_scoring_config_not_found(self, client, session): ) assert resp.status_code == 404 + @patch("protea.api.routers.scoring.compute_score", return_value=0.85) + def test_streams_tsv_with_data(self, mock_score, session): + """Full streaming path: header + data rows.""" + set_id = uuid4() + config_id = uuid4() + cfg = _make_config("stream", formula="linear") + cfg.id = config_id + pred_set = MagicMock() + + pred = MagicMock() + pred.protein_accession = "P12345" + pred.distance = 0.1 + pred.ref_protein_accession = "Q99999" + pred.evidence_code = "IDA" + pred.qualifier = "enables" + pred.identity_nw = 0.9 + pred.identity_sw = 0.8 + pred.taxonomic_distance = 2 + + def get_side(model, id_): + from protea.infrastructure.orm.models.embedding.prediction_set import PredictionSet + from protea.infrastructure.orm.models.embedding.scoring_config import ScoringConfig + if model is PredictionSet: + return pred_set + if model is ScoringConfig: + return cfg + return None + + session.get.side_effect = get_side + q_mock = MagicMock() + session.query.return_value.join.return_value.filter.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.yield_per.return_value = [(pred, "GO:0003674")] + + app = FastAPI() + factory = MagicMock() + app.state.session_factory = factory + app.include_router(router) + with patch("protea.api.routers.scoring.session_scope", side_effect=lambda _: _mock_scope(session)): + with TestClient(app) as c: + resp = c.get( + f"/scoring/prediction-sets/{set_id}/score.tsv" + f"?scoring_config_id={config_id}" + ) + assert resp.status_code == 200 + assert "text/tab-separated-values" in resp.headers["content-type"] + lines = resp.text.strip().split("\n") + assert len(lines) == 2 + assert lines[0].startswith("protein_accession") + assert "P12345" in lines[1] + assert "GO:0003674" in lines[1] + + @patch("protea.api.routers.scoring.compute_score", return_value=0.3) + def test_min_score_filters_rows(self, mock_score, session): + """Rows below min_score are excluded from the stream.""" + set_id = uuid4() + config_id = uuid4() + cfg = _make_config("filter") + cfg.id = config_id + + pred = MagicMock() + pred.protein_accession = "P00001" + pred.distance = 0.5 + pred.ref_protein_accession = None + pred.evidence_code = "IEA" + pred.qualifier = None + pred.identity_nw = None + pred.identity_sw = None + pred.taxonomic_distance = None + + def get_side(model, id_): + from protea.infrastructure.orm.models.embedding.prediction_set import PredictionSet + from protea.infrastructure.orm.models.embedding.scoring_config import ScoringConfig + if model is PredictionSet: + return MagicMock() + if model is ScoringConfig: + return cfg + return None + + session.get.side_effect = get_side + q_mock = MagicMock() + session.query.return_value.join.return_value.filter.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.yield_per.return_value = [(pred, "GO:0005575")] + + app = FastAPI() + factory = MagicMock() + app.state.session_factory = factory + app.include_router(router) + with patch("protea.api.routers.scoring.session_scope", side_effect=lambda _: _mock_scope(session)): + with TestClient(app) as c: + resp = c.get( + f"/scoring/prediction-sets/{set_id}/score.tsv" + f"?scoring_config_id={config_id}&min_score=0.5" + ) + assert resp.status_code == 200 + lines = resp.text.strip().split("\n") + # Only header — score 0.3 < min_score 0.5 + assert len(lines) == 1 + assert lines[0].startswith("protein_accession") + + @patch("protea.api.routers.scoring.compute_score", return_value=0.9) + def test_accession_filter(self, mock_score, session): + """Accession query parameter is forwarded to the DB query.""" + set_id = uuid4() + config_id = uuid4() + cfg = _make_config("acc-filter") + cfg.id = config_id + + pred = MagicMock() + pred.protein_accession = "P99999" + pred.distance = 0.05 + pred.ref_protein_accession = "Q11111" + pred.evidence_code = "EXP" + pred.qualifier = "enables" + pred.identity_nw = 0.95 + pred.identity_sw = 0.92 + pred.taxonomic_distance = 0 + + def get_side(model, id_): + from protea.infrastructure.orm.models.embedding.prediction_set import PredictionSet + from protea.infrastructure.orm.models.embedding.scoring_config import ScoringConfig + if model is PredictionSet: + return MagicMock() + if model is ScoringConfig: + return cfg + return None + + session.get.side_effect = get_side + q_mock = MagicMock() + session.query.return_value.join.return_value.filter.return_value = q_mock + q_mock.filter.return_value = q_mock + q_mock.yield_per.return_value = [(pred, "GO:0008150")] + + app = FastAPI() + factory = MagicMock() + app.state.session_factory = factory + app.include_router(router) + with patch("protea.api.routers.scoring.session_scope", side_effect=lambda _: _mock_scope(session)): + with TestClient(app) as c: + resp = c.get( + f"/scoring/prediction-sets/{set_id}/score.tsv" + f"?scoring_config_id={config_id}&accession=P99999" + ) + assert resp.status_code == 200 + lines = resp.text.strip().split("\n") + assert len(lines) == 2 + assert "P99999" in lines[1] + # --------------------------------------------------------------------------- # GET /prediction-sets/{set_id}/metrics — 404 preflight checks @@ -260,3 +408,116 @@ def test_scoring_config_not_found(self, client, session): session.get.side_effect = [MagicMock(), None] resp = client.get(self._url()) assert resp.status_code == 404 + + def test_invalid_category_returns_422(self, client, session): + resp = client.get( + f"/scoring/prediction-sets/{uuid4()}/metrics" + f"?scoring_config_id={uuid4()}" + f"&old_annotation_set_id={uuid4()}" + f"&new_annotation_set_id={uuid4()}" + f"&ontology_snapshot_id={uuid4()}" + f"&category=invalid" + ) + assert resp.status_code == 422 + + @patch("protea.api.routers.scoring.compute_cafa_metrics") + @patch("protea.api.routers.scoring.compute_evaluation_data") + @patch("protea.api.routers.scoring.compute_score", return_value=0.9) + def test_returns_metrics_with_curve(self, mock_score, mock_eval, mock_metrics, client, session): + set_id = uuid4() + config_id = uuid4() + cfg = _make_config("metrics-cfg") + cfg.id = config_id + pred_set = MagicMock() + + def get_side(model, id_): + from protea.infrastructure.orm.models.embedding.prediction_set import PredictionSet + from protea.infrastructure.orm.models.embedding.scoring_config import ScoringConfig + if model is PredictionSet: + return pred_set + if model is ScoringConfig: + return cfg + return None + + session.get.side_effect = get_side + mock_eval.return_value = MagicMock() + + pred = MagicMock() + pred.protein_accession = "P12345" + pred.distance = 0.1 + pred.identity_nw = 0.9 + pred.identity_sw = 0.8 + pred.evidence_code = "IDA" + pred.taxonomic_distance = 2 + + session.query.return_value.join.return_value.filter.return_value.all.return_value = [ + (pred, "GO:0003674"), + ] + + point = MagicMock() + point.threshold = 0.5 + point.precision = 0.9 + point.recall = 0.8 + point.f1 = 0.85 + metrics_result = MagicMock() + metrics_result.summary.return_value = {"fmax": 0.85, "auc_pr": 0.78} + metrics_result.curve = [point] + mock_metrics.return_value = metrics_result + + resp = client.get( + f"/scoring/prediction-sets/{set_id}/metrics" + f"?scoring_config_id={config_id}" + f"&old_annotation_set_id={uuid4()}" + f"&new_annotation_set_id={uuid4()}" + f"&ontology_snapshot_id={uuid4()}" + f"&category=nk" + ) + assert resp.status_code == 200 + data = resp.json() + assert data["prediction_set_id"] == str(set_id) + assert data["scoring_config_id"] == str(config_id) + assert data["scoring_config_name"] == "metrics-cfg" + assert "fmax" in data + assert "curve" in data + assert len(data["curve"]) == 1 + assert data["curve"][0]["threshold"] == 0.5 + + @patch("protea.api.routers.scoring.compute_cafa_metrics") + @patch("protea.api.routers.scoring.compute_evaluation_data") + @patch("protea.api.routers.scoring.compute_score", return_value=0.5) + def test_lk_category(self, mock_score, mock_eval, mock_metrics, client, session): + set_id = uuid4() + config_id = uuid4() + cfg = _make_config("lk-cfg") + cfg.id = config_id + + def get_side(model, id_): + from protea.infrastructure.orm.models.embedding.prediction_set import PredictionSet + from protea.infrastructure.orm.models.embedding.scoring_config import ScoringConfig + if model is PredictionSet: + return MagicMock() + if model is ScoringConfig: + return cfg + return None + + session.get.side_effect = get_side + mock_eval.return_value = MagicMock() + session.query.return_value.join.return_value.filter.return_value.all.return_value = [] + + metrics_result = MagicMock() + metrics_result.summary.return_value = {"fmax": 0.0, "auc_pr": 0.0} + metrics_result.curve = [] + mock_metrics.return_value = metrics_result + + resp = client.get( + f"/scoring/prediction-sets/{set_id}/metrics" + f"?scoring_config_id={config_id}" + f"&old_annotation_set_id={uuid4()}" + f"&new_annotation_set_id={uuid4()}" + f"&ontology_snapshot_id={uuid4()}" + f"&category=lk" + ) + assert resp.status_code == 200 + mock_metrics.assert_called_once() + call_kwargs = mock_metrics.call_args + assert call_kwargs[1]["category"] == "lk" or call_kwargs[0][2] == "lk" if len(call_kwargs[0]) > 2 else call_kwargs[1].get("category") == "lk" From 096823e43b2854465b6f308471fb77033c0884c0 Mon Sep 17 00:00:00 2001 From: frapercan Date: Wed, 18 Mar 2026 13:37:01 +0100 Subject: [PATCH 3/7] docs: ADRs, operational runbook, and re-ranker design spec Architecture Decision Records (6 ADRs): - 001: KNN on CPU, not pgvector or GPU - 002: Two-session worker pattern - 003: QueueConsumer vs OperationConsumer - 004: Dead letter queue and retry strategy - 005: Thread-local RabbitMQ connections - 006: Sequence deduplication by MD5 Operational runbook covering: start/stop, health checks, scaling, stuck jobs, batch failures, CUDA OOM, DLQ inspection, DB maintenance RERANKER.md: formal spec for temporal holdout re-ranker (cross-attention architecture, LambdaRank loss, WebDataset pipeline, LightGBM baseline) --- RERANKER.md | 188 +++++++++++++++ docs/source/adr/001-knn-without-pgvector.rst | 51 +++++ .../adr/002-two-session-worker-pattern.rst | 45 ++++ ...3-queue-consumer-vs-operation-consumer.rst | 57 +++++ ...4-dead-letter-queue-and-retry-strategy.rst | 51 +++++ .../005-thread-local-rabbitmq-connections.rst | 46 ++++ .../adr/006-sequence-deduplication-by-md5.rst | 47 ++++ docs/source/adr/index.rst | 45 ++++ docs/source/appendix/index.rst | 1 + docs/source/appendix/runbook.rst | 214 ++++++++++++++++++ docs/source/architecture/index.rst | 1 + 11 files changed, 746 insertions(+) create mode 100644 RERANKER.md create mode 100644 docs/source/adr/001-knn-without-pgvector.rst create mode 100644 docs/source/adr/002-two-session-worker-pattern.rst create mode 100644 docs/source/adr/003-queue-consumer-vs-operation-consumer.rst create mode 100644 docs/source/adr/004-dead-letter-queue-and-retry-strategy.rst create mode 100644 docs/source/adr/005-thread-local-rabbitmq-connections.rst create mode 100644 docs/source/adr/006-sequence-deduplication-by-md5.rst create mode 100644 docs/source/adr/index.rst create mode 100644 docs/source/appendix/runbook.rst diff --git a/RERANKER.md b/RERANKER.md new file mode 100644 index 0000000..2301546 --- /dev/null +++ b/RERANKER.md @@ -0,0 +1,188 @@ +# Temporal Holdout Re-Ranker for GO Term Prediction + +## Motivación + +El pipeline actual de PROTEA transfiere anotaciones GO mediante KNN sobre embeddings ESM, usando un scoring heurístico que combina distancia de embedding y pesos de evidencia. Este scoring no está optimizado para la métrica objetivo (Fmax) ni para el comportamiento real de las anotaciones GO a lo largo del tiempo. + +La hipótesis central es que existe una señal aprendible: **dado el contexto de una predicción KNN, ¿acabará este GO term apareciendo en el siguiente release de GOA para esta proteína?** Esta señal puede extraerse directamente del mecanismo de holdout temporal que ya implementa PROTEA. + +--- + +## Formulación del Problema + +Sea $\mathcal{G}_N$ el conjunto de anotaciones GO en el release $N$ de GOA (Swiss-Prot reviewed). Para cada par consecutivo $(G_N, G_{N+1})$, el delta temporal es: + +$$\Delta_{N \to N+1} = \{(p, t) \mid (p, t) \in \mathcal{G}_{N+1} \setminus \mathcal{G}_N\}$$ + +El re-ranker aprende una función: + +$$f(q, t, \mathcal{N}_K(q)) \to \hat{y} \in [0, 1]$$ + +donde: +- $q$ es la proteína query (representada por su embedding ESM) +- $t$ es el GO term candidato +- $\mathcal{N}_K(q)$ es el conjunto de $K$ vecinos más cercanos en el espacio de embeddings con referencia $\mathcal{G}_N$ +- $\hat{y}$ es la probabilidad de que $(q, t) \in \Delta_{N \to N+1}$ + +--- + +## Protocolo de Entrenamiento + +Se utiliza validación cruzada temporal con múltiples splits históricos de GOA: + +``` +Training splits: + GOA_190 → GOA_195 + GOA_195 → GOA_200 + GOA_200 → GOA_205 + GOA_205 → GOA_211 + GOA_211 → GOA_215 + GOA_215 → GOA_220 + +Test split (holdout estricto, nunca visto durante training): + GOA_220 → GOA_229 +``` + +Para cada split se generan ejemplos etiquetados: positivos $(y=1)$ si el par (proteína, GO term) aparece en el delta, negativos $(y=0)$ en caso contrario. El desbalanceo esperado es aproximadamente 1:10, manejable con técnicas estándar. + +--- + +## Arquitectura: Cross-Attention Re-Ranker + +El modelo procesa cada par (query, GO term) usando el contexto completo de los vecinos KNN que contribuyeron a esa predicción. + +``` +Inputs por predicción (query_protein, go_term): + query_embedding float32[D] ESM embedding del query (D=480 para esmc_300m) + neighbor_embeddings float32[K × D] ESM embeddings de los K vecinos contribuyentes + tabular_features float32[K × F] distancia, evidencia, alineamiento, taxonomía... + go_term_embedding float32[G] embedding semántico del GO term (G=64) + +Arquitectura: + 1. query_proj(query_embedding) → q [H=256] + 2. ref_proj(neighbor_embeddings) → tokens [K × H] + 3. feature_encoder(tabular_features) → (sumado a tokens) + 4. CrossAttention(q, tokens, tokens) → context [H] + 5. MLP([q ‖ context ‖ go_emb ‖ agg_features]) → score [1] +``` + +La atención cruzada permite al modelo aprender **qué vecinos son más informativos para este query concreto**, en lugar de agregar los scores de forma heurística. + +### GO Term Embeddings + +Los embeddings de los GO terms se aprenden a partir de la estructura del DAG de GO (relaciones `is_a` / `part_of`) mediante Node2Vec o TransE, de forma que términos semánticamente relacionados (padre-hijo) tengan representaciones similares. El DAG ya está disponible en PROTEA a través de los modelos `GOTerm` y `GOTermRelationship`. + +--- + +## Feature Vector + +Cada predicción (query, GO term) se caracteriza por las siguientes features tabulares, computadas por vecino que contribuyó a la predicción: + +| Feature | Descripción | Estado | +|---|---|---| +| `distance` | Distancia coseno en espacio de embeddings | Existente | +| `evidence_weight` | Peso del código de evidencia (IDA > IEA) | Existente | +| `identity_nw / sw` | Identidad de secuencia (alineamiento NW/SW) | Existente (opcional) | +| `similarity_nw / sw` | Similaridad de secuencia | Existente (opcional) | +| `taxonomic_distance` | Distancia taxonómica entre query y referencia | Existente (opcional) | +| `vote_count` | Número de vecinos que coinciden en este GO term | **Nuevo** | +| `k_position` | Posición del vecino más cercano que predijo este término | **Nuevo** | +| `go_term_frequency` | Frecuencia del término en el annotation set de referencia | **Nuevo** | +| `ref_annotation_density` | Número de GO terms de la proteína de referencia | **Nuevo** | +| `neighbor_distance_std` | Varianza de distancias a los K vecinos | **Nuevo** | + +--- + +## Función de Pérdida + +Se utiliza **LambdaRank** en lugar de binary cross-entropy, ya que optimiza directamente el orden de las predicciones (proxy de NDCG / Fmax) en lugar de la calibración de probabilidades. + +Para cada proteína query, las predicciones GO se rankean conjuntamente: +- Positivos: GO terms en $\Delta_{N \to N+1}$ +- Negativos: GO terms predichos pero no en el delta + +--- + +## Pipeline de Datos: WebDataset + +El volumen de datos (múltiples splits × ~1.35M predicciones por split × embeddings de 480 dim) requiere un pipeline de datos eficiente. Se propone almacenar los ejemplos de entrenamiento en formato **WebDataset** (shards tar), con un shard por split GOA: + +``` +reranker_data/ + splits/ + goa190_to_195.tar # ~2GB por shard + goa195_to_200.tar + ... + goa220_to_229.tar # test split — no tocar durante training + models/ + reranker_v1.pt + reranker_v1_config.json +``` + +Cada muestra en el WebDataset es **una proteína query** con todas sus predicciones GO para ese split: + +```python +{ + "query_accession": "P12345", + "query_embedding": float32[480], + "go_term_ids": ["GO:0006915", "GO:0005737", ...], # N_preds + "neighbor_embeddings": float32[N_preds, K, 480], + "tabular_features": float32[N_preds, K, F], + "labels": int8[N_preds], # 1 si en delta, 0 si no +} +``` + +El streaming de WebDataset permite entrenar sin cargar todo en RAM. + +--- + +## Stack Tecnológico + +| Componente | Tecnología | +|---|---| +| Modelo | PyTorch | +| Data pipeline | WebDataset + torch.utils.data | +| Baseline comparación | LightGBM (binary + LambdaRank) | +| GO embeddings | Node2Vec / PyTorch Geometric | +| Seguimiento experimentos | wandb | +| Embeddings proteína | ESM2 / ESMC (ya en PROTEA) | + +--- + +## Integración en PROTEA + +Una vez entrenado, el re-ranker se integra en el pipeline existente: + +1. Nuevo modelo ORM `RerankingModel`: almacena pesos serializados y metadata de entrenamiento +2. Campo `reranker_id` (nullable) en `PredictionSet` +3. Si `reranker_id` presente: `store_predictions` aplica el modelo y sobreescribe `score` con $\hat{y}$ +4. El threshold de Fmax se calcula igual que ahora sobre los nuevos scores +5. UI: selector de re-ranker en la pantalla de predicción + +--- + +## Experimentos y Ablaciones + +El diseño permite comparar directamente: + +| Configuración | Descripción | +|---|---| +| **Baseline** | KNN + scoring heurístico actual | +| **LightGBM tabular** | Re-ranker con features tabulares sin embeddings | +| **LightGBM + derived** | Features tabulares + features derivadas del embedding (density, std) | +| **MLP cross-encoder** | Arquitectura completa sin cross-attention | +| **Cross-attention (propuesto)** | Arquitectura completa | +| **+ GO DAG embeddings** | Ablación: ¿aportan los go_term_emb? | +| **+ temporal CV** | Ablación: ¿mejora añadir más splits históricos? | + +La métrica principal es **Fmax promedio sobre los 9 settings** (NK/LK/PK × BPO/MFO/CCO) en el test split GOA220→229. + +--- + +## Valor para la Tesis + +1. **Científicamente honesto**: el mismo mecanismo temporal que se usa para evaluar se usa para entrenar. No hay data leakage. +2. **Comprobable y cuantificable**: Fmax(baseline KNN) vs Fmax(re-ranker) en benchmark idéntico. +3. **Interpretable**: las feature importances (LightGBM) o los pesos de atención (cross-attention) revelan qué aspectos de una predicción KNN son más predictivos de anotaciones futuras. +4. **Generalizable**: el re-ranker aprende sobre distribuciones temporales de anotaciones GO, no sobre una proteína concreta — debería generalizar a proteínas no vistas. +5. **Extensible**: la arquitectura admite incorporar embeddings de secuencia de mayor calidad (ESM3, ProstT5) sin cambiar el pipeline. diff --git a/docs/source/adr/001-knn-without-pgvector.rst b/docs/source/adr/001-knn-without-pgvector.rst new file mode 100644 index 0000000..f0f521d --- /dev/null +++ b/docs/source/adr/001-knn-without-pgvector.rst @@ -0,0 +1,51 @@ +ADR-001: KNN on CPU, not pgvector or GPU +======================================== + +:Date: 2025-12-15 +:Author: frapercan + +The problem +----------- + +GO term prediction requires K-nearest-neighbor search over 500K+ embeddings +of 1280 dimensions. The natural options were ``pgvector`` (we already store +vectors there) or PyTorch on GPU (we already have the GPU for inference). +Both failed: + +- **pgvector** with an IVFFlat index on 527K vectors: index build took + >20 minutes, and each individual query cost 100-500ms. For a job with + thousands of queries, unacceptable. +- **PyTorch on GPU**: the GPU is busy with ESM-2/ESM-3c/T5 inference. + Loading the distance matrix competes with model forward passes and + causes CUDA OOM. + +What we do +---------- + +KNN runs **on CPU**, entirely in Python: + +- **NumPy** (brute-force via matrix multiplication) for small datasets + (<100K). +- **FAISS** (Flat, IVFFlat, HNSW) for large datasets. Uses SIMD and + multithreading on CPU without touching the GPU. + +Reference embeddings are loaded once from PostgreSQL into a process-level +cache (``_REF_CACHE``, float16, ~4 GB for 500K vectors). ``pgvector`` +remains as storage only — the ``VECTOR`` type is there, but we never +search with ``<=>``. + +Trade-offs +---------- + +- The cache consumes worker RAM (~4 GB). If the worker restarts, the + first prediction takes ~15s extra to reload from DB. +- KNN and inference run in parallel without contention: CPU computes + distances while GPU computes embeddings. + +Rejected +-------- + +- **Dedicated vector database** (Milvus, Qdrant): one more infra + dependency for something NumPy/FAISS solves in-process. +- **Persistent FAISS index on disk**: IVFFlat training takes a few + seconds; not worth the complexity of serialising/deserialising for now. diff --git a/docs/source/adr/002-two-session-worker-pattern.rst b/docs/source/adr/002-two-session-worker-pattern.rst new file mode 100644 index 0000000..14c2b51 --- /dev/null +++ b/docs/source/adr/002-two-session-worker-pattern.rst @@ -0,0 +1,45 @@ +ADR-002: Two-session worker pattern +==================================== + +:Date: 2025-12-20 +:Author: frapercan + +The problem +----------- + +A worker executes operations that can run for hours (compute_embeddings, +load_goa_annotations). If the operation fails mid-way, we need the job +to remain marked as ``RUNNING`` in the database so monitoring can detect it. + +With a single database session, a rollback on error also reverts the +``QUEUED -> RUNNING`` transition. The job silently goes back to ``QUEUED`` +and nobody notices the failure until the reaper catches it an hour later. + +What we do +---------- + +``BaseWorker.handle_job(job_id)`` opens **two independent sessions**: + +1. **Claim session** — changes the job to ``RUNNING``, records + ``started_at`` and the ``job.started`` event, and **commits immediately**. + From this point the job is visible as running. + +2. **Execute session** — runs the operation. On success: ``SUCCEEDED``. + On failure: ``FAILED`` with ``error_code`` and ``error_message``. + A rollback here does not affect the claim. + +Trade-offs +---------- + +- Two round-trips to DB per job — irrelevant when the operation takes + minutes. +- RabbitMQ delivers each message to a single consumer (``prefetch=1``), + so there is no real race condition between workers for the same job. + +Rejected +-------- + +- **Savepoints** inside a long transaction: hold locks and bloat the + PostgreSQL WAL. +- **Optimistic locking** with a version column: does not solve the + requirement that the claim must be visible before execution starts. diff --git a/docs/source/adr/003-queue-consumer-vs-operation-consumer.rst b/docs/source/adr/003-queue-consumer-vs-operation-consumer.rst new file mode 100644 index 0000000..8819873 --- /dev/null +++ b/docs/source/adr/003-queue-consumer-vs-operation-consumer.rst @@ -0,0 +1,57 @@ +ADR-003: Two types of consumer +=============================== + +:Date: 2026-01-10 +:Author: frapercan + +The problem +----------- + +Distributed pipelines (``compute_embeddings``, ``predict_go_terms``) split +work into hundreds of batches. If each batch had its own ``Job`` row in +the DB: + +- The ``jobs`` table fills with thousands of rows per prediction run, + making it impossible to see real user-facing jobs. +- Each batch pays the cost of the two-session pattern (2 round-trips), + which for 2-8s batches is more overhead than useful work. + +What we do +---------- + +Two consumers coexist: + +**QueueConsumer** — for user-facing jobs with full lifecycle tracking: + +- Receives ``{"job_id": ""}`` and delegates to + ``BaseWorker.handle_job()``. +- Used by: ``protea.ping``, ``protea.jobs``, ``protea.embeddings``. + +**OperationConsumer** — for ephemeral batches with no individual DB row: + +- Receives ``{"operation": "...", "job_id": "", "payload": {...}}``. +- Executes the operation in a single session, ack/nack, done. +- Progress is reported by incrementing ``progress_current`` on the + **parent job**. +- Events are written to the parent's log with the ``child.`` prefix. +- Used by: ``protea.embeddings.batch``, ``protea.embeddings.write``, + ``protea.predictions.batch``, ``protea.predictions.write``. + +From the outside, the user sees a single job (the coordinator) with a +progress bar that advances. Batches are invisible. + +Trade-offs +---------- + +- Two code paths for consuming messages, but both are short (~100 lines) + and share infrastructure (DLQ, registry, emit). +- If a batch fails and goes to the DLQ, there is no individual retry + counter — just the dead message for inspection. + +Rejected +-------- + +- **Job with** ``is_batch=True`` **flag**: still creates thousands of DB + rows. +- **Fire-and-forget** without tracking: operators lose visibility into + progress and failures. diff --git a/docs/source/adr/004-dead-letter-queue-and-retry-strategy.rst b/docs/source/adr/004-dead-letter-queue-and-retry-strategy.rst new file mode 100644 index 0000000..d151592 --- /dev/null +++ b/docs/source/adr/004-dead-letter-queue-and-retry-strategy.rst @@ -0,0 +1,51 @@ +ADR-004: Dead letter queue and retries +====================================== + +:Date: 2026-03-18 +:Author: frapercan + +The problem +----------- + +Two related messaging problems: + +1. **Lost messages**: when a message failed permanently (invalid JSON, + unknown operation), it was discarded with ``basic_nack``. The payload + disappeared and there was no way to do post-mortem. + +2. **Aggressive retries**: transient failures (broker down, GPU busy) + were retried immediately, amplifying load on the service that was + already struggling. + +What we do +---------- + +**Dead letter queue** — all queues are declared with +``x-dead-letter-exchange: protea.dlx``. Rejected messages +(``nack`` without ``requeue``) end up in ``protea.dead-letter``, a durable +queue where they can be inspected, fixed, and republished. + +**Publisher retries** — exponential backoff: 5 attempts with delays of +1, 2, 4, 8, 16s (capped at 30s). If the connection is broken, it is +discarded and a new one is created. + +**Worker retries** — operations can raise +``RetryLaterError("GPU busy", delay_seconds=60)``. The worker calculates +adaptive backoff based on how many previous retries have occurred: +``delay = min(base * 2^retries, 600s)``. The job goes back to ``QUEUED`` +and is republished after the wait. + +Trade-offs +---------- + +- The DLQ grows if nobody inspects it — it must be monitored (see runbook). +- Adaptive backoff makes one DB query per retry to count previous + ``job.retry_later`` events. Negligible cost. + +Rejected +-------- + +- **TTL + delay queue in RabbitMQ**: more complex to set up and debug than + an application-level ``sleep()``. +- **Celery retries**: PROTEA does not use Celery; reimplementing its + countdown over raw pika adds no value. diff --git a/docs/source/adr/005-thread-local-rabbitmq-connections.rst b/docs/source/adr/005-thread-local-rabbitmq-connections.rst new file mode 100644 index 0000000..a5732e1 --- /dev/null +++ b/docs/source/adr/005-thread-local-rabbitmq-connections.rst @@ -0,0 +1,46 @@ +ADR-005: Reusable RabbitMQ connections +====================================== + +:Date: 2026-03-18 +:Author: frapercan + +The problem +----------- + +When a coordinator (``compute_embeddings``) dispatches 500 batches, the +publisher opened and closed a TCP connection for each ``publish_operation()`` +call. This caused: + +- 500 TCP+AMQP handshakes in a burst. +- ``EMFILE`` (too many open files) errors on the worker. +- Broker-side resource exhaustion (each connection costs RabbitMQ memory). + +What we do +---------- + +Each thread keeps **a single connection** stored in ``threading.local()``. +``_get_connection()`` returns the existing connection if it is open, or +creates a new one. If a publish fails, ``_close_cached_connection()`` +discards the broken connection so the next attempt reconnects. + +Result: from O(messages) connections down to O(threads) — in practice, +1-4 connections total. + +Trade-offs +---------- + +- ``pika.BlockingConnection`` is not thread-safe, which is why + ``threading.local()`` isolation is mandatory. +- Connections are never proactively closed — they live until the thread + dies or a publish fails. If RabbitMQ restarts, the first publish after + restart always fails once (and reconnects automatically). + +Rejected +-------- + +- **Connection pool** (``pika_pool``): external dependency for something + ``threading.local()`` solves in 15 lines. +- **Global connection with a lock**: serialises all publishes, creating a + bottleneck when dispatching hundreds of messages. +- **``aio-pika`` async**: workers are synchronous; adding an event loop + just for the publisher is disproportionate. diff --git a/docs/source/adr/006-sequence-deduplication-by-md5.rst b/docs/source/adr/006-sequence-deduplication-by-md5.rst new file mode 100644 index 0000000..cd62a79 --- /dev/null +++ b/docs/source/adr/006-sequence-deduplication-by-md5.rst @@ -0,0 +1,47 @@ +ADR-006: Sequence deduplication by MD5 +====================================== + +:Date: 2025-12-10 +:Author: frapercan + +The problem +----------- + +UniProt has ~570K accessions in Swiss-Prot, but only ~540K unique sequences. +The remaining 30K are isoforms or cross-references sharing the same amino +acid chain. + +Computing the embedding for a sequence costs ~0.5s on GPU. Processing 30K +duplicates wastes **4+ hours** per full run. + +What we do +---------- + +When inserting proteins, we compute the MD5 hash of the amino acid string. +The ``Sequence`` table has a **unique constraint on ``sequence_hash``**: + +1. If the hash already exists -> reuse the existing ``Sequence.id``. +2. If it does not exist -> insert a new row. + +Multiple ``Protein`` rows (one per UniProt accession) point to the same +``Sequence``. The FK ``Protein.sequence_id`` is intentionally non-unique. + +When the embedding pipeline runs, it only processes ``Sequence`` rows +without an embedding — duplicates are skipped automatically. + +Trade-offs +---------- + +- MD5 is not cryptographically secure, but that does not matter here: + there is no adversarial input, only biological sequences. +- Sequences with a single mutation produce different hashes and are stored + separately. This is correct — a mutation changes the embedding. + +Rejected +-------- + +- **SHA-256**: digest twice as long, zero practical benefit. +- **UNIQUE on the sequence text column**: indexing multi-kilobyte text + columns is expensive; the 32-char hex digest is far more efficient. +- **CD-HIT clustering** (90-95% identity): useful for reducing redundancy + in evolutionary analysis, but here we need exact deduplication (100%). diff --git a/docs/source/adr/index.rst b/docs/source/adr/index.rst new file mode 100644 index 0000000..b4f5046 --- /dev/null +++ b/docs/source/adr/index.rst @@ -0,0 +1,45 @@ +Architecture Decision Records +============================= + +Design decisions that are not obvious from reading the code. Each ADR +documents **why** a decision was made, not just what — the code already +shows the what. + +Decisions are grouped by system layer: + +.. list-table:: + :header-rows: 1 + :widths: 10 50 40 + + * - ADR + - Decision + - Problem it solves + * - 001 + - :doc:`KNN on CPU, not pgvector or GPU <001-knn-without-pgvector>` + - pgvector does not scale to 500K+ vectors; GPU must be reserved for inference + * - 006 + - :doc:`Sequence deduplication by MD5 <006-sequence-deduplication-by-md5>` + - 30K duplicate sequences in Swiss-Prot waste hours of GPU time + * - 002 + - :doc:`Two-session worker pattern <002-two-session-worker-pattern>` + - A mid-operation crash left the job invisible to monitoring + * - 003 + - :doc:`Two types of consumer <003-queue-consumer-vs-operation-consumer>` + - Thousands of batch jobs per pipeline flooded the jobs table + * - 004 + - :doc:`Dead letter queue and retries <004-dead-letter-queue-and-retry-strategy>` + - Failed messages were lost; retries without backoff amplified failures + * - 005 + - :doc:`Reusable RabbitMQ connections <005-thread-local-rabbitmq-connections>` + - A coordinator dispatching 500 batches opened 500 TCP connections + +.. toctree:: + :maxdepth: 1 + :hidden: + + 001-knn-without-pgvector + 002-two-session-worker-pattern + 003-queue-consumer-vs-operation-consumer + 004-dead-letter-queue-and-retry-strategy + 005-thread-local-rabbitmq-connections + 006-sequence-deduplication-by-md5 diff --git a/docs/source/appendix/index.rst b/docs/source/appendix/index.rst index 1384103..dbeb4d2 100644 --- a/docs/source/appendix/index.rst +++ b/docs/source/appendix/index.rst @@ -7,3 +7,4 @@ Appendix installation_and_quickstart configuration howto_guides + runbook diff --git a/docs/source/appendix/runbook.rst b/docs/source/appendix/runbook.rst new file mode 100644 index 0000000..4769822 --- /dev/null +++ b/docs/source/appendix/runbook.rst @@ -0,0 +1,214 @@ +Operational Runbook +=================== + +Practical guide for operating PROTEA: starting the system, diagnosing +problems, and maintaining infrastructure. + +.. contents:: Contents + :local: + :depth: 2 + + +Day-to-day operations +--------------------- + +Starting and stopping +~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + # Prerequisite: PostgreSQL and RabbitMQ must be running + docker start pgvectorsql rabbitmq + + # Start everything (API + workers + frontend) + bash scripts/manage.sh start + + # Start with 3 batch workers per GPU pipeline + bash scripts/manage.sh start 3 + + # Check what is running + bash scripts/manage.sh status + + # Stop everything + bash scripts/manage.sh stop + +Checking that everything works +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + # Liveness: is the API process alive? + curl http://127.0.0.1:8000/health + # -> {"status": "ok"} + + # Readiness: can it connect to DB and RabbitMQ? + curl http://127.0.0.1:8000/health/ready + # -> {"status": "ready"} or 503 if something is down + +If ``/health/ready`` returns 503, check that Docker containers are running +and that the URLs in ``protea/config/system.yaml`` are correct. + +Scaling workers +~~~~~~~~~~~~~~~ + +Batch workers are stateless — they can be added on the fly: + +.. code-block:: bash + + bash scripts/manage.sh scale protea.predictions.batch 2 + bash scripts/manage.sh scale protea.embeddings.batch 3 + +Scaling is linear for batch queues. + +.. warning:: + + The ``protea.embeddings`` queue must have **exactly one** consumer. + The coordinator serialises GPU access; multiple coordinators step on + each other and cause ``RetryLaterError`` storms. + +Remote access +~~~~~~~~~~~~~ + +For demos or access from outside the local network: + +.. code-block:: bash + + bash scripts/expose.sh + +Opens an ngrok tunnel to the frontend (port 3000) with a static domain +(``protea.ngrok.app``). API calls are proxied through Next.js rewrites, +so only one tunnel is needed. Requires ngrok installed and authenticated. +Closes with Ctrl+C. + + +Troubleshooting +--------------- + +Jobs stuck in RUNNING +~~~~~~~~~~~~~~~~~~~~~ + +A job in ``RUNNING`` that is not progressing usually means the worker died. + +**Automatic detection**: the ``worker-reaper`` process checks every 60s +and marks as ``FAILED`` (error code ``JobTimeout``) any job that has been +in ``RUNNING`` for more than 1 hour. + +**Manual intervention**: + +.. code-block:: bash + + # Check job status and events + curl -s http://127.0.0.1:8000/jobs/ | python -m json.tool + curl -s http://127.0.0.1:8000/jobs//events | python -m json.tool + + # Cancel (also cancels child sub-jobs) + curl -s -X POST http://127.0.0.1:8000/jobs//cancel + + # Delete a terminal job + curl -s -X DELETE http://127.0.0.1:8000/jobs/ + +To re-run, create a new job with the same operation and payload. +There is no "retry" button — jobs are immutable once finished. + +Batch failures +~~~~~~~~~~~~~~ + +Batches (``compute_embeddings_batch``, ``predict_go_terms_batch``) do not +have their own row in ``jobs``. To diagnose: + +1. **Parent job events** — failures are recorded as ``child.failed``: + + .. code-block:: bash + + curl -s http://127.0.0.1:8000/jobs//events?limit=50 | python -m json.tool + +2. **Worker logs** — each worker writes structured JSON: + + .. code-block:: bash + + bash scripts/manage.sh logs embeddings-batch + + # Filter errors only with jq + cat logs/worker-embeddings-batch-1.log | jq 'select(.level == "ERROR")' + + # Search for a specific job + cat logs/worker-jobs.log | jq 'select(.message | contains(""))' + +3. **Dead letter queue** — permanently failed messages: + + .. code-block:: bash + + # Check how many dead messages there are + rabbitmqctl list_queues name messages | grep dead-letter + + Also accessible from the RabbitMQ UI: http://localhost:15672 + (guest/guest) -> Queues -> ``protea.dead-letter`` -> Get Message(s). + + To republish a corrected message, use "Move" in the UI. + +CUDA out of memory +~~~~~~~~~~~~~~~~~~ + +When a batch worker runs out of GPU memory: + +1. The worker automatically calls ``torch.cuda.empty_cache()`` and + requeues the message for retry. +2. If it keeps failing, reduce ``batch_size`` in the job payload. +3. Check that no other process is using the GPU: + + .. code-block:: bash + + nvidia-smi + +4. If another embedding job is using the GPU, the coordinator detects + contention via ``RetryLaterError`` and waits with exponential backoff + (up to 10 minutes between retries). + + +Maintenance +----------- + +Database +~~~~~~~~ + +.. code-block:: bash + + # Total DB size + psql postgresql://protea:protea@localhost:5432/protea \ + -c "SELECT pg_size_pretty(pg_database_size('protea'));" + + # Top 10 tables by size + psql postgresql://protea:protea@localhost:5432/protea \ + -c "SELECT relname, pg_size_pretty(pg_total_relation_size(oid)) + FROM pg_class WHERE relkind='r' + ORDER BY pg_total_relation_size(oid) DESC LIMIT 10;" + + # Clean up jobs and events older than 30 days + psql postgresql://protea:protea@localhost:5432/protea \ + -c "DELETE FROM job_events WHERE ts < now() - interval '30 days';" + psql postgresql://protea:protea@localhost:5432/protea \ + -c "DELETE FROM jobs WHERE finished_at < now() - interval '30 days' + AND status IN ('succeeded', 'failed', 'cancelled');" + + # Full reset (destructive — deletes EVERYTHING) + curl -s -X POST http://127.0.0.1:8000/admin/reset-db + +Dead letter queue +~~~~~~~~~~~~~~~~~ + +Messages in ``protea.dead-letter`` accumulate and are not purged +automatically. Review periodically: + +.. code-block:: bash + + # Purge the DLQ when messages are no longer needed + rabbitmqctl purge_queue protea.dead-letter + +Logs +~~~~ + +Logs grow without limit. To truncate without restarting workers: + +.. code-block:: bash + + for f in logs/*.log; do : > "$f"; done diff --git a/docs/source/architecture/index.rst b/docs/source/architecture/index.rst index 548db7a..9110d40 100644 --- a/docs/source/architecture/index.rst +++ b/docs/source/architecture/index.rst @@ -12,3 +12,4 @@ job lifecycle, and extension points. data_model operations evaluation + /adr/index From 092f11013793d9325555866910d73a46e5752399 Mon Sep 17 00:00:00 2001 From: frapercan Date: Wed, 25 Mar 2026 13:35:27 +0100 Subject: [PATCH 4/7] =?UTF-8?q?release:=20v0.3.0=20=E2=80=94=20re-ranker,?= =?UTF-8?q?=20evaluation=20pipeline,=20annotate=20workflow,=20UI=20overhau?= =?UTF-8?q?l?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major features: - Neural re-ranker: train_reranker operation, ReRankerModel ORM, reranker UI page - Expanded CAFA evaluation pipeline with scoring router and detailed metrics - Annotate router and showcase router for streamlined user workflows - Floating jobs widget, breadcrumbs, context banner, tooltip components - Frontend overhaul: redesigned pages, improved navigation, i18n updates - Thesis PDF served from frontend Infrastructure: - 4 new Alembic migrations for re-ranker schema - API deps module, extended scoring endpoints - Experiment and evaluation helper scripts - Updated documentation (results, evaluation architecture) - Version bump to 0.3.0 Tests: - New test suites: reranker, train_reranker, annotate router, showcase router, integration - Expanded: predict_go_terms, compute_embeddings, scoring router, embeddings router --- ...b9_add_reranker_model_id_to_evaluation_.py | 36 + ...74df6_add_aspect_to_reranker_model_and_.py | 34 + ...946_add_reranker_feature_columns_to_go_.py | 40 + .../ba9966bd453e_add_reranker_model_table.py | 50 + apps/web/app/[locale]/annotations/page.tsx | 117 +- apps/web/app/[locale]/embeddings/page.tsx | 141 +- apps/web/app/[locale]/evaluation/page.tsx | 255 ++- .../functional-annotation/[id]/page.tsx | 561 ++++--- .../[locale]/functional-annotation/page.tsx | 26 +- apps/web/app/[locale]/jobs/[id]/page.tsx | 3 +- apps/web/app/[locale]/layout.tsx | 23 +- apps/web/app/[locale]/page.tsx | 346 +++- .../[locale]/proteins/[accession]/page.tsx | 3 +- apps/web/app/[locale]/proteins/page.tsx | 41 +- apps/web/app/[locale]/query-sets/page.tsx | 4 +- apps/web/app/[locale]/reranker/page.tsx | 574 +++++++ apps/web/components/AnnotateForm.tsx | 302 ++++ apps/web/components/Breadcrumbs.tsx | 59 + apps/web/components/ContextBanner.tsx | 70 + apps/web/components/FloatingJobsWidget.tsx | 97 ++ apps/web/components/LanguageSwitcher.tsx | 51 +- apps/web/components/NavLinks.tsx | 185 ++- apps/web/components/SupportButton.tsx | 2 +- apps/web/components/Tooltip.tsx | 22 + .../e2e/screenshots/mobile-annotations.png | Bin 137479 -> 356915 bytes .../web/e2e/screenshots/mobile-embeddings.png | Bin 100383 -> 350559 bytes .../mobile-functional-annotation.png | Bin 286720 -> 538175 bytes apps/web/e2e/screenshots/mobile-jobs.png | Bin 92118 -> 353322 bytes apps/web/e2e/screenshots/mobile-proteins.png | Bin 133589 -> 353317 bytes .../web/e2e/screenshots/mobile-query-sets.png | Bin 92828 -> 348494 bytes .../e2e/screenshots/tablet-annotations.png | Bin 108080 -> 270171 bytes .../web/e2e/screenshots/tablet-embeddings.png | Bin 82652 -> 289941 bytes .../tablet-functional-annotation.png | Bin 196370 -> 388457 bytes apps/web/e2e/screenshots/tablet-jobs.png | Bin 72319 -> 256053 bytes apps/web/e2e/screenshots/tablet-proteins.png | Bin 97931 -> 261150 bytes .../web/e2e/screenshots/tablet-query-sets.png | Bin 70127 -> 255494 bytes apps/web/lib/api.ts | 174 +- apps/web/messages/de.json | 27 +- apps/web/messages/en.json | 55 +- apps/web/messages/es.json | 61 +- apps/web/messages/pt.json | 29 +- apps/web/messages/zh.json | 3 +- apps/web/public/thesis.pdf | Bin 0 -> 464324 bytes docs/source/abstract.rst | 7 + docs/source/appendix/configuration.rst | 29 +- docs/source/appendix/howto_guides.rst | 116 +- docs/source/appendix/runbook.rst | 4 +- docs/source/architecture/data_model.rst | 63 +- docs/source/architecture/evaluation.rst | 137 ++ docs/source/architecture/operations.rst | 50 +- docs/source/architecture/system_overview.rst | 21 +- docs/source/index.rst | 19 +- docs/source/introduction.rst | 20 + docs/source/reference/api.rst | 314 +++- docs/source/reference/core.rst | 82 +- docs/source/reference/index.rst | 17 +- docs/source/reference/infrastructure.rst | 68 +- docs/source/reference/workers.rst | 24 +- docs/source/results.rst | 412 +++++ poetry.lock | 88 +- protea/api/app.py | 15 +- protea/api/deps.py | 28 + protea/api/routers/admin.py | 21 +- protea/api/routers/annotate.py | 243 +++ protea/api/routers/annotations.py | 26 +- protea/api/routers/embeddings.py | 97 +- protea/api/routers/jobs.py | 19 +- protea/api/routers/maintenance.py | 9 +- protea/api/routers/proteins.py | 13 +- protea/api/routers/query_sets.py | 12 +- protea/api/routers/scoring.py | 641 ++++++- protea/api/routers/showcase.py | 159 ++ protea/api/routers/support.py | 6 +- protea/config/system.yaml | 3 + protea/core/metrics.py | 14 +- protea/core/operations/compute_embeddings.py | 21 +- .../core/operations/fetch_uniprot_metadata.py | 4 + protea/core/operations/insert_proteins.py | 4 + protea/core/operations/predict_go_terms.py | 95 +- protea/core/operations/run_cafa_evaluation.py | 299 +++- protea/core/operations/train_reranker.py | 1479 +++++++++++++++++ protea/core/reranker.py | 302 ++++ protea/infrastructure/orm/models/__init__.py | 1 + .../models/annotation/evaluation_result.py | 9 + .../orm/models/embedding/embedding_config.py | 2 +- .../orm/models/embedding/go_prediction.py | 7 + .../orm/models/embedding/reranker_model.py | 45 + protea/infrastructure/orm/models/job.py | 5 +- protea/infrastructure/queue/publisher.py | 6 +- protea/infrastructure/settings.py | 9 +- pyproject.toml | 4 +- scripts/evaluate_external_tool.py | 414 +++++ scripts/query_eval_results.py | 114 ++ scripts/queue_evals_when_ready.py | 57 + scripts/run_experiments.py | 616 +++++++ scripts/worker.py | 7 +- tests/conftest.py | 6 + tests/test_admin_router.py | 33 +- tests/test_annotate_router.py | 376 +++++ tests/test_compute_embeddings.py | 187 ++- tests/test_core.py | 13 - tests/test_embeddings_router.py | 69 +- tests/test_generate_evaluation_set.py | 4 +- tests/test_integration.py | 631 +++++++ tests/test_load_goa_annotations.py | 4 +- tests/test_load_ontology_snapshot.py | 6 +- tests/test_load_quickgo_annotations.py | 4 +- tests/test_metrics.py | 2 +- tests/test_predict_go_terms.py | 446 ++++- tests/test_queue.py | 4 +- tests/test_reranker.py | 253 +++ tests/test_scoring_router.py | 605 +++++++ tests/test_showcase_router.py | 291 ++++ tests/test_train_reranker.py | 474 ++++++ 114 files changed, 12382 insertions(+), 694 deletions(-) create mode 100644 alembic/versions/110a5b8cfbb9_add_reranker_model_id_to_evaluation_.py create mode 100644 alembic/versions/3505bfa74df6_add_aspect_to_reranker_model_and_.py create mode 100644 alembic/versions/3884c47fe946_add_reranker_feature_columns_to_go_.py create mode 100644 alembic/versions/ba9966bd453e_add_reranker_model_table.py create mode 100644 apps/web/app/[locale]/reranker/page.tsx create mode 100644 apps/web/components/AnnotateForm.tsx create mode 100644 apps/web/components/Breadcrumbs.tsx create mode 100644 apps/web/components/ContextBanner.tsx create mode 100644 apps/web/components/FloatingJobsWidget.tsx create mode 100644 apps/web/components/Tooltip.tsx create mode 100644 apps/web/public/thesis.pdf create mode 100644 docs/source/results.rst create mode 100644 protea/api/deps.py create mode 100644 protea/api/routers/annotate.py create mode 100644 protea/api/routers/showcase.py create mode 100644 protea/core/operations/train_reranker.py create mode 100644 protea/core/reranker.py create mode 100644 protea/infrastructure/orm/models/embedding/reranker_model.py create mode 100644 scripts/evaluate_external_tool.py create mode 100644 scripts/query_eval_results.py create mode 100644 scripts/queue_evals_when_ready.py create mode 100644 scripts/run_experiments.py create mode 100644 tests/test_annotate_router.py create mode 100644 tests/test_integration.py create mode 100644 tests/test_reranker.py create mode 100644 tests/test_showcase_router.py create mode 100644 tests/test_train_reranker.py diff --git a/alembic/versions/110a5b8cfbb9_add_reranker_model_id_to_evaluation_.py b/alembic/versions/110a5b8cfbb9_add_reranker_model_id_to_evaluation_.py new file mode 100644 index 0000000..bce4436 --- /dev/null +++ b/alembic/versions/110a5b8cfbb9_add_reranker_model_id_to_evaluation_.py @@ -0,0 +1,36 @@ +"""add reranker_model_id to evaluation_result + +Revision ID: 110a5b8cfbb9 +Revises: ba9966bd453e +Create Date: 2026-03-19 10:52:11.951459 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '110a5b8cfbb9' +down_revision: Union[str, Sequence[str], None] = 'ba9966bd453e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('evaluation_result', sa.Column('reranker_model_id', sa.UUID(), nullable=True)) + op.create_index(op.f('ix_evaluation_result_reranker_model_id'), 'evaluation_result', ['reranker_model_id'], unique=False) + op.create_foreign_key(None, 'evaluation_result', 'reranker_model', ['reranker_model_id'], ['id'], ondelete='SET NULL') + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'evaluation_result', type_='foreignkey') + op.drop_index(op.f('ix_evaluation_result_reranker_model_id'), table_name='evaluation_result') + op.drop_column('evaluation_result', 'reranker_model_id') + # ### end Alembic commands ### diff --git a/alembic/versions/3505bfa74df6_add_aspect_to_reranker_model_and_.py b/alembic/versions/3505bfa74df6_add_aspect_to_reranker_model_and_.py new file mode 100644 index 0000000..8ba8c5c --- /dev/null +++ b/alembic/versions/3505bfa74df6_add_aspect_to_reranker_model_and_.py @@ -0,0 +1,34 @@ +"""add aspect to reranker_model and reranker_config to evaluation_result + +Revision ID: 3505bfa74df6 +Revises: 110a5b8cfbb9 +Create Date: 2026-03-19 15:16:18.474851 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '3505bfa74df6' +down_revision: Union[str, Sequence[str], None] = '110a5b8cfbb9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('evaluation_result', sa.Column('reranker_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + op.add_column('reranker_model', sa.Column('aspect', sa.String(length=3), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('reranker_model', 'aspect') + op.drop_column('evaluation_result', 'reranker_config') + # ### end Alembic commands ### diff --git a/alembic/versions/3884c47fe946_add_reranker_feature_columns_to_go_.py b/alembic/versions/3884c47fe946_add_reranker_feature_columns_to_go_.py new file mode 100644 index 0000000..a980419 --- /dev/null +++ b/alembic/versions/3884c47fe946_add_reranker_feature_columns_to_go_.py @@ -0,0 +1,40 @@ +"""add reranker feature columns to go_prediction + +Revision ID: 3884c47fe946 +Revises: 5fc2eb0f986d +Create Date: 2026-03-18 13:40:17.716092 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '3884c47fe946' +down_revision: Union[str, Sequence[str], None] = '5fc2eb0f986d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('go_prediction', sa.Column('vote_count', sa.Integer(), nullable=True)) + op.add_column('go_prediction', sa.Column('k_position', sa.Integer(), nullable=True)) + op.add_column('go_prediction', sa.Column('go_term_frequency', sa.Integer(), nullable=True)) + op.add_column('go_prediction', sa.Column('ref_annotation_density', sa.Integer(), nullable=True)) + op.add_column('go_prediction', sa.Column('neighbor_distance_std', sa.Float(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('go_prediction', 'neighbor_distance_std') + op.drop_column('go_prediction', 'ref_annotation_density') + op.drop_column('go_prediction', 'go_term_frequency') + op.drop_column('go_prediction', 'k_position') + op.drop_column('go_prediction', 'vote_count') + # ### end Alembic commands ### diff --git a/alembic/versions/ba9966bd453e_add_reranker_model_table.py b/alembic/versions/ba9966bd453e_add_reranker_model_table.py new file mode 100644 index 0000000..4a1516e --- /dev/null +++ b/alembic/versions/ba9966bd453e_add_reranker_model_table.py @@ -0,0 +1,50 @@ +"""add reranker_model table + +Revision ID: ba9966bd453e +Revises: 3884c47fe946 +Create Date: 2026-03-18 13:57:29.263810 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'ba9966bd453e' +down_revision: Union[str, Sequence[str], None] = '3884c47fe946' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('reranker_model', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('prediction_set_id', sa.UUID(), nullable=True), + sa.Column('evaluation_set_id', sa.UUID(), nullable=True), + sa.Column('category', sa.String(length=10), nullable=False), + sa.Column('model_data', sa.Text(), nullable=False), + sa.Column('metrics', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('feature_importance', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint(['evaluation_set_id'], ['evaluation_set.id'], ondelete='SET NULL'), + sa.ForeignKeyConstraint(['prediction_set_id'], ['prediction_set.id'], ondelete='SET NULL'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name') + ) + op.create_index(op.f('ix_reranker_model_evaluation_set_id'), 'reranker_model', ['evaluation_set_id'], unique=False) + op.create_index(op.f('ix_reranker_model_prediction_set_id'), 'reranker_model', ['prediction_set_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_reranker_model_prediction_set_id'), table_name='reranker_model') + op.drop_index(op.f('ix_reranker_model_evaluation_set_id'), table_name='reranker_model') + op.drop_table('reranker_model') + # ### end Alembic commands ### diff --git a/apps/web/app/[locale]/annotations/page.tsx b/apps/web/app/[locale]/annotations/page.tsx index 73be48b..0affb9e 100644 --- a/apps/web/app/[locale]/annotations/page.tsx +++ b/apps/web/app/[locale]/annotations/page.tsx @@ -201,12 +201,12 @@ export default function AnnotationsPage() {

{t("title")}

-
+
{tabs.map((tab) => ( ))} -
+
{/* ── Annotation Sets ── */} {activeTab === "sets" && ( @@ -226,7 +226,51 @@ export default function AnnotationsPage() { {t("setsTab.refresh")} -
+ {/* Mobile card list */} +
+ {loadingSets && Array.from({ length: 3 }).map((_, i) => ( +
+
+
+
+ ))} + {!loadingSets && sets.length === 0 && ( +
+ {t("setsTab.noSetsFound")} +
+ )} + {sets.map((a) => ( +
+
+ {a.source} + +
+

{a.source_version ?? "—"} · {(a.annotation_count ?? 0).toLocaleString()} annotations

+
+ {a.meta && Object.entries(a.meta).map(([k, v]) => ( + + {k}: {Array.isArray(v) ? v.join(", ") : String(v)} + + ))} +
+
+ {shortId(a.id)} + {formatDate(a.created_at)} + {a.job_id && ( + ↗ + )} +
+
+ ))} +
+ + {/* Desktop table */} +
{t("setsTab.tableHeaders.id")}
{t("setsTab.tableHeaders.source")}
{t("setsTab.tableHeaders.version")}
{t("setsTab.tableHeaders.annotations")}
{t("setsTab.tableHeaders.meta")}
{t("setsTab.tableHeaders.created")}
@@ -278,7 +322,70 @@ export default function AnnotationsPage() { {t("snapshotsTab.refresh")}
-
+ {/* Mobile card list */} +
+ {loadingSnaps && Array.from({ length: 2 }).map((_, i) => ( +
+
+
+
+ ))} + {!loadingSnaps && snapshots.length === 0 && ( +
+ {t("snapshotsTab.noSnapshotsFound")} +
+ )} + {snapshots.map((s) => ( +
+
+ {s.obo_version} + {(s.go_term_count ?? 0).toLocaleString()} terms +
+
+ {iaEditId === s.id ? ( +
+ setIaEditValue(e.target.value)} + placeholder="https://…/IA_cafa6.tsv or file path" + className="w-full rounded border px-2 py-1.5 text-xs focus:outline-none focus:ring-1 focus:ring-blue-500" + onKeyDown={(e) => { + if (e.key === "Enter") handleSaveIa(s.id); + if (e.key === "Escape") setIaEditId(null); + }} + /> +
+ + +
+
+ ) : ( + + )} +
+
+ {shortId(s.id)} + {formatDate(s.loaded_at)} +
+
+ ))} +
+ + {/* Desktop table */} +
{t("snapshotsTab.tableHeaders.id")}
{t("snapshotsTab.tableHeaders.version")}
{t("snapshotsTab.tableHeaders.goTerms")}
{t("snapshotsTab.tableHeaders.iaUrl")}
{t("snapshotsTab.tableHeaders.loaded")}
diff --git a/apps/web/app/[locale]/embeddings/page.tsx b/apps/web/app/[locale]/embeddings/page.tsx index 7110cb8..d3411d0 100644 --- a/apps/web/app/[locale]/embeddings/page.tsx +++ b/apps/web/app/[locale]/embeddings/page.tsx @@ -5,12 +5,14 @@ import Link from "next/link"; import { useTranslations } from "next-intl"; import { useToast } from "@/components/Toast"; import { SkeletonTableRow } from "@/components/Skeleton"; +import { ContextBanner } from "@/components/ContextBanner"; import { listEmbeddingConfigs, createEmbeddingConfig, deleteEmbeddingConfig, createJob, listQuerySets, + getProteinStats, EmbeddingConfig, QuerySet, } from "@/lib/api"; @@ -95,6 +97,7 @@ export default function EmbeddingsPage() { const [cmpResult, setCmpResult] = useState<{ id: string; status: string } | null>(null); const [cmpError, setCmpError] = useState(""); const [cmpSubmitting, setCmpSubmitting] = useState(false); + const [proteinCount, setProteinCount] = useState(null); async function loadAll() { setLoading(true); @@ -106,6 +109,7 @@ export default function EmbeddingsPage() { ]); setConfigs(cfgs); setQuerySets(qsets); + getProteinStats().then((s) => setProteinCount(s.total ?? 0)).catch(() => {}); if (cfgs.length > 0 && !cmpConfigId) setCmpConfigId(cfgs[0].id); } catch (e: any) { setError(String(e)); @@ -224,6 +228,16 @@ export default function EmbeddingsPage() {

{t("title")}

+ 0, href: "/proteins" }, + { label: `${configs.length} embedding config(s)`, met: configs.length > 0 }, + ] : undefined} + nextStep={{ label: "Functional Annotation", href: "/functional-annotation" }} + /> + {error && (
           {error}
@@ -475,53 +489,92 @@ export default function EmbeddingsPage() {
               {Array.from({ length: 3 }).map((_, i) => )}
             
) : ( -
-
-
{t("configsTab.tableHeaders.description")}
-
{t("configsTab.tableHeaders.model")}
-
{t("configsTab.tableHeaders.backend")}
-
{t("configsTab.tableHeaders.layers")}
-
{t("configsTab.tableHeaders.agg")}
-
{t("configsTab.tableHeaders.pool")}
-
{t("configsTab.tableHeaders.norm")}
-
{t("configsTab.tableHeaders.created")}
-
-
- {configs.map((c) => ( -
-
- {c.description || } + <> + {/* Mobile card list */} +
+ {configs.map((c) => ( +
+
+ + {c.description || } + + +
+

{c.model_name}

+
+ {c.model_backend} + layers [{c.layer_indices.join(", ")}] + {c.layer_agg}/{c.pooling} + {c.normalize ? "norm" : "no norm"} +
+

{formatDate(c.created_at)}

-
{c.model_name}
-
{c.model_backend}
-
[{c.layer_indices.join(", ")}]
-
{c.layer_agg}
-
{c.pooling}
-
{c.normalize ? "yes" : "no"}
-
{formatDate(c.created_at)}
-
-
+ )} +
+ + {/* Desktop table */} +
+
+
{t("configsTab.tableHeaders.description")}
+
{t("configsTab.tableHeaders.model")}
+
{t("configsTab.tableHeaders.backend")}
+
{t("configsTab.tableHeaders.layers")}
+
{t("configsTab.tableHeaders.agg")}
+
{t("configsTab.tableHeaders.pool")}
+
{t("configsTab.tableHeaders.norm")}
+
{t("configsTab.tableHeaders.created")}
+
- ))} - {configs.length === 0 && ( -
- {t("configsTab.noConfigs")}{" "} - -
- )} -
+ {configs.map((c) => ( +
+
+ {c.description || } +
+
{c.model_name}
+
{c.model_backend}
+
[{c.layer_indices.join(", ")}]
+
{c.layer_agg}
+
{c.pooling}
+
{c.normalize ? "yes" : "no"}
+
{formatDate(c.created_at)}
+
+ +
+
+ ))} + {configs.length === 0 && ( +
+ {t("configsTab.noConfigs")}{" "} + +
+ )} +
+ )}
)} @@ -572,7 +625,7 @@ export default function EmbeddingsPage() {
-
+
-
+
{["BPO", "MFO", "CCO"].map((ns) => { const m = results[setting]?.[ns]; if (!m) return null; @@ -178,7 +201,10 @@ function ResultsTable({ results }: { results: Record }) {m.recall.toFixed(3)}
- {t("resultMetrics.coverage")} + + {t("resultMetrics.coverage")} + + {(m.coverage * 100).toFixed(1)}%
@@ -201,6 +227,7 @@ function EvaluationSetCard({ annotationSets, predictionSets, scoringConfigs, + rerankers: initialRerankers, isSelected, onSelect, onDeleted, @@ -209,6 +236,7 @@ function EvaluationSetCard({ annotationSets: AnnotationSet[]; predictionSets: PredictionSet[]; scoringConfigs: ScoringConfig[]; + rerankers: RerankerModel[]; isSelected: boolean; onSelect: () => void; onDeleted: () => void; @@ -219,6 +247,14 @@ function EvaluationSetCard({ const [predSetId, setPredSetId] = useState(""); const [maxDistance, setMaxDistance] = useState(""); const [scoringConfigId, setScoringConfigId] = useState(""); + // 3x3 reranker grid: category × aspect + const [rrGrid, setRrGrid] = useState>>({ + nk: { bpo: "", mfo: "", cco: "" }, + lk: { bpo: "", mfo: "", cco: "" }, + pk: { bpo: "", mfo: "", cco: "" }, + }); + const setRrCell = (cat: string, asp: string, val: string) => + setRrGrid((prev) => ({ ...prev, [cat]: { ...prev[cat], [asp]: val } })); const [running, setRunning] = useState(false); const [runError, setRunError] = useState(""); const [pendingJobId, setPendingJobId] = useState(null); @@ -277,7 +313,24 @@ function EvaluationSetCard({ try { const body: Record = { prediction_set_id: predSetId }; if (maxDistance) body.max_distance = parseFloat(maxDistance); - if (scoringConfigId) body.scoring_config_id = scoringConfigId; + // Build nested rerankers mapping from the 3×3 grid + const rerankers: Record> = {}; + let hasAnyReranker = false; + for (const cat of ["nk", "lk", "pk"]) { + const catMap: Record = {}; + for (const asp of ["bpo", "mfo", "cco"]) { + if (rrGrid[cat]?.[asp]) { + catMap[asp] = rrGrid[cat][asp]; + hasAnyReranker = true; + } + } + if (Object.keys(catMap).length > 0) rerankers[cat] = catMap; + } + if (hasAnyReranker) { + body.rerankers = rerankers; + } else if (scoringConfigId) { + body.scoring_config_id = scoringConfigId; + } const res = await apiFetch<{ id: string; status: string }>( `/annotations/evaluation-sets/${e.id}/run`, { @@ -305,7 +358,7 @@ function EvaluationSetCard({ className="cursor-pointer p-4 hover:bg-gray-50 rounded-t-lg" onClick={onSelect} > -
+
{evalLabel(e, annotationSets)}
-
+

{t("evaluationSetCard.runCafaEvaluator")}

-
+
setScoringConfigId(ev.target.value)} - className={selectClass} - > - - {scoringConfigs.map((c) => ( - - ))} - -
+ + {/* Scoring method — 3×3 grid (category × aspect) */} +
+ + {initialRerankers.length > 0 && ( +
+ + + + + + + + + + + {(["nk", "lk", "pk"] as const).map((cat) => ( + + + {(["bpo", "mfo", "cco"] as const).map((asp) => { + // Show models matching this category+aspect, or category+null (all-aspect models) + const candidates = initialRerankers.filter( + (r) => r.category === cat && (r.aspect === asp || r.aspect === null) + ); + return ( + + ); + })} + + ))} + +
BPOMFOCCO
{cat} + +
+
+ )} + {(() => { + const hasAnyRr = Object.values(rrGrid).some((catMap) => Object.values(catMap).some(Boolean)); + return scoringConfigs.length > 0 && !hasAnyRr ? ( +
+ + +
+ ) : null; + })()} +
{runError && (

{runError} @@ -512,22 +618,87 @@ function EvaluationSetCard({ {results.map((r) => { const pred = predictionSets.find((p) => p.id === r.prediction_set_id); const sc = scoringConfigs.find((c) => c.id === r.scoring_config_id); + const hasReranker = !!r.reranker_model_id; + const rr = initialRerankers.find((m) => m.id === r.reranker_model_id); return (

{/* Meta header */} -
+
-
+
{t("evaluationSetCard.predictionSet")} {pred ? {r.prediction_set_id.slice(0, 8)}… · {new Date(pred.created_at).toLocaleDateString()}{pred.prediction_count != null ? ` · ${pred.prediction_count.toLocaleString()} preds.` : ""} : {r.prediction_set_id.slice(0, 8)}… } + {pred && ( + +
+
Prediction Set
+
+ Config + {pred.embedding_config_name ?? pred.embedding_config_id.slice(0, 8) + "…"} +
+
+ Annotations + {pred.annotation_set_label ?? pred.annotation_set_id.slice(0, 8) + "…"} +
+
+ Ontology + {pred.ontology_snapshot_version ?? pred.ontology_snapshot_id.slice(0, 8) + "…"} +
+
+ Max dist. + {pred.distance_threshold ?? "—"} +
+
+ Limit/entry + {pred.limit_per_entry} +
+
+
+ )}
-
+
{t("evaluationSetCard.scoring")} - {sc ? sc.name : {t("evaluationSetCard.fallbackFormula")}} - {sc?.description && } + {r.reranker_config ? ( + + Re-ranker + {Object.entries(r.reranker_config).map(([cat, aspMap]) => ( + + {cat.toUpperCase()}({Object.keys(aspMap).map(a => a.toUpperCase()).join(",")}) + + ))} + + ) : hasReranker ? ( + + Re-ranker + {rr ? rr.name : "model"} + + ) : sc ? sc.name : {t("evaluationSetCard.fallbackFormula")}} + {sc && !hasReranker && ( + +
+
{sc.name}
+ {sc.description &&
{sc.description}
} +
+ Formula + {sc.formula} +
+ {Object.keys(sc.weights).length > 0 && ( +
+
Weights
+ {Object.entries(sc.weights).map(([k, v]) => ( +
+ {k} + {v} +
+ ))} +
+ )} +
+
+ )}
{new Date(r.created_at).toLocaleString()}
@@ -559,6 +730,7 @@ function EvaluationSetCard({
)}
+
)}
@@ -571,6 +743,7 @@ export default function EvaluationPage() { const [predictionSets, setPredictionSets] = useState([]); const [evaluationSets, setEvaluationSets] = useState([]); const [scoringConfigs, setScoringConfigs] = useState([]); + const [rerankers, setRerankers] = useState([]); const [loading, setLoading] = useState(true); const [oldSetId, setOldSetId] = useState(""); @@ -580,12 +753,13 @@ export default function EvaluationPage() { const [selectedEvalId, setSelectedEvalId] = useState(""); const reload = () => - Promise.all([listAnnotationSets(), listPredictionSets(), listEvaluationSets(), listScoringConfigs()]) - .then(([ann, pred, ev, sc]) => { + Promise.all([listAnnotationSets(), listPredictionSets(), listEvaluationSets(), listScoringConfigs(), listRerankers()]) + .then(([ann, pred, ev, sc, rr]) => { setAnnotationSets(ann); setPredictionSets(pred); setEvaluationSets(ev); setScoringConfigs(sc); + setRerankers(rr); }) .finally(() => setLoading(false)); @@ -616,9 +790,19 @@ export default function EvaluationPage() { if (loading) return
Loading…
; return ( -
+

{t("title")}

+ = 2, href: "/annotations" }, + { label: `${predictionSets.length} prediction set(s)`, met: predictionSets.length > 0, href: "/functional-annotation" }, + ]} + nextStep={{ label: "Scoring configs", href: "/scoring" }} + /> + {/* ── Generate Evaluation Set ───────────────────────────────── */}
@@ -627,7 +811,7 @@ export default function EvaluationPage() { {t("generateSection.description")}

-
+
setSelectedConfigId(e.target.value)} - className="rounded-md border bg-white px-2 py-1.5 text-sm text-gray-700 shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500" - > - - {scoringConfigs.map((c) => ( - - ))} - - -
- +
+
+ +
+ +
+
{selectedConfigId && ( + {/* ── Executive summary ── */} + {activeTab === "proteins" && distribution && ( +
+
+
{proteinTotal.toLocaleString()}
+
Proteins
+
+ {(["P", "F", "C"] as const).map((aspect) => ( +
+
+ {(distribution.aspect_totals[aspect] ?? 0).toLocaleString()} +
+
{ASPECT_LABELS[aspect]}
+
+ ))} +
+ )} + {/* ── Proteins ── */} {activeTab === "proteins" && (
-
-
+
+ setProteinSearchInput(e.target.value)} placeholder="Filter by accession…" - className="rounded-md border px-3 py-1.5 text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 w-56" + className="rounded-md border px-3 py-1.5 text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 w-full sm:w-56" /> )} - {proteinTotal.toLocaleString()} proteins + {proteinTotal.toLocaleString()} proteins +
+ + {/* Mobile card list */} +
+ {loadingProteins && Array.from({ length: 4 }).map((_, i) => ( +
+
+
+
+ ))} + {!loadingProteins && proteins.length === 0 && ( +
No proteins found.
+ )} + {!loadingProteins && proteins.map((p) => ( +
+
selectProtein(p.accession, p.in_db)} + > +
+
+ + {p.in_db ? ( + e.stopPropagation()}> + {p.accession} + + ) : ( + {p.accession} + )} +
+ {p.go_count} predicted +
+
+ dist: {p.min_distance?.toFixed(4) ?? "—"} + known/pred: {p.annotation_count}/{p.go_count} +
+
+ {selectedAccession === p.accession && ( +
+ setSelectedAccession(null)} + ontologySnapshotId={ontologySnapshotId} + scoringConfig={selectedConfig} + /> +
+ )} +
+ ))}
-
-
+ {/* Desktop table */} +
+
Accession
Predicted
Min Distance
-
Known
-
Matches
+
Known / Pred.
{loadingProteins && Array.from({ length: 8 }).map((_, i) => )} @@ -911,12 +1081,18 @@ export default function PredictionSetDetailPage({ params }: { params: Promise<{ {!loadingProteins && proteins.map((p) => (
selectProtein(p.accession, p.in_db)} >
+ {p.in_db ? (
{p.go_count}
{p.min_distance?.toFixed(4) ?? "—"}
-
0 ? "text-gray-700" : "text-gray-300"}`}> - {p.annotation_count > 0 ? p.annotation_count : "—"} -
-
0 ? "text-green-700" : "text-gray-300"}`}> - {p.match_count > 0 ? p.match_count : "—"} +
+ {p.annotation_count > 0 + ? {p.annotation_count} + : 0} + / + {p.go_count}
diff --git a/apps/web/app/[locale]/functional-annotation/page.tsx b/apps/web/app/[locale]/functional-annotation/page.tsx index 6aca0b0..92626b4 100644 --- a/apps/web/app/[locale]/functional-annotation/page.tsx +++ b/apps/web/app/[locale]/functional-annotation/page.tsx @@ -5,6 +5,7 @@ import Link from "next/link"; import { useTranslations } from "next-intl"; import { useToast } from "@/components/Toast"; import { SkeletonTableRow } from "@/components/Skeleton"; +import { ContextBanner } from "@/components/ContextBanner"; import { listEmbeddingConfigs, launchPredictGoTerms, @@ -163,6 +164,17 @@ export default function FunctionalAnnotationPage() {

{t("title")}

+ 0, href: "/embeddings" }, + { label: `${annotationSets.length} annotation set(s)`, met: annotationSets.length > 0, href: "/annotations" }, + { label: `${ontologySnapshots.length} ontology snapshot(s)`, met: ontologySnapshots.length > 0, href: "/annotations" }, + ] : undefined} + nextStep={{ label: "Evaluation", href: "/evaluation" }} + /> +
{tabs.map((tab) => (
-
+

{t("predictTab.searchBackend")}

-
+
{predFaissIndex === "IVFFlat" && ( -
+
setPredFaissNlist(parseInt(e.target.value, 10))} min={1} className={inputClass} /> @@ -364,7 +376,7 @@ export default function FunctionalAnnotationPage() {
)} {predFaissIndex === "HNSW" && ( -
+
setPredFaissHnswM(parseInt(e.target.value, 10))} min={2} className={inputClass} /> @@ -422,13 +434,14 @@ export default function FunctionalAnnotationPage() {
-
+
{t("resultsTab.tableHeaders.id")}
{t("resultsTab.tableHeaders.config")}
{t("resultsTab.tableHeaders.annotationSet")}
{t("resultsTab.tableHeaders.snapshot")}
{t("resultsTab.tableHeaders.goTerms")}
{t("resultsTab.tableHeaders.distanceThreshold")}
+
{t("resultsTab.tableHeaders.k")}
{t("resultsTab.tableHeaders.created")}
@@ -436,7 +449,7 @@ export default function FunctionalAnnotationPage() { {predictionSets.map((ps) => (
@@ -450,6 +463,7 @@ export default function FunctionalAnnotationPage() {
{ps.distance_threshold != null ? ps.distance_threshold : }
+
{ps.limit_per_entry}
{formatDate(ps.created_at)}
+
+
+ ); + } + + if (!data) { + return ( +
+
+
+ {[0, 1, 2].map((i) => ( +
+ ))} +
+
+
+ ); + } + + const hasFmax = data.best_fmax && Object.keys(data.best_fmax).length > 0; + const hasComparison = data.method_comparison && Object.keys(data.method_comparison).length > 0; + + // Available categories (only those with data) + const availableCategories = CATEGORIES.filter( + (cat) => data.best_fmax?.[cat] || data.method_comparison?.[cat] + ); + + // Current category data + const catFmax = data.best_fmax?.[activeCategory] ?? {}; + const catMethods = data.method_comparison?.[activeCategory] ?? []; + const baseline = catMethods.find((m) => m.method === "knn_baseline"); + + return ( +
+ {/* ── Hero ──────────────────────────────────────────────────── */} +
+

+ PROTEA +

+

+ {t("subtitle")} +

+
+ + {/* ── Annotate form ─────────────────────────────────────────── */} + + + {/* ── Category tabs ─────────────────────────────────────────── */} + {hasFmax ? ( + <> +
+
+

+ {t("bestResults")} +

+
+ {availableCategories.map((cat) => ( + + ))} +
+ + {CATEGORY_LABELS[activeCategory]} + +
+ + {/* ── Fmax cards ────────────────────────────────────────── */} +
+ {ASPECTS.map((aspect) => { + const d = catFmax[aspect]; + if (!d) return null; + const color = ASPECT_COLORS[aspect]; + return ( +
+
+ {d.fmax.toFixed(2)} +
+
+ {t("fmax")} {aspect} +
+
+ {ASPECT_LABELS[aspect]} +
+
+ {d.method_label} +
+
+ ); + })} +
+
+ + {/* ── Method comparison table ───────────────────────────── */} + {catMethods.length > 0 && ( +
+

+ {t("methodComparison")} + + ({activeCategory}) + +

+
+ + + + + {ASPECTS.map((a) => ( + + ))} + + + + {catMethods.map((row, i) => { + const isBest = ASPECTS.some( + (a) => catFmax[a]?.method === row.method + ); + return ( + + + {ASPECTS.map((aspect) => { + const val = (row as any)[aspect]?.fmax; + const baseVal = baseline ? (baseline as any)[aspect]?.fmax : null; + const delta = val != null && baseVal != null && row.method !== "knn_baseline" + ? val - baseVal + : null; + return ( + + ); + })} + + ); + })} + +
{t("method")} + {a} +
+ {t(METHOD_KEYS[row.method] ?? row.method)} + {isBest && ( + best + )} + + {val != null ? ( + + {val.toFixed(3)} + {delta != null && ( + 0 ? "text-green-600" : delta < 0 ? "text-red-600" : "text-gray-400"}`}> + {delta > 0 ? "+" : ""}{delta.toFixed(3)} + + )} + + ) : ( + + )} +
+
+
+ )} + + ) : ( +
+

{t("noDataYet")}

+ + {t("getStarted")} + +
+ )} + + {/* ── Pipeline diagram ──────────────────────────────────────── */} +
+

+ {t("pipeline")} +

+
+ {data.pipeline_stages.map((stage, i) => ( +
+ {i > 0 && ( +
+ → +
+ )} + +
+ ))} + {/* LLM stage (future) */} +
+
+ → +
+
+ LLM + {t("stageLlm")} + soon +
+
+
+
+ + {/* ── Stats bar ─────────────────────────────────────────────── */} +
+

+ {t("stats")} +

+
+ {([ + ["proteins", data.counts.proteins], + ["sequences", data.counts.sequences], + ["embeddings", data.counts.embeddings], + ["predictions", data.counts.predictions], + ] as [string, number][]).map(([key, count]) => ( +
+
+ {count.toLocaleString()} +
+
{t(key as any)}
+
+ ))} +
+
+ + {/* ── CTAs ──────────────────────────────────────────────────── */} +
+ + {t("exploreResults")} + + + {t("annotateProteins")} + +
+
+ ); } diff --git a/apps/web/app/[locale]/proteins/[accession]/page.tsx b/apps/web/app/[locale]/proteins/[accession]/page.tsx index 2c5b4c0..5ca2158 100644 --- a/apps/web/app/[locale]/proteins/[accession]/page.tsx +++ b/apps/web/app/[locale]/proteins/[accession]/page.tsx @@ -4,6 +4,7 @@ import { use, useEffect, useState } from "react"; import Link from "next/link"; import { useToast } from "@/components/Toast"; import { useTranslations } from "next-intl"; +import { Breadcrumbs } from "@/components/Breadcrumbs"; import { getProtein, getProteinAnnotations, getGoSubgraph, listOntologySnapshots, ProteinDetail, ProteinAnnotation, GoSubgraph } from "@/lib/api"; import dynamic from "next/dynamic"; const GoGraph = dynamic(() => import("@/components/GoGraph"), { ssr: false }); @@ -88,7 +89,7 @@ export default function ProteinDetailPage({ params }: { params: Promise<{ access <> {/* Header */}
- {t("backToProteins")} +

{protein.accession}

diff --git a/apps/web/app/[locale]/proteins/page.tsx b/apps/web/app/[locale]/proteins/page.tsx index 4360894..6334e8e 100644 --- a/apps/web/app/[locale]/proteins/page.tsx +++ b/apps/web/app/[locale]/proteins/page.tsx @@ -236,8 +236,41 @@ export default function ProteinsPage() { {t("browseTab.totalProteins", { count: total.toLocaleString() })}
- {/* Table */} -
+ {/* Mobile card list */} +
+ {loadingBrowse && Array.from({ length: 4 }).map((_, i) => ( +
+
+
+
+ ))} + {!loadingBrowse && proteins.length === 0 && ( +
+ {t("browseTab.noProteinsCta")} +
+ )} + {!loadingBrowse && proteins.map((p) => ( + +
+ {p.accession} + +
+

{p.gene_name ?? "—"}

+

{p.organism ?? "—"}

+
+ {p.entry_name ?? "—"} + {p.length != null && {p.length.toLocaleString()} aa} +
+ + ))} +
+ + {/* Desktop table */} +
{t("browseTab.tableHeaders.accession")}
{t("browseTab.tableHeaders.entryName")}
@@ -353,7 +386,7 @@ export default function ProteinsPage() { setSearchCriteria(e.target.value)} required className={inputClass} placeholder="organism_id:9606 AND reviewed:true" />

{t("insertTab.searchCriteriaHelper")}

-
+
setPageSize(parseInt(e.target.value, 10))} min={1} className={inputClass} /> @@ -395,7 +428,7 @@ export default function ProteinsPage() { setMetaCriteria(e.target.value)} required className={inputClass} placeholder="organism_id:9606 AND reviewed:true" />

{t("metadataTab.searchCriteriaHelper")}

-
+
setMetaPageSize(parseInt(e.target.value, 10))} min={1} className={inputClass} /> diff --git a/apps/web/app/[locale]/query-sets/page.tsx b/apps/web/app/[locale]/query-sets/page.tsx index fc1a63b..49b6c00 100644 --- a/apps/web/app/[locale]/query-sets/page.tsx +++ b/apps/web/app/[locale]/query-sets/page.tsx @@ -117,7 +117,7 @@ export default function QuerySetsPage() { {/* List */}
-
+
{t("tableHeaders.name")}
{t("tableHeaders.sequences")}
{t("tableHeaders.created")}
@@ -143,7 +143,7 @@ export default function QuerySetsPage() { {sets.map((qs) => (
setExpandedId(expandedId === qs.id ? null : qs.id)} >
diff --git a/apps/web/app/[locale]/reranker/page.tsx b/apps/web/app/[locale]/reranker/page.tsx new file mode 100644 index 0000000..edf0751 --- /dev/null +++ b/apps/web/app/[locale]/reranker/page.tsx @@ -0,0 +1,574 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { ContextBanner } from "@/components/ContextBanner"; +import { + baseUrl, + listPredictionSets, + listAnnotationSets, + listRerankers, + trainReranker, + deleteReranker, + getRerankedTsvUrl, + getRerankerMetrics, + getTrainingDataTsvUrl, +} from "@/lib/api"; +import type { PredictionSet, AnnotationSet, RerankerModel } from "@/lib/api"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +async function apiFetch(path: string, init?: RequestInit): Promise { + const res = await fetch(`${baseUrl()}${path}`, { cache: "no-store", ...init }); + if (!res.ok) throw new Error(await res.text()); + return res.json(); +} + +type EvaluationSet = { + id: string; + old_annotation_set_id: string; + new_annotation_set_id: string; + created_at: string; + stats: Record; +}; + +const listEvaluationSets = () => apiFetch("/annotations/evaluation-sets"); + +function shortId(id: string) { return id.slice(0, 8); } + +function predLabel(p: PredictionSet) { + const parts: string[] = []; + if (p.embedding_config_name) parts.push(p.embedding_config_name); + if (p.annotation_set_label) parts.push(p.annotation_set_label); + parts.push(`k=${p.limit_per_entry}`); + if (p.prediction_count != null) parts.push(`${p.prediction_count.toLocaleString()} preds`); + return `${parts.join(" · ")} (${shortId(p.id)}…)`; +} + +function evalLabel(es: EvaluationSet, annotationSets: AnnotationSet[]) { + const oldSet = annotationSets.find((a) => a.id === es.old_annotation_set_id); + const newSet = annotationSets.find((a) => a.id === es.new_annotation_set_id); + const oldVer = oldSet ? `[${oldSet.source.toUpperCase()}] ${oldSet.source_version ?? "?"}` : shortId(es.old_annotation_set_id); + const newVer = newSet ? `[${newSet.source.toUpperCase()}] ${newSet.source_version ?? "?"}` : shortId(es.new_annotation_set_id); + const delta = es.stats.delta_proteins ?? "?"; + return `${oldVer} → ${newVer} · ${delta} delta proteins (${shortId(es.id)}…)`; +} + +const labelClass = "block text-sm font-medium text-gray-700 mb-1"; +const selectClass = + "w-full rounded-md border border-gray-300 px-3 py-2 text-sm focus:outline-none focus:ring-2 focus:ring-blue-500"; +const btnPrimary = + "rounded-md bg-blue-600 px-4 py-2 text-sm font-medium text-white hover:bg-blue-700 disabled:opacity-50 transition-colors"; +const btnDanger = + "rounded-md bg-red-50 border border-red-200 px-3 py-1.5 text-xs font-medium text-red-600 hover:bg-red-100 transition-colors"; + +const CATEGORY_HINTS: Record = { + nk: "No Knowledge: proteins with zero GO annotations at t0. Hardest setting — measures pure prediction ability.", + lk: "Limited Knowledge: proteins annotated in some GO namespaces but not all at t0. New annotations in previously empty namespaces.", + pk: "Partial Knowledge: proteins that already had annotations in a namespace at t0 and gained new ones at t1.", +}; + +const ASPECT_LABELS: Record = { + bpo: "BPO (Biological Process)", + mfo: "MFO (Molecular Function)", + cco: "CCO (Cellular Component)", +}; + +// --------------------------------------------------------------------------- +// Feature importance bar chart +// --------------------------------------------------------------------------- + +function FeatureImportanceChart({ importance }: { importance: Record }) { + const entries = Object.entries(importance) + .sort(([, a], [, b]) => b - a) + .filter(([, v]) => v > 0); + if (entries.length === 0) return

No feature importance data

; + const maxVal = entries[0][1]; + + return ( +
+ {entries.map(([name, val]) => ( +
+ {name} +
+
+
+ + {val >= 1000 ? `${(val / 1000).toFixed(1)}k` : val.toFixed(0)} + +
+ ))} +
+ ); +} + +// --------------------------------------------------------------------------- +// Metrics display +// --------------------------------------------------------------------------- + +function MetricsBadge({ label, value, suffix }: { label: string; value: number | string | undefined; suffix?: string }) { + if (value === undefined) return null; + const formatted = typeof value === "number" ? value.toFixed(4) : value; + return ( +
+

{label}

+

{formatted}{suffix}

+
+ ); +} + +// --------------------------------------------------------------------------- +// Reranker card +// --------------------------------------------------------------------------- + +function RerankerCard({ + model, + predictionSets, + evaluationSets, + annotationSets, + onDelete, +}: { + model: RerankerModel; + predictionSets: PredictionSet[]; + evaluationSets: EvaluationSet[]; + annotationSets: AnnotationSet[]; + onDelete: () => void; +}) { + const [expanded, setExpanded] = useState(false); + const [metricsLoading, setMetricsLoading] = useState(false); + const [metrics, setMetrics] = useState | null>(null); + const [metricsError, setMetricsError] = useState(null); + const [deleting, setDeleting] = useState(false); + + // For computing metrics on a different prediction set + const [metricsPsId, setMetricsPsId] = useState(model.prediction_set_id ?? ""); + const [metricsEsId, setMetricsEsId] = useState(model.evaluation_set_id ?? ""); + const [metricsCategory, setMetricsCategory] = useState(model.category); + + async function handleComputeMetrics() { + if (!metricsPsId || !metricsEsId) return; + setMetricsLoading(true); + setMetricsError(null); + setMetrics(null); + try { + const result = await getRerankerMetrics(metricsPsId, model.id, metricsEsId, metricsCategory); + setMetrics(result); + } catch (e: any) { + setMetricsError(e.message ?? "Failed to compute metrics"); + } finally { + setMetricsLoading(false); + } + } + + async function handleDelete() { + if (!confirm(`Delete reranker "${model.name}"?`)) return; + setDeleting(true); + try { + await deleteReranker(model.id); + onDelete(); + } catch { + setDeleting(false); + } + } + + const m = model.metrics; + + return ( +
+
setExpanded(!expanded)} + > +
+
+ {model.name} + + {model.category} + + {model.aspect && ( + + {model.aspect} + + )} +
+
+ {new Date(model.created_at).toLocaleDateString()} + {expanded ? "▲" : "▼"} +
+
+
+ AUC: {m.val_auc?.toFixed(4) ?? "—"} + F1: {m.val_f1?.toFixed(4) ?? "—"} + Precision: {m.val_precision?.toFixed(4) ?? "—"} + Recall: {m.val_recall?.toFixed(4) ?? "—"} + Positive rate: {m.positive_rate != null ? `${(m.positive_rate * 100).toFixed(2)}%` : "—"} +
+
+ + {expanded && ( +
+ {/* Validation metrics */} +
+

Validation metrics

+
+ + + + +
+
+ Train samples: {m.train_samples?.toLocaleString()} + Val samples: {m.val_samples?.toLocaleString()} +
+
+ + {/* Feature importance */} +
+

Feature importance (gain)

+ +
+ + {/* Download reranked TSV */} + {model.prediction_set_id && ( +
+

Download re-ranked predictions

+ + ↓ Download reranked TSV + +
+ )} + + {/* Compute CAFA metrics */} +
+

Compute CAFA metrics

+
+
+ + +
+
+ + +
+
+ + +

{CATEGORY_HINTS[metricsCategory]}

+
+
+ + {metricsError &&

{metricsError}

} + {metrics && ( +
+
+ + + + + + +
+ {metrics.curve && metrics.curve.length > 0 && ( +

{metrics.curve.length} PR curve points computed

+ )} +
+ )} +
+ + {/* Source info */} +
+ Prediction set: {model.prediction_set_id ? shortId(model.prediction_set_id) : "—"} + Evaluation set: {model.evaluation_set_id ? shortId(model.evaluation_set_id) : "—"} + ID: {shortId(model.id)} +
+ + {/* Delete */} +
+ +
+
+ )} +
+ ); +} + +// --------------------------------------------------------------------------- +// Main page +// --------------------------------------------------------------------------- + +export default function RerankerPage() { + const [rerankers, setRerankers] = useState([]); + const [predictionSets, setPredictionSets] = useState([]); + const [evaluationSets, setEvaluationSets] = useState([]); + const [annotationSets, setAnnotationSets] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + // Train form + const [trainName, setTrainName] = useState(""); + const [trainPsId, setTrainPsId] = useState(""); + const [trainEsId, setTrainEsId] = useState(""); + const [trainCategory, setTrainCategory] = useState("nk"); + const [trainAspect, setTrainAspect] = useState(""); + const [trainNegPosRatio, setTrainNegPosRatio] = useState(""); + const [extraPairs, setExtraPairs] = useState<{ psId: string; esId: string }[]>([]); + const [training, setTraining] = useState(false); + const [trainError, setTrainError] = useState(null); + + async function loadAll() { + setLoading(true); + setError(null); + try { + const [r, ps, es, as_] = await Promise.all([ + listRerankers(), + listPredictionSets(), + listEvaluationSets(), + listAnnotationSets(), + ]); + setRerankers(r); + setPredictionSets(ps); + setEvaluationSets(es); + setAnnotationSets(as_); + } catch (e: any) { + setError(e.message ?? "Failed to load data"); + } finally { + setLoading(false); + } + } + + useEffect(() => { loadAll(); }, []); + + async function handleTrain() { + if (!trainName.trim() || !trainPsId || !trainEsId) return; + setTraining(true); + setTrainError(null); + try { + const validExtraPairs = extraPairs + .filter((p) => p.psId && p.esId) + .map((p) => ({ prediction_set_id: p.psId, evaluation_set_id: p.esId })); + const model = await trainReranker({ + name: trainName.trim(), + prediction_set_id: trainPsId, + evaluation_set_id: trainEsId, + category: trainCategory, + aspect: trainAspect || null, + neg_pos_ratio: trainNegPosRatio ? parseFloat(trainNegPosRatio) : null, + extra_pairs: validExtraPairs.length > 0 ? validExtraPairs : undefined, + }); + setRerankers((prev) => [...prev, model]); + setTrainName(""); + } catch (e: any) { + setTrainError(e.message ?? "Training failed"); + } finally { + setTraining(false); + } + } + + return ( + <> +

Re-ranker Models

+ + 0, href: "/functional-annotation" }, + { label: `${evaluationSets.length} evaluation set(s)`, met: evaluationSets.length > 0, href: "/evaluation" }, + ]} + nextStep={{ label: "Evaluation", href: "/evaluation" }} + /> +

+ LightGBM binary classifiers trained on temporal holdout data (CAFA protocol). + A re-ranker uses alignment, taxonomy, and aggregate features to re-score GO predictions + with calibrated probabilities, replacing the raw embedding distance ranking. +

+ + {/* Train new reranker */} +
+

Train new re-ranker

+
+
+ + setTrainName(e.target.value)} + placeholder="e.g. reranker-nk-bpo-v1" + className="w-full rounded-md border border-gray-300 px-3 py-2 text-sm focus:outline-none focus:ring-2 focus:ring-blue-500" + /> +
+
+ + +
+
+ + +
+
+ + +

{CATEGORY_HINTS[trainCategory]}

+
+
+ + +
+
+ + setTrainNegPosRatio(e.target.value)} + className="w-full rounded-md border border-gray-300 px-3 py-2 text-sm focus:outline-none focus:ring-2 focus:ring-blue-500" + /> +
+
+ + {/* Extra training pairs */} +
+
+ + +
+ {extraPairs.map((pair, i) => ( +
+ + + +
+ ))} + {extraPairs.length > 0 && ( +

+ Data from all pairs will be concatenated before training a single model. + {extraPairs.filter((p) => p.psId && p.esId).length > 0 && + ` (${1 + extraPairs.filter((p) => p.psId && p.esId).length} pairs total)`} +

+ )} +
+ +
+ + {trainPsId && trainEsId && ( + + ↓ Preview training data TSV + + )} +
+ {trainError &&

{trainError}

} +
+ + {/* List of rerankers */} + {loading &&

Loading...

} + {error &&

{error}

} + + {!loading && rerankers.length === 0 && ( +
+ No re-ranker models trained yet. Use the form above to train one. +
+ )} + +
+ {rerankers.map((model) => ( + setRerankers((prev) => prev.filter((r) => r.id !== model.id))} + /> + ))} +
+ + ); +} diff --git a/apps/web/components/AnnotateForm.tsx b/apps/web/components/AnnotateForm.tsx new file mode 100644 index 0000000..e28e1cf --- /dev/null +++ b/apps/web/components/AnnotateForm.tsx @@ -0,0 +1,302 @@ +"use client"; + +import { useState, useRef, useCallback, useEffect } from "react"; +import { useRouter } from "next/navigation"; +import { useTranslations } from "next-intl"; +import { + annotateProteins, + getJob, + launchPredictGoTerms, + listPredictionSets, + type AnnotateResult, +} from "@/lib/api"; + +type Stage = "idle" | "uploading" | "embedding" | "predicting" | "done" | "error"; + +const POLL_MS = 3_000; + +const EXAMPLE_FASTA = `>sp|P04637|P53_HUMAN Cellular tumor antigen p53 +MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGP +DEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYPQGLNGTVNLPGRNSFEV +RVCACPGRDRRTEEENLHKTTGIDSFLHPEVEYFTPETDPAGPMCSRHFYQLAKTCPVQLW +VDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHERCTCGGNHGISTTTGICLICQFFLVHKP +>sp|P38398|BRCA1_HUMAN Breast cancer type 1 susceptibility protein +MDLSALRVEEVQNVINAMQKILECPICLELIKEPVSTKCDHIFCKFCMLKLLNQKKGPSQC +PLCKNDITKRSLQESTRFSQLVEELLKIICAFQLDTGLEYANSYNFAKKENNSPEHLKDEV +SIIQSMGYRNRAKRLLQSEPENPSLQETSLSVQLSNLGTVRTLRTKQRIQPQKTSVYIELG`; + +export function AnnotateForm() { + const t = useTranslations("home"); + const router = useRouter(); + + const [fasta, setFasta] = useState(""); + const [stage, setStage] = useState("idle"); + const [error, setError] = useState(null); + const [progress, setProgress] = useState(""); + const [predictionSetId, setPredictionSetId] = useState(null); + const [rerankerId, setRerankerId] = useState(null); + const fileRef = useRef(null); + const abortRef = useRef(false); + + // Drag-and-drop state + const [dragOver, setDragOver] = useState(false); + + const handleFile = (file: File) => { + const reader = new FileReader(); + reader.onload = (e) => { + const text = e.target?.result; + if (typeof text === "string") setFasta(text); + }; + reader.readAsText(file); + }; + + const handleDrop = (e: React.DragEvent) => { + e.preventDefault(); + setDragOver(false); + const file = e.dataTransfer.files?.[0]; + if (file) handleFile(file); + }; + + const pollJob = useCallback( + async (jobId: string): Promise<"succeeded" | "failed"> => { + while (!abortRef.current) { + try { + const job = await getJob(jobId); + if (job.progress_total && job.progress_current) { + const pct = Math.round((job.progress_current / job.progress_total) * 100); + setProgress(`${pct}%`); + } + if (job.status === "succeeded") return "succeeded"; + if (job.status === "failed" || job.status === "cancelled") return "failed"; + } catch { + // transient error, keep polling + } + await new Promise((r) => setTimeout(r, POLL_MS)); + } + return "failed"; + }, + [], + ); + + const handleSubmit = async () => { + if (!fasta.trim()) return; + abortRef.current = false; + setError(null); + setStage("uploading"); + setProgress(""); + + try { + // Step 1: Upload FASTA + create embedding job + setProgress(t("annotateUploading" as any)); + const result: AnnotateResult = await annotateProteins({ + fastaText: fasta, + name: `Annotation ${new Date().toISOString().slice(0, 16)}`, + }); + + // Step 2: Poll embedding job + setStage("embedding"); + setProgress("0%"); + const embedResult = await pollJob(result.embedding_job_id); + if (embedResult === "failed") { + throw new Error("Embedding computation failed"); + } + + // Step 3: Launch prediction + setStage("predicting"); + setProgress("0%"); + const predictJob = await launchPredictGoTerms(result.predict_payload as Parameters[0]); + + // Step 4: Poll prediction job + const predictResult = await pollJob(predictJob.id); + if (predictResult === "failed") { + throw new Error("Prediction failed"); + } + + // Step 5: Find the prediction set created for this query_set + const sets = await listPredictionSets(); + const match = sets.find( + (s) => + (s as any).query_set_id === result.query_set_id && + s.embedding_config_id === result.embedding_config_id, + ); + if (match) { + setPredictionSetId(match.id); + } + if (result.reranker_id) { + setRerankerId(result.reranker_id); + } + + setStage("done"); + setProgress(""); + } catch (err: any) { + setStage("error"); + setError(err?.message ?? "Unknown error"); + } + }; + + // Auto-redirect when done + useEffect(() => { + if (stage === "done" && predictionSetId) { + const timer = setTimeout(() => { + const qs = rerankerId ? `?reranker_id=${rerankerId}` : ""; + router.push(`/functional-annotation/${predictionSetId}${qs}`); + }, 1500); + return () => clearTimeout(timer); + } + }, [stage, predictionSetId, rerankerId, router]); + + // Cleanup on unmount + useEffect(() => { + return () => { + abortRef.current = true; + }; + }, []); + + const isRunning = stage === "uploading" || stage === "embedding" || stage === "predicting"; + + return ( +
+

+ {t("annotateTitle" as any)} +

+

+ {t("annotateDescription" as any)} +

+ + {/* FASTA input */} +
{ + e.preventDefault(); + setDragOver(true); + }} + onDragLeave={() => setDragOver(false)} + onDrop={handleDrop} + > +