Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/copilot-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ ChemPropLightningModule.refit() ← or NoisyOracleModel (fast=True)
model.predict_smiles(unlabeled + ps_labeled) → pEC50 point estimates
CostAwareGreedyAcquisition.select(unlabeled, predictions, k,
CostAwareGreedyAcquisition.select(unlabeled, predictions,
plate_size, wells_per_ps, wells_per_drc,
ps_labeled_smiles, ps_labeled_predictions)
│ DRC score = p_active(ŷ) / cost_DRC [exploitation]
│ PS score = H_binary(p_cross(ŷ, T)) / cost_PS [exploration]
│ PS-labeled compounds: DRC-upgrade candidates only (no PS re-query)
│ Greedy top-k, one query per compound
│ Greedy by score; stops when next candidate would overflow plate_size wells
ActiveLearningLoop.run() → LoopResults
│ 3 Rich progress steps per iteration: query → refit → select
Expand Down Expand Up @@ -158,7 +159,7 @@ All campaign parameters live in `moal/config.py` as frozen dataclasses. The YAML
| `data.simulate:` | `SimulationDataConfig` | `input_csv`, `smiles_column`, `pec50_column`, `is_canonical`, `test_set_size`; nested `pretrain:` → `PretrainDataConfig` |
| `data.simulate.pretrain:` | `PretrainDataConfig` | `input_csv`, `smiles_column`, `relation_column`, `value_column`, `is_canonical`; same fields as `PlanDataConfig` minus `output_csv` |
| `data.plan:` | `PlanDataConfig` | `input_csv`, `output_csv`, `smiles_column`, `relation_column`, `value_column`, `is_canonical` |
| `active_learning_loop:` | `ActiveLearningLoopConfig` | `n_iterations`, `k_per_iteration` |
| `active_learning_loop:` | `ActiveLearningLoopConfig` | `n_iterations`, `plate_size`, `wells_per_ps`, `wells_per_drc` |
| *(top-level)* | `PipelineConfig` | `seed` |

### Fast mode (NoisyOracleModel)
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ The campaign emits a rich progress bar with `n_iterations × 3` discrete steps:
```
⠹ Iter 3/20 Querying oracle — 5 DRC (2 upgrades), 3 PS ████░░ 15% 0:00:12
⠹ Iter 3/20 Retraining model — 28 oracle + 60 pretrain records (DRC / PS) ████░░ 17% 0:00:45
⠹ Iter 3/20 Selecting next 10 — 18 unqueried, 3 PS hits eligible for upgrade ████░░ 18% 0:00:46
⠹ Iter 3/20 Selecting (plate=1536) — 18 unqueried, 3 PS hits eligible for upgrade ████░░ 18% 0:00:46
```

## Key Design Notes
Expand All @@ -178,10 +178,12 @@ The campaign emits a rich progress bar with `n_iterations × 3` discrete steps:
- Primary screen hit threshold T (`ps_threshold`): pEC50 ≈ 5.0
- Optimization target (`activity_threshold`): pEC50 = 7.0

**Acquisition strategy (greedy):** Each iteration scores two pools — unqueried compounds (eligible for PS or DRC) and PS-INTERVAL-labeled hits (eligible for DRC upgrade only) — on the same cost-normalised scale:
**Acquisition strategy (plate-budget greedy):** Each iteration scores two pools — unqueried compounds (eligible for PS or DRC) and PS-INTERVAL-labeled hits (eligible for DRC upgrade only) — on the same cost-normalised scale:
- `score(x, DRC) = sigmoid((ŷ - 7.0) / τ) / cost_DRC` — exploits likely actives; applies equally to first-pass DRC and upgrade-DRC candidates
- `score(x, PS) = H_binary(sigmoid((ŷ - T) / τ)) / cost_PS` — cheaply resolves threshold ambiguity; only generated for unqueried compounds

Candidates are selected in score order until adding the next would exceed `active_learning_loop.plate_size` wells (each PS query costs `wells_per_ps`, each DRC costs `wells_per_drc`). When the next candidate overflows the plate the loop hard-stops; remaining candidates are deferred to the next iteration and rescored on the updated model. To replicate a flat query count of k, set `plate_size=k`, `wells_per_ps=1`, `wells_per_drc=1`.

**Per-fidelity loss monitoring:** `training_step` and `validation_step` log `train_drc_loss`, `train_ps_loss`, `val_drc_loss`, and `val_ps_loss` separately (in addition to the aggregate `train_loss` / `val_loss`), making it possible to detect if DRC regression degrades while PS labels keep the total loss deceptively low.

## Output Files
Expand Down
6 changes: 5 additions & 1 deletion examples/default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ acquisition:
# --------------------------------------------------------------------------
active_learning_loop:
n_iterations: 20 # Number of active learning iterations (m)
k_per_iteration: 100 # Number of queries selected per iteration (k)
plate_size: 1536 # Maximum wells available per plate (per iteration).
# Selection stops when the next candidate would exceed this limit;
# remaining capacity is deferred to the next iteration.
wells_per_ps: 1 # Wells consumed by a single Primary Screen query (typically 1)
wells_per_drc: 13 # Wells consumed by a single DRC query (e.g., 13-point singlet DRC)

# --------------------------------------------------------------------------
# ChemProp / CheMeleon model
Expand Down
57 changes: 38 additions & 19 deletions moal/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _binary_entropy(p: np.ndarray) -> np.ndarray:


class CostAwareGreedyAcquisition:
"""Select k (compound, fidelity) query pairs per active-learning iteration.
"""Select (compound, fidelity) query pairs that fit within a plate budget.

Parameters
----------
Expand Down Expand Up @@ -181,29 +181,44 @@ def select(
self,
unlabeled_smiles: list[str],
predictions: np.ndarray,
k: int,
plate_size: int,
wells_per_ps: int,
wells_per_drc: int,
ps_labeled_smiles: list[str] | None = None,
ps_labeled_predictions: np.ndarray | None = None,
) -> list[tuple[str, QueryType]]:
"""Greedily select k (compound, fidelity) pairs.
"""Greedily select queries that fit within a plate well budget.

Two pools are considered:
Candidates are ranked by acquisition score (highest first) across two
pools:

- *Unqueried* compounds (``unlabeled_smiles``): eligible for either PS or
DRC. Both candidates enter the unified ranked list.
- *Unqueried* compounds (``unlabeled_smiles``): eligible for either PS
or DRC. Both candidates enter the unified ranked list.
- *PS-labeled* compounds (``ps_labeled_smiles``): already screened with
PS; eligible for a DRC upgrade only. Only a DRC candidate is
generated for each.

The loop walks the ranked list from highest to lowest score. When the
next candidate's well cost would push the running total above
``plate_size``, the loop stops and returns whatever has been selected so
far. No attempt is made to fill the remaining capacity with lower-ranked
candidates — unused wells are deferred to the next iteration, where all
candidates will be rescored on the updated labeled pool.

Parameters
----------
unlabeled_smiles : list[str]
Ground-truth keys for all unqueried compounds.
predictions : np.ndarray
Model pEC50 point estimates, shape ``(N,)``, aligned with
``unlabeled_smiles``.
k : int
Number of queries to select.
plate_size : int
Maximum total wells available on the plate. Selection stops as
soon as adding the next candidate would exceed this limit.
wells_per_ps : int
Number of wells consumed by a single PS query.
wells_per_drc : int
Number of wells consumed by a single DRC query.
ps_labeled_smiles : list[str], optional
Ground-truth keys for compounds that have a PS label but no DRC
label (i.e., INTERVAL-censored hits eligible for a full
Expand All @@ -216,7 +231,8 @@ def select(
Returns
-------
list[tuple[str, QueryType]]
Ordered list of (smiles, QueryType) pairs, highest-scoring first.
Ordered list of (smiles, QueryType) pairs, highest-scoring first,
whose cumulative well cost does not exceed ``plate_size``.

Raises
------
Expand All @@ -229,7 +245,7 @@ def select(
logger.warning("No unlabeled compounds available for acquisition.")
return []

if k == 0:
if plate_size == 0:
return []

predictions = np.asarray(predictions, dtype=np.float32)
Expand Down Expand Up @@ -271,22 +287,25 @@ def select(

selected: list[tuple[str, QueryType]] = []
selected_smiles: set[str] = set()
wells_used = 0
for _score, smi, qt in candidates:
if smi in selected_smiles:
continue
cost = wells_per_drc if qt == QueryType.DOSE_RESPONSE else wells_per_ps
if wells_used + cost > plate_size:
break
selected.append((smi, qt))
selected_smiles.add(smi)
if len(selected) >= k:
break
wells_used += cost

if len(selected) < k:
if wells_used == 0 and (unlabeled_smiles or ps_labeled_smiles):
logger.warning(
"Could only select %d queries (requested %d); "
"%d unqueried and %d PS-labeled compounds available.",
len(selected),
k,
len(unlabeled_smiles),
len(ps_labeled_smiles),
"No candidates fit within plate_size=%d "
"(wells_per_ps=%d, wells_per_drc=%d). "
"No queries selected for this iteration.",
plate_size,
wells_per_ps,
wells_per_drc,
)
return selected

Expand Down
4 changes: 3 additions & 1 deletion moal/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def simulate(config: Path, output_dir: Path | None, verbose: bool) -> None:
try:
results = loop.run(
n_iterations=cfg.active_learning_loop.n_iterations,
k_per_iteration=cfg.active_learning_loop.k_per_iteration,
plate_size=cfg.active_learning_loop.plate_size,
wells_per_ps=cfg.active_learning_loop.wells_per_ps,
wells_per_drc=cfg.active_learning_loop.wells_per_drc,
)
except ValueError as exc:
raise click.ClickException(str(exc)) from exc
Expand Down
20 changes: 17 additions & 3 deletions moal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,26 @@ class ActiveLearningLoopConfig:
----------
n_iterations : int
Number of active learning iterations (m).
k_per_iteration : int
Number of oracle queries issued per iteration (k).
plate_size : int
Maximum number of wells available per plate (i.e., per iteration).
The acquisition greedily selects ranked candidates in score order,
stopping as soon as the next candidate would push the total well
count over this limit. Any remaining plate capacity is accepted
and the unused candidates are deferred to the next iteration, where
the model will be re-scored on the updated labeled pool.
wells_per_ps : int
Number of wells consumed by a single Primary Screen (PS) query.
Typically 1 for a singlet primary screen.
wells_per_drc : int
Number of wells consumed by a single Dose-Response Curve (DRC) query.
For example, a 13-point DRC in duplicate consumes 26 wells; a compact
8-point singlet DRC consumes 8.
"""

n_iterations: int = 20
k_per_iteration: int = 10
plate_size: int = 1536
wells_per_ps: int = 1
wells_per_drc: int = 13


@dataclass(frozen=True)
Expand Down
21 changes: 21 additions & 0 deletions moal/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ def update(
iter_drc_cost: float,
iter_ps_cost: float,
iter_upgrade_cost: float = 0.0,
iter_n_drc_new: int = 0,
iter_n_upgrades: int = 0,
iter_n_ps: int = 0,
model_metric_value: float | None = None,
) -> None:
"""Append iteration data for live display and deferred GIF/HTML export.
Expand All @@ -363,6 +366,12 @@ def update(
Total PS cost incurred in the current iteration.
iter_upgrade_cost : float, optional
Portion of ``iter_drc_cost`` attributable to PS→DRC upgrades.
iter_n_drc_new : int, optional
Number of new (first-pass) DRC queries issued this iteration.
iter_n_upgrades : int, optional
Number of PS→DRC upgrade queries issued this iteration.
iter_n_ps : int, optional
Number of PS queries issued this iteration.
model_metric_value : float, optional
Held-out test-set metric for this iteration, or None if unavailable.
"""
Expand Down Expand Up @@ -390,6 +399,9 @@ def update(
"iter_drc_cost": iter_drc_cost,
"iter_ps_cost": iter_ps_cost,
"iter_upgrade_cost": iter_upgrade_cost,
"iter_n_drc_new": iter_n_drc_new,
"iter_n_upgrades": iter_n_upgrades,
"iter_n_ps": iter_n_ps,
"model_metric_value": model_metric_value,
"n_ps_only": n_ps_only,
"n_drc_new": n_drc_new,
Expand Down Expand Up @@ -841,6 +853,9 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure:
iter_drc_new = [it["iter_drc_cost"] - it["iter_upgrade_cost"] for it in iterations]
iter_upgrades = [it["iter_upgrade_cost"] for it in iterations]
iter_ps = [it["iter_ps_cost"] for it in iterations]
iter_n_drc_new = [it.get("iter_n_drc_new", 0) for it in iterations]
iter_n_upgrades = [it.get("iter_n_upgrades", 0) for it in iterations]
iter_n_ps = [it.get("iter_n_ps", 0) for it in iterations]
# Scale to thousands so secondary y-axis ticks stay compact whole-number integers
cum_total_costs = [
c / 1000
Expand Down Expand Up @@ -880,6 +895,8 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure:
y=iter_drc_new,
name="DRC",
marker_color=_COLOUR_DRC,
customdata=iter_n_drc_new,
hovertemplate="<b>DRC</b><br>%{customdata} queries<br>$%{y:.0f}<extra></extra>",
showlegend=False,
),
row=1,
Expand All @@ -891,6 +908,8 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure:
y=iter_upgrades,
name="PS→DRC",
marker_color=_COLOUR_UPGRADE,
customdata=iter_n_upgrades,
hovertemplate="<b>PS→DRC</b><br>%{customdata} upgrades<br>$%{y:.0f}<extra></extra>",
showlegend=False,
),
row=1,
Expand All @@ -902,6 +921,8 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure:
y=iter_ps,
name="PS",
marker_color=_COLOUR_PS,
customdata=iter_n_ps,
hovertemplate="<b>PS</b><br>%{customdata} queries<br>$%{y:.0f}<extra></extra>",
showlegend=False,
),
row=1,
Expand Down
Loading