diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md
index 441791d..fcd42cd 100644
--- a/.github/copilot-instructions.md
+++ b/.github/copilot-instructions.md
@@ -68,12 +68,13 @@ ChemPropLightningModule.refit() ← or NoisyOracleModel (fast=True)
model.predict_smiles(unlabeled + ps_labeled) → pEC50 point estimates
│
▼
-CostAwareGreedyAcquisition.select(unlabeled, predictions, k,
+CostAwareGreedyAcquisition.select(unlabeled, predictions,
+ plate_size, wells_per_ps, wells_per_drc,
ps_labeled_smiles, ps_labeled_predictions)
│ DRC score = p_active(ŷ) / cost_DRC [exploitation]
│ PS score = H_binary(p_cross(ŷ, T)) / cost_PS [exploration]
│ PS-labeled compounds: DRC-upgrade candidates only (no PS re-query)
- │ Greedy top-k, one query per compound
+ │ Greedy by score; stops when next candidate would overflow plate_size wells
▼
ActiveLearningLoop.run() → LoopResults
│ 3 Rich progress steps per iteration: query → refit → select
@@ -158,7 +159,7 @@ All campaign parameters live in `moal/config.py` as frozen dataclasses. The YAML
| `data.simulate:` | `SimulationDataConfig` | `input_csv`, `smiles_column`, `pec50_column`, `is_canonical`, `test_set_size`; nested `pretrain:` → `PretrainDataConfig` |
| `data.simulate.pretrain:` | `PretrainDataConfig` | `input_csv`, `smiles_column`, `relation_column`, `value_column`, `is_canonical`; same fields as `PlanDataConfig` minus `output_csv` |
| `data.plan:` | `PlanDataConfig` | `input_csv`, `output_csv`, `smiles_column`, `relation_column`, `value_column`, `is_canonical` |
-| `active_learning_loop:` | `ActiveLearningLoopConfig` | `n_iterations`, `k_per_iteration` |
+| `active_learning_loop:` | `ActiveLearningLoopConfig` | `n_iterations`, `plate_size`, `wells_per_ps`, `wells_per_drc` |
| *(top-level)* | `PipelineConfig` | `seed` |
### Fast mode (NoisyOracleModel)
diff --git a/README.md b/README.md
index 3671577..f2a7d50 100644
--- a/README.md
+++ b/README.md
@@ -162,7 +162,7 @@ The campaign emits a rich progress bar with `n_iterations × 3` discrete steps:
```
⠹ Iter 3/20 Querying oracle — 5 DRC (2 upgrades), 3 PS ████░░ 15% 0:00:12
⠹ Iter 3/20 Retraining model — 28 oracle + 60 pretrain records (DRC / PS) ████░░ 17% 0:00:45
- ⠹ Iter 3/20 Selecting next 10 — 18 unqueried, 3 PS hits eligible for upgrade ████░░ 18% 0:00:46
+ ⠹ Iter 3/20 Selecting (plate=1536) — 18 unqueried, 3 PS hits eligible for upgrade ████░░ 18% 0:00:46
```
## Key Design Notes
@@ -178,10 +178,12 @@ The campaign emits a rich progress bar with `n_iterations × 3` discrete steps:
- Primary screen hit threshold T (`ps_threshold`): pEC50 ≈ 5.0
- Optimization target (`activity_threshold`): pEC50 = 7.0
-**Acquisition strategy (greedy):** Each iteration scores two pools — unqueried compounds (eligible for PS or DRC) and PS-INTERVAL-labeled hits (eligible for DRC upgrade only) — on the same cost-normalised scale:
+**Acquisition strategy (plate-budget greedy):** Each iteration scores two pools — unqueried compounds (eligible for PS or DRC) and PS-INTERVAL-labeled hits (eligible for DRC upgrade only) — on the same cost-normalised scale:
- `score(x, DRC) = sigmoid((ŷ - 7.0) / τ) / cost_DRC` — exploits likely actives; applies equally to first-pass DRC and upgrade-DRC candidates
- `score(x, PS) = H_binary(sigmoid((ŷ - T) / τ)) / cost_PS` — cheaply resolves threshold ambiguity; only generated for unqueried compounds
+Candidates are selected in score order until adding the next would exceed `active_learning_loop.plate_size` wells (each PS query costs `wells_per_ps`, each DRC costs `wells_per_drc`). When the next candidate overflows the plate the loop hard-stops; remaining candidates are deferred to the next iteration and rescored on the updated model. To replicate a flat query count of k, set `plate_size=k`, `wells_per_ps=1`, `wells_per_drc=1`.
+
**Per-fidelity loss monitoring:** `training_step` and `validation_step` log `train_drc_loss`, `train_ps_loss`, `val_drc_loss`, and `val_ps_loss` separately (in addition to the aggregate `train_loss` / `val_loss`), making it possible to detect if DRC regression degrades while PS labels keep the total loss deceptively low.
## Output Files
diff --git a/examples/default_config.yaml b/examples/default_config.yaml
index 69f8147..2377d6a 100644
--- a/examples/default_config.yaml
+++ b/examples/default_config.yaml
@@ -80,7 +80,11 @@ acquisition:
# --------------------------------------------------------------------------
active_learning_loop:
n_iterations: 20 # Number of active learning iterations (m)
- k_per_iteration: 100 # Number of queries selected per iteration (k)
+ plate_size: 1536 # Maximum wells available per plate (per iteration).
+ # Selection stops when the next candidate would exceed this limit;
+ # remaining capacity is deferred to the next iteration.
+ wells_per_ps: 1 # Wells consumed by a single Primary Screen query (typically 1)
+ wells_per_drc: 13 # Wells consumed by a single DRC query (e.g., 13-point singlet DRC)
# --------------------------------------------------------------------------
# ChemProp / CheMeleon model
diff --git a/moal/acquisition.py b/moal/acquisition.py
index e762ce1..02437e6 100644
--- a/moal/acquisition.py
+++ b/moal/acquisition.py
@@ -86,7 +86,7 @@ def _binary_entropy(p: np.ndarray) -> np.ndarray:
class CostAwareGreedyAcquisition:
- """Select k (compound, fidelity) query pairs per active-learning iteration.
+ """Select (compound, fidelity) query pairs that fit within a plate budget.
Parameters
----------
@@ -181,20 +181,30 @@ def select(
self,
unlabeled_smiles: list[str],
predictions: np.ndarray,
- k: int,
+ plate_size: int,
+ wells_per_ps: int,
+ wells_per_drc: int,
ps_labeled_smiles: list[str] | None = None,
ps_labeled_predictions: np.ndarray | None = None,
) -> list[tuple[str, QueryType]]:
- """Greedily select k (compound, fidelity) pairs.
+ """Greedily select queries that fit within a plate well budget.
- Two pools are considered:
+ Candidates are ranked by acquisition score (highest first) across two
+ pools:
- - *Unqueried* compounds (``unlabeled_smiles``): eligible for either PS or
- DRC. Both candidates enter the unified ranked list.
+ - *Unqueried* compounds (``unlabeled_smiles``): eligible for either PS
+ or DRC. Both candidates enter the unified ranked list.
- *PS-labeled* compounds (``ps_labeled_smiles``): already screened with
PS; eligible for a DRC upgrade only. Only a DRC candidate is
generated for each.
+ The loop walks the ranked list from highest to lowest score. When the
+ next candidate's well cost would push the running total above
+ ``plate_size``, the loop stops and returns whatever has been selected so
+ far. No attempt is made to fill the remaining capacity with lower-ranked
+ candidates — unused wells are deferred to the next iteration, where all
+ candidates will be rescored on the updated labeled pool.
+
Parameters
----------
unlabeled_smiles : list[str]
@@ -202,8 +212,13 @@ def select(
predictions : np.ndarray
Model pEC50 point estimates, shape ``(N,)``, aligned with
``unlabeled_smiles``.
- k : int
- Number of queries to select.
+ plate_size : int
+ Maximum total wells available on the plate. Selection stops as
+ soon as adding the next candidate would exceed this limit.
+ wells_per_ps : int
+ Number of wells consumed by a single PS query.
+ wells_per_drc : int
+ Number of wells consumed by a single DRC query.
ps_labeled_smiles : list[str], optional
Ground-truth keys for compounds that have a PS label but no DRC
label (i.e., INTERVAL-censored hits eligible for a full
@@ -216,7 +231,8 @@ def select(
Returns
-------
list[tuple[str, QueryType]]
- Ordered list of (smiles, QueryType) pairs, highest-scoring first.
+ Ordered list of (smiles, QueryType) pairs, highest-scoring first,
+ whose cumulative well cost does not exceed ``plate_size``.
Raises
------
@@ -229,7 +245,7 @@ def select(
logger.warning("No unlabeled compounds available for acquisition.")
return []
- if k == 0:
+ if plate_size == 0:
return []
predictions = np.asarray(predictions, dtype=np.float32)
@@ -271,22 +287,25 @@ def select(
selected: list[tuple[str, QueryType]] = []
selected_smiles: set[str] = set()
+ wells_used = 0
for _score, smi, qt in candidates:
if smi in selected_smiles:
continue
+ cost = wells_per_drc if qt == QueryType.DOSE_RESPONSE else wells_per_ps
+ if wells_used + cost > plate_size:
+ break
selected.append((smi, qt))
selected_smiles.add(smi)
- if len(selected) >= k:
- break
+ wells_used += cost
- if len(selected) < k:
+ if wells_used == 0 and (unlabeled_smiles or ps_labeled_smiles):
logger.warning(
- "Could only select %d queries (requested %d); "
- "%d unqueried and %d PS-labeled compounds available.",
- len(selected),
- k,
- len(unlabeled_smiles),
- len(ps_labeled_smiles),
+ "No candidates fit within plate_size=%d "
+ "(wells_per_ps=%d, wells_per_drc=%d). "
+ "No queries selected for this iteration.",
+ plate_size,
+ wells_per_ps,
+ wells_per_drc,
)
return selected
diff --git a/moal/cli.py b/moal/cli.py
index c46ea37..1da7d9e 100644
--- a/moal/cli.py
+++ b/moal/cli.py
@@ -235,7 +235,9 @@ def simulate(config: Path, output_dir: Path | None, verbose: bool) -> None:
try:
results = loop.run(
n_iterations=cfg.active_learning_loop.n_iterations,
- k_per_iteration=cfg.active_learning_loop.k_per_iteration,
+ plate_size=cfg.active_learning_loop.plate_size,
+ wells_per_ps=cfg.active_learning_loop.wells_per_ps,
+ wells_per_drc=cfg.active_learning_loop.wells_per_drc,
)
except ValueError as exc:
raise click.ClickException(str(exc)) from exc
diff --git a/moal/config.py b/moal/config.py
index 31ac564..f471454 100644
--- a/moal/config.py
+++ b/moal/config.py
@@ -379,12 +379,26 @@ class ActiveLearningLoopConfig:
----------
n_iterations : int
Number of active learning iterations (m).
- k_per_iteration : int
- Number of oracle queries issued per iteration (k).
+ plate_size : int
+ Maximum number of wells available per plate (i.e., per iteration).
+ The acquisition greedily selects ranked candidates in score order,
+ stopping as soon as the next candidate would push the total well
+ count over this limit. Any remaining plate capacity is accepted
+ and the unused candidates are deferred to the next iteration, where
+ the model will be re-scored on the updated labeled pool.
+ wells_per_ps : int
+ Number of wells consumed by a single Primary Screen (PS) query.
+ Typically 1 for a singlet primary screen.
+ wells_per_drc : int
+ Number of wells consumed by a single Dose-Response Curve (DRC) query.
+ For example, a 13-point DRC in duplicate consumes 26 wells; a compact
+ 8-point singlet DRC consumes 8.
"""
n_iterations: int = 20
- k_per_iteration: int = 10
+ plate_size: int = 1536
+ wells_per_ps: int = 1
+ wells_per_drc: int = 13
@dataclass(frozen=True)
diff --git a/moal/dashboard.py b/moal/dashboard.py
index 99cdc30..27bd40c 100644
--- a/moal/dashboard.py
+++ b/moal/dashboard.py
@@ -347,6 +347,9 @@ def update(
iter_drc_cost: float,
iter_ps_cost: float,
iter_upgrade_cost: float = 0.0,
+ iter_n_drc_new: int = 0,
+ iter_n_upgrades: int = 0,
+ iter_n_ps: int = 0,
model_metric_value: float | None = None,
) -> None:
"""Append iteration data for live display and deferred GIF/HTML export.
@@ -363,6 +366,12 @@ def update(
Total PS cost incurred in the current iteration.
iter_upgrade_cost : float, optional
Portion of ``iter_drc_cost`` attributable to PS→DRC upgrades.
+ iter_n_drc_new : int, optional
+ Number of new (first-pass) DRC queries issued this iteration.
+ iter_n_upgrades : int, optional
+ Number of PS→DRC upgrade queries issued this iteration.
+ iter_n_ps : int, optional
+ Number of PS queries issued this iteration.
model_metric_value : float, optional
Held-out test-set metric for this iteration, or None if unavailable.
"""
@@ -390,6 +399,9 @@ def update(
"iter_drc_cost": iter_drc_cost,
"iter_ps_cost": iter_ps_cost,
"iter_upgrade_cost": iter_upgrade_cost,
+ "iter_n_drc_new": iter_n_drc_new,
+ "iter_n_upgrades": iter_n_upgrades,
+ "iter_n_ps": iter_n_ps,
"model_metric_value": model_metric_value,
"n_ps_only": n_ps_only,
"n_drc_new": n_drc_new,
@@ -841,6 +853,9 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure:
iter_drc_new = [it["iter_drc_cost"] - it["iter_upgrade_cost"] for it in iterations]
iter_upgrades = [it["iter_upgrade_cost"] for it in iterations]
iter_ps = [it["iter_ps_cost"] for it in iterations]
+ iter_n_drc_new = [it.get("iter_n_drc_new", 0) for it in iterations]
+ iter_n_upgrades = [it.get("iter_n_upgrades", 0) for it in iterations]
+ iter_n_ps = [it.get("iter_n_ps", 0) for it in iterations]
# Scale to thousands so secondary y-axis ticks stay compact whole-number integers
cum_total_costs = [
c / 1000
@@ -880,6 +895,8 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure:
y=iter_drc_new,
name="DRC",
marker_color=_COLOUR_DRC,
+ customdata=iter_n_drc_new,
+ hovertemplate="DRC
%{customdata} queries
$%{y:.0f}",
showlegend=False,
),
row=1,
@@ -891,6 +908,8 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure:
y=iter_upgrades,
name="PS→DRC",
marker_color=_COLOUR_UPGRADE,
+ customdata=iter_n_upgrades,
+ hovertemplate="PS→DRC
%{customdata} upgrades
$%{y:.0f}",
showlegend=False,
),
row=1,
@@ -902,6 +921,8 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure:
y=iter_ps,
name="PS",
marker_color=_COLOUR_PS,
+ customdata=iter_n_ps,
+ hovertemplate="PS
%{customdata} queries
$%{y:.0f}",
showlegend=False,
),
row=1,
diff --git a/moal/loop.py b/moal/loop.py
index 1419fa0..552214b 100644
--- a/moal/loop.py
+++ b/moal/loop.py
@@ -201,14 +201,20 @@ def __init__(
# Main entry point
# ------------------------------------------------------------------
- def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults:
+ def run(
+ self,
+ n_iterations: int,
+ plate_size: int,
+ wells_per_ps: int,
+ wells_per_drc: int,
+ ) -> LoopResults:
"""Execute the full active learning campaign.
Each iteration completes three sequential steps tracked in the Rich
progress bar:
- 1. **Query oracle** — issue ``k`` queries from the pre-computed
- candidate list assembled at the end of the previous iteration.
+ 1. **Query oracle** — issue queries from the pre-computed candidate
+ list assembled at the end of the previous iteration.
2. **Refit model** — fine-tune the model on the growing labeled pool.
3. **Select compounds** — run model inference and acquisition scoring
over the remaining pool to prepare the next iteration's queries.
@@ -217,8 +223,15 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults:
----------
n_iterations : int
Total number of active learning iterations to run.
- k_per_iteration : int
- Number of oracle queries to issue per iteration.
+ plate_size : int
+ Maximum total wells available per plate (i.e., per iteration).
+ The acquisition greedily selects ranked candidates in score order,
+ stopping as soon as the next candidate would push the total well
+ count over this limit.
+ wells_per_ps : int
+ Number of wells consumed by a single PS query.
+ wells_per_drc : int
+ Number of wells consumed by a single DRC query.
Returns
-------
@@ -255,8 +268,9 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults:
_console.print(
f"[bold]moal[/bold] campaign starting — "
- f"[cyan]{n_iterations}[/cyan] iterations × "
- f"[cyan]{k_per_iteration}[/cyan] queries | "
+ f"[cyan]{n_iterations}[/cyan] iterations | "
+ f"plate: [cyan]{plate_size}[/cyan] wells "
+ f"([cyan]{wells_per_ps}[/cyan] PS / [cyan]{wells_per_drc}[/cyan] DRC) | "
f"[cyan]{n_total}[/cyan] compounds | "
f"[cyan]{n_true_actives}[/cyan] true actives"
)
@@ -283,7 +297,9 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults:
self.acquisition.select(
unlabeled,
unlabeled_preds,
- k_per_iteration,
+ plate_size,
+ wells_per_ps,
+ wells_per_drc,
ps_labeled_smiles=ps_labeled,
ps_labeled_predictions=ps_labeled_preds if ps_labeled else None,
)
@@ -359,6 +375,21 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults:
if r.fidelity == QueryType.DOSE_RESPONSE
and r.canonical_smiles in ps_labeled_before
)
+ iter_n_ps = sum(
+ 1 for r in new_records if r.fidelity == QueryType.PRIMARY_SCREEN
+ )
+ iter_n_drc_upgrade = sum(
+ 1
+ for r in new_records
+ if r.fidelity == QueryType.DOSE_RESPONSE
+ and r.canonical_smiles in ps_labeled_before
+ )
+ iter_n_drc_new = sum(
+ 1
+ for r in new_records
+ if r.fidelity == QueryType.DOSE_RESPONSE
+ and r.canonical_smiles not in ps_labeled_before
+ )
progress.advance(task)
# --- Refit model --------------------------------------
@@ -434,7 +465,7 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults:
task,
description=(
f"[green]Iter {iteration + 1}/{n_iterations}[/green] "
- f"Selecting next {k_per_iteration} — "
+ f"Selecting (plate={plate_size}) — "
f"[white]{len(remaining_unlabeled)} unqueried[/white], "
f"[magenta]{len(remaining_ps_labeled)} PS hits[/magenta]"
" eligible for upgrade"
@@ -450,7 +481,9 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults:
pending_queries = self.acquisition.select(
remaining_unlabeled,
unlabeled_preds,
- k_per_iteration,
+ plate_size,
+ wells_per_ps,
+ wells_per_drc,
ps_labeled_smiles=remaining_ps_labeled,
ps_labeled_predictions=ps_labeled_preds
if remaining_ps_labeled
@@ -490,6 +523,9 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults:
iter_drc_cost=iter_drc_cost,
iter_ps_cost=iter_ps_cost,
iter_upgrade_cost=iter_upgrade_cost,
+ iter_n_drc_new=iter_n_drc_new,
+ iter_n_upgrades=iter_n_drc_upgrade,
+ iter_n_ps=iter_n_ps,
model_metric_value=model_metric_value,
)
diff --git a/tests/test_acquisition.py b/tests/test_acquisition.py
index da8523a..68e61bd 100644
--- a/tests/test_acquisition.py
+++ b/tests/test_acquisition.py
@@ -82,10 +82,10 @@ class TestSelect:
"""Integration tests for CostAwareGreedyAcquisition.select()."""
def test_returns_k_unique_queries(self, acq):
- """Select must return exactly k pairs with no repeated SMILES, since a compound should only be queried once."""
+ """Select must return no repeated SMILES, since a compound should only be queried once."""
smiles = [f"C{i}" for i in range(20)]
preds = np.random.default_rng(0).normal(6.0, 1.5, 20).astype(np.float32)
- selected = acq.select(smiles, preds, k=5)
+ selected = acq.select(smiles, preds, plate_size=5, wells_per_ps=1, wells_per_drc=1)
assert len(selected) == 5
selected_smiles = [s for s, _ in selected]
assert len(selected_smiles) == len(set(selected_smiles))
@@ -94,7 +94,7 @@ def test_fidelity_types_are_valid(self, acq):
"""Every selection must be either PS or DRC — no other query type should ever be emitted."""
smiles = [f"C{i}" for i in range(10)]
preds = np.ones(10, dtype=np.float32) * 6.0
- selected = acq.select(smiles, preds, k=8)
+ selected = acq.select(smiles, preds, plate_size=8, wells_per_ps=1, wells_per_drc=1)
for _, qt in selected:
assert qt in (QueryType.PRIMARY_SCREEN, QueryType.DOSE_RESPONSE)
@@ -102,32 +102,34 @@ def test_high_pec50_prefers_drc(self, acq):
"""Compounds with very high predicted pEC50 should prefer DRC."""
smiles = ["high", "low"]
preds = np.array([9.5, 3.0], dtype=np.float32)
- selected = acq.select(smiles, preds, k=1)
+ selected = acq.select(smiles, preds, plate_size=1, wells_per_ps=1, wells_per_drc=1)
assert selected[0] == ("high", QueryType.DOSE_RESPONSE)
def test_at_threshold_prefers_ps(self, acq):
"""Compounds at the PS threshold (max entropy) should prefer PS when cheap."""
smiles = ["at_threshold"]
preds = np.array([5.0], dtype=np.float32)
- selected = acq.select(smiles, preds, k=1)
+ selected = acq.select(smiles, preds, plate_size=1, wells_per_ps=1, wells_per_drc=1)
assert selected[0][1] == QueryType.PRIMARY_SCREEN
@pytest.mark.parametrize(
- "smiles,preds,k",
+ "smiles,preds,plate_size",
[
([], np.array([]), 5),
([f"C{i}" for i in range(10)], np.ones(10, dtype=np.float32) * 6.0, 0),
],
)
- def test_empty_selection_returns_empty(self, acq, smiles, preds, k):
- """An empty pool or k=0 must return [] without error, as there is nothing to select."""
- assert acq.select(smiles, preds, k=k) == []
+ def test_empty_selection_returns_empty(self, acq, smiles, preds, plate_size):
+ """An empty pool or plate_size=0 must return [] without error, as there is nothing to select."""
+ assert (
+ acq.select(smiles, preds, plate_size=plate_size, wells_per_ps=1, wells_per_drc=1) == []
+ )
- def test_k_larger_than_pool(self, acq):
- """When k exceeds the pool size, all available compounds must be returned rather than raising."""
+ def test_plate_larger_than_pool(self, acq):
+ """When plate_size exceeds the pool's total well cost, all available compounds must be returned."""
smiles = ["A", "B"]
preds = np.array([5.0, 6.0], dtype=np.float32)
- selected = acq.select(smiles, preds, k=100)
+ selected = acq.select(smiles, preds, plate_size=100, wells_per_ps=1, wells_per_drc=1)
assert len(selected) == 2 # limited by pool size
def test_invalid_cost_raises(self):
@@ -148,7 +150,7 @@ def test_degenerate_thresholds_still_selects(self, acq):
)
smiles = [f"C{i}" for i in range(5)]
preds = np.array([5.0, 6.0, 7.0, 8.0, 9.0], dtype=np.float32)
- selected = degenerate.select(smiles, preds, k=3)
+ selected = degenerate.select(smiles, preds, plate_size=3, wells_per_ps=1, wells_per_drc=1)
assert len(selected) == 3
for _, qt in selected:
assert qt in (QueryType.PRIMARY_SCREEN, QueryType.DOSE_RESPONSE)
@@ -162,7 +164,48 @@ def test_select_with_nan_predictions(self, acq):
smiles = ["A", "B", "C"]
preds = np.array([float("nan"), 6.0, 7.0], dtype=np.float32)
with pytest.raises(ValueError, match="finite"):
- acq.select(smiles, preds, k=2)
+ acq.select(smiles, preds, plate_size=2, wells_per_ps=1, wells_per_drc=1)
+
+ def test_drc_cost_stops_at_plate_boundary(self, acq):
+ """Selection must stop when the next DRC candidate would overflow the plate.
+
+ With plate_size=14, wells_per_drc=13, and wells_per_ps=1, the top
+ candidate (DRC, 13 wells) fills most of the plate. The second candidate
+ is also a DRC (13 wells), which would push total wells to 26 > 14, so
+ the loop must hard-stop and return only the first compound.
+ """
+ acq_asymmetric = CostAwareGreedyAcquisition(
+ cost_ps=1.0,
+ cost_drc=10.0,
+ ps_threshold=5.0,
+ target_threshold=7.0,
+ tau=0.5,
+ )
+ smiles = ["high1", "high2", "low"]
+ # Both high compounds score heavily for DRC; low scores for PS
+ preds = np.array([9.5, 9.0, 3.0], dtype=np.float32)
+ selected = acq_asymmetric.select(
+ smiles, preds, plate_size=14, wells_per_ps=1, wells_per_drc=13
+ )
+ # First candidate: DRC for "high1" (13 wells used; 1 well left)
+ # Second candidate: DRC for "high2" would use 13 more → 26 > 14 → stop
+ assert len(selected) == 1
+ assert selected[0] == ("high1", QueryType.DOSE_RESPONSE)
+
+ def test_wells_used_never_exceeds_plate_size(self):
+ """Total wells consumed by the selection must never exceed plate_size."""
+ acq = CostAwareGreedyAcquisition(cost_ps=1.0, cost_drc=10.0)
+ rng = np.random.default_rng(7)
+ smiles = [f"C{i}" for i in range(50)]
+ preds = rng.normal(6.0, 1.5, 50).astype(np.float32)
+ plate_size, wells_ps, wells_drc = 100, 1, 13
+ selected = acq.select(
+ smiles, preds, plate_size=plate_size, wells_per_ps=wells_ps, wells_per_drc=wells_drc
+ )
+ total_wells = sum(
+ wells_drc if qt == QueryType.DOSE_RESPONSE else wells_ps for _, qt in selected
+ )
+ assert total_wells <= plate_size
class TestPSUpgradeCandidates:
@@ -186,7 +229,9 @@ def test_ps_labeled_compounds_only_generate_drc_candidates(self, acq):
selected = acq.select(
[],
np.array([]),
- 1,
+ plate_size=1,
+ wells_per_ps=1,
+ wells_per_drc=1,
ps_labeled_smiles=ps_smiles,
ps_labeled_predictions=ps_preds,
)
@@ -202,7 +247,9 @@ def test_ps_labeled_and_unlabeled_compete_correctly(self, acq):
selected = acq.select(
unlabeled,
unlabeled_preds,
- 1,
+ plate_size=1,
+ wells_per_ps=1,
+ wells_per_drc=1,
ps_labeled_smiles=ps_labeled,
ps_labeled_predictions=ps_labeled_preds,
)
@@ -212,9 +259,15 @@ def test_empty_ps_labeled_behaves_like_no_pool(self, acq):
"""Passing ps_labeled_smiles=[] must not change behaviour vs omitting the argument."""
smiles = [f"C{i}" for i in range(5)]
preds = np.ones(5, dtype=np.float32) * 6.0
- without_pool = acq.select(smiles, preds, 3)
+ without_pool = acq.select(smiles, preds, plate_size=3, wells_per_ps=1, wells_per_drc=1)
with_empty_pool = acq.select(
- smiles, preds, 3, ps_labeled_smiles=[], ps_labeled_predictions=None
+ smiles,
+ preds,
+ plate_size=3,
+ wells_per_ps=1,
+ wells_per_drc=1,
+ ps_labeled_smiles=[],
+ ps_labeled_predictions=None,
)
assert without_pool == with_empty_pool
@@ -224,7 +277,9 @@ def test_no_smiles_length_mismatch_assertion(self, acq):
acq.select(
[],
np.array([]),
- 1,
+ plate_size=1,
+ wells_per_ps=1,
+ wells_per_drc=1,
ps_labeled_smiles=["A", "B"],
ps_labeled_predictions=np.array([1.0]),
)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index eb91756..5f5dea0 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -227,7 +227,9 @@ def test_custom_column_names_accepted(self, tmp_path):
" enabled: false\n"
"active_learning_loop:\n"
" n_iterations: 1\n"
- " k_per_iteration: 1\n"
+ " plate_size: 1\n"
+ " wells_per_ps: 1\n"
+ " wells_per_drc: 1\n"
)
runner = CliRunner()
result = runner.invoke(
@@ -643,7 +645,9 @@ def test_example_config_loads_as_pipeline_config(self):
assert cfg.oracle.cost_ps == 1.0
assert cfg.oracle.cost_drc == 10.0
assert cfg.active_learning_loop.n_iterations == 20
- assert cfg.active_learning_loop.k_per_iteration == 100
+ assert cfg.active_learning_loop.plate_size == 1536
+ assert cfg.active_learning_loop.wells_per_ps == 1
+ assert cfg.active_learning_loop.wells_per_drc == 13
assert cfg.model.fast is True
assert cfg.model.reset_weights_on_refit is False
assert cfg.data.simulate.input_csv == ""
@@ -672,7 +676,9 @@ def _full_config(self, gt_path, pretrain_path=None, *, ps_threshold=5.0, test_se
" enabled: false\n"
"active_learning_loop:\n"
" n_iterations: 1\n"
- " k_per_iteration: 1\n"
+ " plate_size: 1\n"
+ " wells_per_ps: 1\n"
+ " wells_per_drc: 1\n"
"oracle:\n"
f" ps_threshold: {ps_threshold}\n"
)
diff --git a/tests/test_loop.py b/tests/test_loop.py
index 7f37d63..a715130 100644
--- a/tests/test_loop.py
+++ b/tests/test_loop.py
@@ -142,7 +142,9 @@ def loop(oracle, mock_model, acquisition, evaluator):
# ---------------------------------------------------------------------------
N_ITERATIONS = 3
-K = 5
+PLATE_SIZE = 5
+WELLS_PS = 1
+WELLS_DRC = 1
class TestLoopExecution:
@@ -150,21 +152,36 @@ class TestLoopExecution:
def test_correct_number_of_iterations(self, loop):
"""The results list must contain exactly n_iterations entries, confirming the loop ran the requested number of times."""
- results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ results = loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
assert len(results.iterations) == N_ITERATIONS
def test_labeled_pool_grows(self, loop, oracle):
"""The cumulative labeled count must be non-decreasing and total k × n_iterations compounds at the end."""
- results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ results = loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
prev = 0
for iter_result in results.iterations:
assert iter_result.cumulative_labeled >= prev
prev = iter_result.cumulative_labeled
- assert results.total_labeled == N_ITERATIONS * K
+ assert results.total_labeled == N_ITERATIONS * PLATE_SIZE
def test_cost_is_monotonically_increasing(self, loop):
"""Cumulative cost must be non-decreasing, since assays can only add cost, not remove it."""
- results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ results = loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
costs = results.costs()
assert all(costs[i] <= costs[i + 1] for i in range(len(costs) - 1))
assert results.total_cost == pytest.approx(costs[-1])
@@ -175,13 +192,23 @@ def test_no_compound_labeled_twice_same_fidelity(self, loop, oracle):
A compound may have two records if it was upgraded from PS to DRC, but
must never have two records with the same fidelity.
"""
- loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
pairs = [(r.canonical_smiles, r.fidelity) for r in oracle.labeled_records]
assert len(pairs) == len(set(pairs))
def test_model_refit_called_each_iteration(self, loop, mock_model):
"""model.refit must be called exactly once per iteration to ensure the model is updated with new labels."""
- loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
assert mock_model.refit.call_count == N_ITERATIONS
def test_reset_weights_flag_forwarded_to_refit(
@@ -196,7 +223,12 @@ def test_reset_weights_flag_forwarded_to_refit(
reset_weights_on_refit=True,
)
- loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
assert mock_model.refit.call_count == N_ITERATIONS
assert all(call.kwargs["reset_weights"] is True for call in mock_model.refit.call_args_list)
@@ -210,7 +242,12 @@ def test_predict_smiles_pool_never_grows(self, loop, mock_model):
It can only strictly shrink when a DRC query is made (compound leaves
both pools entirely). It must never grow.
"""
- loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
call_sizes = [len(c.args[0]) for c in mock_model.predict_smiles.call_args_list]
assert all(s1 >= s2 for s1, s2 in zip(call_sizes, call_sizes[1:], strict=False)), (
f"Scorable pool grew between iterations; sizes: {call_sizes}"
@@ -222,14 +259,24 @@ class TestMetrics:
def test_metrics_are_finite(self, loop):
"""All numeric metrics must be finite after every iteration; nan or inf would indicate a data pipeline bug."""
- results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ results = loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
for iter_result in results.iterations:
for key, value in iter_result.metrics.items():
assert np.isfinite(value), f"Metric {key} is not finite: {value}"
def test_total_cost_in_final_metrics(self, loop):
"""final_metrics must contain total_cost matching oracle.total_cost so downstream reporting is consistent."""
- results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ results = loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
assert "total_cost" in results.final_metrics
assert results.final_metrics["total_cost"] == pytest.approx(results.total_cost)
@@ -242,7 +289,12 @@ def test_total_cost_in_final_metrics(self, loop):
)
def test_metric_in_bounds(self, loop, key, lo, hi):
"""Recall must lie in [0, 1] and actives_per_dollar must be non-negative across all iterations."""
- results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ results = loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
for iter_result in results.iterations:
assert key in iter_result.metrics, (
f"Expected metric '{key}' in iter_result.metrics; "
@@ -277,7 +329,7 @@ def test_stops_when_all_labeled(self, ground_truth_df):
ev = PipelineEvaluator()
loop = ActiveLearningLoop(oracle=oracle, model=mock, acquisition=acq, evaluator=ev)
- results = loop.run(n_iterations=100, k_per_iteration=5)
+ results = loop.run(n_iterations=100, plate_size=5, wells_per_ps=1, wells_per_drc=1)
# At most 10 unique compounds could be queried; each may contribute 2 records
# (PS + DRC upgrade), so total_labeled can exceed pool_size. Check unique compounds.
unique_labeled = len(oracle._labeled)
@@ -303,7 +355,12 @@ def test_dashboard_update_called_and_costs_correct(
evaluator=evaluator,
dashboard=mock_db,
)
- loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
assert mock_db.update.call_count == N_ITERATIONS
for call in mock_db.update.call_args_list:
@@ -337,7 +394,12 @@ def test_model_metric_value_stored(self, oracle, mock_model, acquisition, evalua
test_set=(test_smiles, test_pec50),
model_metric=ModelMetric.MAE,
)
- results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ results = loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
for ir in results.iterations:
assert ir.model_metric_value is not None
assert np.isfinite(ir.model_metric_value)
@@ -350,7 +412,12 @@ def test_no_test_set_metric_value_none(self, oracle, mock_model, acquisition, ev
acquisition=acquisition,
evaluator=evaluator,
)
- results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K)
+ results = loop.run(
+ n_iterations=N_ITERATIONS,
+ plate_size=PLATE_SIZE,
+ wells_per_ps=WELLS_PS,
+ wells_per_drc=WELLS_DRC,
+ )
for ir in results.iterations:
assert ir.model_metric_value is None
@@ -411,7 +478,7 @@ def test_queries_succeed_with_kekule_smiles_and_is_canonical_true(self):
final_error=0.0,
)
- results = loop.run(n_iterations=2, k_per_iteration=2)
+ results = loop.run(n_iterations=2, plate_size=2, wells_per_ps=1, wells_per_drc=1)
# The oracle must have labeled compounds — if the bug were present,
# every query would have been silently skipped and the pool would be empty.
@@ -478,7 +545,7 @@ def test_upgrade_produces_both_records(self, upgrade_loop, upgrade_oracle):
"""Running enough iterations must produce at least one compound with
both a PS and a DRC record (confirming the upgrade path fires).
"""
- upgrade_loop.run(n_iterations=6, k_per_iteration=2)
+ upgrade_loop.run(n_iterations=6, plate_size=2, wells_per_ps=1, wells_per_drc=1)
records = upgrade_oracle.labeled_records
# Group by canonical SMILES
by_smiles: dict = defaultdict(list)
@@ -493,13 +560,13 @@ def test_upgrade_produces_both_records(self, upgrade_loop, upgrade_oracle):
def test_no_duplicate_fidelity_pairs(self, upgrade_loop, upgrade_oracle):
"""Each (smiles, fidelity) pair must appear at most once in labeled_records."""
- upgrade_loop.run(n_iterations=4, k_per_iteration=2)
+ upgrade_loop.run(n_iterations=4, plate_size=2, wells_per_ps=1, wells_per_drc=1)
pairs = [(r.canonical_smiles, r.fidelity) for r in upgrade_oracle.labeled_records]
assert len(pairs) == len(set(pairs))
def test_cost_includes_both_ps_and_drc(self, upgrade_loop, upgrade_oracle):
"""Total cost must reflect both PS and DRC assays when upgrades occur."""
- upgrade_loop.run(n_iterations=6, k_per_iteration=2)
+ upgrade_loop.run(n_iterations=6, plate_size=2, wells_per_ps=1, wells_per_drc=1)
manual_cost = sum(r.cost for r in upgrade_oracle.labeled_records)
assert upgrade_oracle.total_cost == pytest.approx(manual_cost)
@@ -508,7 +575,7 @@ def test_ps_labeled_pool_shrinks_as_upgrades_happen(self, upgrade_loop, upgrade_
as all INTERVAL-censored hits are upgraded to DRC.
"""
# Run enough iterations to exhaust the whole pool
- upgrade_loop.run(n_iterations=10, k_per_iteration=2)
+ upgrade_loop.run(n_iterations=10, plate_size=2, wells_per_ps=1, wells_per_drc=1)
# All compounds labeled; none remain eligible for PS→DRC upgrade
assert upgrade_oracle.get_ps_labeled_smiles() == []
@@ -551,7 +618,9 @@ class TestNoisyOracleErrorRamp:
"""Verify that the per-iteration noise ramp is correctly computed and dispatched."""
N_ITER = 4
- K = 2
+ PLATE_SIZE = 2
+ WELLS_PS = 1
+ WELLS_DRC = 1
def _make_loop(self, oracle, initial_error, final_error):
model = NoisyOracleModel(oracle._ground_truth, seed=0)
@@ -591,7 +660,12 @@ def _spy(smiles_list, noise_scale, batch_size=256):
return real_predict(smiles_list, noise_scale, batch_size)
with patch.object(model, "predict_smiles", side_effect=_spy):
- loop.run(n_iterations=self.N_ITER, k_per_iteration=self.K)
+ loop.run(
+ n_iterations=self.N_ITER,
+ plate_size=self.PLATE_SIZE,
+ wells_per_ps=self.WELLS_PS,
+ wells_per_drc=self.WELLS_DRC,
+ )
# The first call is the pre-loop seed; calls 1..N_ITER are the per-iteration
# Step 3 selections. Slice off the seed call and check the iteration calls.
@@ -618,7 +692,12 @@ def _spy(smiles_list, noise_scale, batch_size=256):
return real_predict(smiles_list, noise_scale, batch_size)
with patch.object(model, "predict_smiles", side_effect=_spy):
- loop.run(n_iterations=self.N_ITER, k_per_iteration=self.K)
+ loop.run(
+ n_iterations=self.N_ITER,
+ plate_size=self.PLATE_SIZE,
+ wells_per_ps=self.WELLS_PS,
+ wells_per_drc=self.WELLS_DRC,
+ )
assert len(captured) >= self.N_ITER + 1
for i, ns in enumerate(captured):
@@ -638,7 +717,12 @@ def _spy(smiles_list, noise_scale, batch_size=256):
return real_predict(smiles_list, noise_scale, batch_size)
with patch.object(model, "predict_smiles", side_effect=_spy):
- loop.run(n_iterations=self.N_ITER, k_per_iteration=self.K)
+ loop.run(
+ n_iterations=self.N_ITER,
+ plate_size=self.PLATE_SIZE,
+ wells_per_ps=self.WELLS_PS,
+ wells_per_drc=self.WELLS_DRC,
+ )
assert captured, "predict_smiles was never called"
assert captured[0] == pytest.approx(initial, abs=1e-7), (
@@ -650,7 +734,9 @@ class TestPretrainRecords:
"""Tests for ActiveLearningLoop behaviour when pretrain_records are provided."""
N_ITER = 2
- K = 3
+ PLATE_SIZE = 3
+ WELLS_PS = 1
+ WELLS_DRC = 1
@pytest.fixture
def pretrain_loop(self, oracle, mock_model, acquisition, evaluator):
@@ -702,7 +788,12 @@ def pretrain_loop(self, oracle, mock_model, acquisition, evaluator):
def test_pretrain_records_included_in_refit_call(self, pretrain_loop, mock_model):
"""model.refit must be called with the exact pretrain SMILES in the records list."""
- pretrain_loop.run(n_iterations=self.N_ITER, k_per_iteration=self.K)
+ pretrain_loop.run(
+ n_iterations=self.N_ITER,
+ plate_size=self.PLATE_SIZE,
+ wells_per_ps=self.WELLS_PS,
+ wells_per_drc=self.WELLS_DRC,
+ )
assert mock_model.refit.called
# Every pretrain SMILES must appear among the records passed to the first refit
first_call_records = mock_model.refit.call_args_list[0][1]["records"]
@@ -743,8 +834,18 @@ def test_empty_pretrain_reproduces_no_pretrain_behaviour(
loop_no_pretrain.model = m1
loop_empty.model = m2
- loop_no_pretrain.run(n_iterations=self.N_ITER, k_per_iteration=self.K)
- loop_empty.run(n_iterations=self.N_ITER, k_per_iteration=self.K)
+ loop_no_pretrain.run(
+ n_iterations=self.N_ITER,
+ plate_size=self.PLATE_SIZE,
+ wells_per_ps=self.WELLS_PS,
+ wells_per_drc=self.WELLS_DRC,
+ )
+ loop_empty.run(
+ n_iterations=self.N_ITER,
+ plate_size=self.PLATE_SIZE,
+ wells_per_ps=self.WELLS_PS,
+ wells_per_drc=self.WELLS_DRC,
+ )
assert m1.refit.call_count == m2.refit.call_count