Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ on:
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true

jobs:
build-backend:
Expand All @@ -20,7 +19,7 @@ jobs:
packages: write

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
Expand Down Expand Up @@ -59,7 +58,7 @@ jobs:
packages: write

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
Expand Down
7 changes: 2 additions & 5 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ on:
push:
pull_request:

env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true

jobs:
docs:
runs-on: ubuntu-22.04
Expand All @@ -17,9 +14,9 @@ jobs:
poetry-version: ["2.1.0"]

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6

- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

Expand Down
7 changes: 2 additions & 5 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ on:
branches: [main]
pull_request:

env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true

jobs:
integration:
runs-on: ubuntu-22.04
Expand All @@ -33,9 +30,9 @@ jobs:
--health-retries 5

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6

- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

Expand Down
7 changes: 2 additions & 5 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ on:
push:
pull_request:

env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true

jobs:
lint:
runs-on: ubuntu-22.04
Expand All @@ -17,9 +14,9 @@ jobs:
poetry-version: ["2.1.0"]

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6

- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

Expand Down
7 changes: 2 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ on:
push:
pull_request:

env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true

jobs:
test:
runs-on: ubuntu-22.04
Expand All @@ -17,9 +14,9 @@ jobs:
poetry-version: ["2.1.0"]

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6

- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

Expand Down
26 changes: 22 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
FROM python:3.12-slim
# ── Stage 1: build dependencies ──────────────────────────────────────────────
FROM python:3.12-slim AS builder

RUN apt-get update && apt-get install -y \
build-essential \
Expand All @@ -8,24 +9,41 @@ RUN apt-get update && apt-get install -y \

WORKDIR /app

# Install Poetry and dependencies first (layer cache)
RUN pip install --no-cache-dir poetry==2.1.0

COPY pyproject.toml poetry.lock ./
RUN poetry config virtualenvs.create false \
&& poetry install --without dev --no-root --no-interaction --no-ansi

# Copy source
COPY protea/ ./protea/
RUN poetry install --without dev --no-interaction --no-ansi

# ── Stage 2: runtime ────────────────────────────────────────────────────────
FROM python:3.12-slim

RUN apt-get update && apt-get install -y \
libpq5 \
&& rm -rf /var/lib/apt/lists/*

WORKDIR /app

# Copy installed packages from builder
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin

# Copy application code
COPY protea/ ./protea/
COPY scripts/ ./scripts/
COPY alembic/ ./alembic/
COPY alembic.ini ./

ENV PYTHONUNBUFFERED=1
EXPOSE 8000

HEALTHCHECK --interval=30s --timeout=5s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1

# Default: API server
# Override CMD to run a worker:
# docker run protea python scripts/worker.py --queue protea.jobs
CMD ["uvicorn", "protea.api.app:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["uvicorn", "protea.api.app:create_app", "--factory", "--host", "0.0.0.0", "--port", "8000"]
188 changes: 188 additions & 0 deletions RERANKER.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Temporal Holdout Re-Ranker for GO Term Prediction

## Motivación

El pipeline actual de PROTEA transfiere anotaciones GO mediante KNN sobre embeddings ESM, usando un scoring heurístico que combina distancia de embedding y pesos de evidencia. Este scoring no está optimizado para la métrica objetivo (Fmax) ni para el comportamiento real de las anotaciones GO a lo largo del tiempo.

La hipótesis central es que existe una señal aprendible: **dado el contexto de una predicción KNN, ¿acabará este GO term apareciendo en el siguiente release de GOA para esta proteína?** Esta señal puede extraerse directamente del mecanismo de holdout temporal que ya implementa PROTEA.

---

## Formulación del Problema

Sea $\mathcal{G}_N$ el conjunto de anotaciones GO en el release $N$ de GOA (Swiss-Prot reviewed). Para cada par consecutivo $(G_N, G_{N+1})$, el delta temporal es:

$$\Delta_{N \to N+1} = \{(p, t) \mid (p, t) \in \mathcal{G}_{N+1} \setminus \mathcal{G}_N\}$$

El re-ranker aprende una función:

$$f(q, t, \mathcal{N}_K(q)) \to \hat{y} \in [0, 1]$$

donde:
- $q$ es la proteína query (representada por su embedding ESM)
- $t$ es el GO term candidato
- $\mathcal{N}_K(q)$ es el conjunto de $K$ vecinos más cercanos en el espacio de embeddings con referencia $\mathcal{G}_N$
- $\hat{y}$ es la probabilidad de que $(q, t) \in \Delta_{N \to N+1}$

---

## Protocolo de Entrenamiento

Se utiliza validación cruzada temporal con múltiples splits históricos de GOA:

```
Training splits:
GOA_190 → GOA_195
GOA_195 → GOA_200
GOA_200 → GOA_205
GOA_205 → GOA_211
GOA_211 → GOA_215
GOA_215 → GOA_220

Test split (holdout estricto, nunca visto durante training):
GOA_220 → GOA_229
```

Para cada split se generan ejemplos etiquetados: positivos $(y=1)$ si el par (proteína, GO term) aparece en el delta, negativos $(y=0)$ en caso contrario. El desbalanceo esperado es aproximadamente 1:10, manejable con técnicas estándar.

---

## Arquitectura: Cross-Attention Re-Ranker

El modelo procesa cada par (query, GO term) usando el contexto completo de los vecinos KNN que contribuyeron a esa predicción.

```
Inputs por predicción (query_protein, go_term):
query_embedding float32[D] ESM embedding del query (D=480 para esmc_300m)
neighbor_embeddings float32[K × D] ESM embeddings de los K vecinos contribuyentes
tabular_features float32[K × F] distancia, evidencia, alineamiento, taxonomía...
go_term_embedding float32[G] embedding semántico del GO term (G=64)

Arquitectura:
1. query_proj(query_embedding) → q [H=256]
2. ref_proj(neighbor_embeddings) → tokens [K × H]
3. feature_encoder(tabular_features) → (sumado a tokens)
4. CrossAttention(q, tokens, tokens) → context [H]
5. MLP([q ‖ context ‖ go_emb ‖ agg_features]) → score [1]
```

La atención cruzada permite al modelo aprender **qué vecinos son más informativos para este query concreto**, en lugar de agregar los scores de forma heurística.

### GO Term Embeddings

Los embeddings de los GO terms se aprenden a partir de la estructura del DAG de GO (relaciones `is_a` / `part_of`) mediante Node2Vec o TransE, de forma que términos semánticamente relacionados (padre-hijo) tengan representaciones similares. El DAG ya está disponible en PROTEA a través de los modelos `GOTerm` y `GOTermRelationship`.

---

## Feature Vector

Cada predicción (query, GO term) se caracteriza por las siguientes features tabulares, computadas por vecino que contribuyó a la predicción:

| Feature | Descripción | Estado |
|---|---|---|
| `distance` | Distancia coseno en espacio de embeddings | Existente |
| `evidence_weight` | Peso del código de evidencia (IDA > IEA) | Existente |
| `identity_nw / sw` | Identidad de secuencia (alineamiento NW/SW) | Existente (opcional) |
| `similarity_nw / sw` | Similaridad de secuencia | Existente (opcional) |
| `taxonomic_distance` | Distancia taxonómica entre query y referencia | Existente (opcional) |
| `vote_count` | Número de vecinos que coinciden en este GO term | **Nuevo** |
| `k_position` | Posición del vecino más cercano que predijo este término | **Nuevo** |
| `go_term_frequency` | Frecuencia del término en el annotation set de referencia | **Nuevo** |
| `ref_annotation_density` | Número de GO terms de la proteína de referencia | **Nuevo** |
| `neighbor_distance_std` | Varianza de distancias a los K vecinos | **Nuevo** |

---

## Función de Pérdida

Se utiliza **LambdaRank** en lugar de binary cross-entropy, ya que optimiza directamente el orden de las predicciones (proxy de NDCG / Fmax) en lugar de la calibración de probabilidades.

Para cada proteína query, las predicciones GO se rankean conjuntamente:
- Positivos: GO terms en $\Delta_{N \to N+1}$
- Negativos: GO terms predichos pero no en el delta

---

## Pipeline de Datos: WebDataset

El volumen de datos (múltiples splits × ~1.35M predicciones por split × embeddings de 480 dim) requiere un pipeline de datos eficiente. Se propone almacenar los ejemplos de entrenamiento en formato **WebDataset** (shards tar), con un shard por split GOA:

```
reranker_data/
splits/
goa190_to_195.tar # ~2GB por shard
goa195_to_200.tar
...
goa220_to_229.tar # test split — no tocar durante training
models/
reranker_v1.pt
reranker_v1_config.json
```

Cada muestra en el WebDataset es **una proteína query** con todas sus predicciones GO para ese split:

```python
{
"query_accession": "P12345",
"query_embedding": float32[480],
"go_term_ids": ["GO:0006915", "GO:0005737", ...], # N_preds
"neighbor_embeddings": float32[N_preds, K, 480],
"tabular_features": float32[N_preds, K, F],
"labels": int8[N_preds], # 1 si en delta, 0 si no
}
```

El streaming de WebDataset permite entrenar sin cargar todo en RAM.

---

## Stack Tecnológico

| Componente | Tecnología |
|---|---|
| Modelo | PyTorch |
| Data pipeline | WebDataset + torch.utils.data |
| Baseline comparación | LightGBM (binary + LambdaRank) |
| GO embeddings | Node2Vec / PyTorch Geometric |
| Seguimiento experimentos | wandb |
| Embeddings proteína | ESM2 / ESMC (ya en PROTEA) |

---

## Integración en PROTEA

Una vez entrenado, el re-ranker se integra en el pipeline existente:

1. Nuevo modelo ORM `RerankingModel`: almacena pesos serializados y metadata de entrenamiento
2. Campo `reranker_id` (nullable) en `PredictionSet`
3. Si `reranker_id` presente: `store_predictions` aplica el modelo y sobreescribe `score` con $\hat{y}$
4. El threshold de Fmax se calcula igual que ahora sobre los nuevos scores
5. UI: selector de re-ranker en la pantalla de predicción

---

## Experimentos y Ablaciones

El diseño permite comparar directamente:

| Configuración | Descripción |
|---|---|
| **Baseline** | KNN + scoring heurístico actual |
| **LightGBM tabular** | Re-ranker con features tabulares sin embeddings |
| **LightGBM + derived** | Features tabulares + features derivadas del embedding (density, std) |
| **MLP cross-encoder** | Arquitectura completa sin cross-attention |
| **Cross-attention (propuesto)** | Arquitectura completa |
| **+ GO DAG embeddings** | Ablación: ¿aportan los go_term_emb? |
| **+ temporal CV** | Ablación: ¿mejora añadir más splits históricos? |

La métrica principal es **Fmax promedio sobre los 9 settings** (NK/LK/PK × BPO/MFO/CCO) en el test split GOA220→229.

---

## Valor para la Tesis

1. **Científicamente honesto**: el mismo mecanismo temporal que se usa para evaluar se usa para entrenar. No hay data leakage.
2. **Comprobable y cuantificable**: Fmax(baseline KNN) vs Fmax(re-ranker) en benchmark idéntico.
3. **Interpretable**: las feature importances (LightGBM) o los pesos de atención (cross-attention) revelan qué aspectos de una predicción KNN son más predictivos de anotaciones futuras.
4. **Generalizable**: el re-ranker aprende sobre distribuciones temporales de anotaciones GO, no sobre una proteína concreta — debería generalizar a proteínas no vistas.
5. **Extensible**: la arquitectura admite incorporar embeddings de secuencia de mayor calidad (ESM3, ProstT5) sin cambiar el pipeline.
9 changes: 4 additions & 5 deletions alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from logging.config import fileConfig
from pathlib import Path

from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy import engine_from_config, pool

from alembic import context

Expand All @@ -17,14 +16,14 @@

# Wire PROTEA's ORM metadata so autogenerate works.
# All model modules must be imported before Base.metadata is used.
from protea.infrastructure.orm.base import Base
import protea.infrastructure.orm.models # noqa: F401 — registers all mappers
import protea.infrastructure.orm.models # noqa: E402, F401 — registers all mappers
from protea.infrastructure.orm.base import Base # noqa: E402

target_metadata = Base.metadata

# Override the DB URL from PROTEA's settings rather than relying on the
# placeholder value in alembic.ini.
from protea.infrastructure.settings import load_settings
from protea.infrastructure.settings import load_settings # noqa: E402

_project_root = Path(__file__).resolve().parents[1]
_settings = load_settings(_project_root)
Expand Down
Loading
Loading