diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 441791d..fcd42cd 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -68,12 +68,13 @@ ChemPropLightningModule.refit() ← or NoisyOracleModel (fast=True) model.predict_smiles(unlabeled + ps_labeled) → pEC50 point estimates │ ▼ -CostAwareGreedyAcquisition.select(unlabeled, predictions, k, +CostAwareGreedyAcquisition.select(unlabeled, predictions, + plate_size, wells_per_ps, wells_per_drc, ps_labeled_smiles, ps_labeled_predictions) │ DRC score = p_active(ŷ) / cost_DRC [exploitation] │ PS score = H_binary(p_cross(ŷ, T)) / cost_PS [exploration] │ PS-labeled compounds: DRC-upgrade candidates only (no PS re-query) - │ Greedy top-k, one query per compound + │ Greedy by score; stops when next candidate would overflow plate_size wells ▼ ActiveLearningLoop.run() → LoopResults │ 3 Rich progress steps per iteration: query → refit → select @@ -158,7 +159,7 @@ All campaign parameters live in `moal/config.py` as frozen dataclasses. The YAML | `data.simulate:` | `SimulationDataConfig` | `input_csv`, `smiles_column`, `pec50_column`, `is_canonical`, `test_set_size`; nested `pretrain:` → `PretrainDataConfig` | | `data.simulate.pretrain:` | `PretrainDataConfig` | `input_csv`, `smiles_column`, `relation_column`, `value_column`, `is_canonical`; same fields as `PlanDataConfig` minus `output_csv` | | `data.plan:` | `PlanDataConfig` | `input_csv`, `output_csv`, `smiles_column`, `relation_column`, `value_column`, `is_canonical` | -| `active_learning_loop:` | `ActiveLearningLoopConfig` | `n_iterations`, `k_per_iteration` | +| `active_learning_loop:` | `ActiveLearningLoopConfig` | `n_iterations`, `plate_size`, `wells_per_ps`, `wells_per_drc` | | *(top-level)* | `PipelineConfig` | `seed` | ### Fast mode (NoisyOracleModel) diff --git a/README.md b/README.md index 3671577..f2a7d50 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ The campaign emits a rich progress bar with `n_iterations × 3` discrete steps: ``` ⠹ Iter 3/20 Querying oracle — 5 DRC (2 upgrades), 3 PS ████░░ 15% 0:00:12 ⠹ Iter 3/20 Retraining model — 28 oracle + 60 pretrain records (DRC / PS) ████░░ 17% 0:00:45 - ⠹ Iter 3/20 Selecting next 10 — 18 unqueried, 3 PS hits eligible for upgrade ████░░ 18% 0:00:46 + ⠹ Iter 3/20 Selecting (plate=1536) — 18 unqueried, 3 PS hits eligible for upgrade ████░░ 18% 0:00:46 ``` ## Key Design Notes @@ -178,10 +178,12 @@ The campaign emits a rich progress bar with `n_iterations × 3` discrete steps: - Primary screen hit threshold T (`ps_threshold`): pEC50 ≈ 5.0 - Optimization target (`activity_threshold`): pEC50 = 7.0 -**Acquisition strategy (greedy):** Each iteration scores two pools — unqueried compounds (eligible for PS or DRC) and PS-INTERVAL-labeled hits (eligible for DRC upgrade only) — on the same cost-normalised scale: +**Acquisition strategy (plate-budget greedy):** Each iteration scores two pools — unqueried compounds (eligible for PS or DRC) and PS-INTERVAL-labeled hits (eligible for DRC upgrade only) — on the same cost-normalised scale: - `score(x, DRC) = sigmoid((ŷ - 7.0) / τ) / cost_DRC` — exploits likely actives; applies equally to first-pass DRC and upgrade-DRC candidates - `score(x, PS) = H_binary(sigmoid((ŷ - T) / τ)) / cost_PS` — cheaply resolves threshold ambiguity; only generated for unqueried compounds +Candidates are selected in score order until adding the next would exceed `active_learning_loop.plate_size` wells (each PS query costs `wells_per_ps`, each DRC costs `wells_per_drc`). When the next candidate overflows the plate the loop hard-stops; remaining candidates are deferred to the next iteration and rescored on the updated model. To replicate a flat query count of k, set `plate_size=k`, `wells_per_ps=1`, `wells_per_drc=1`. + **Per-fidelity loss monitoring:** `training_step` and `validation_step` log `train_drc_loss`, `train_ps_loss`, `val_drc_loss`, and `val_ps_loss` separately (in addition to the aggregate `train_loss` / `val_loss`), making it possible to detect if DRC regression degrades while PS labels keep the total loss deceptively low. ## Output Files diff --git a/examples/default_config.yaml b/examples/default_config.yaml index 69f8147..2377d6a 100644 --- a/examples/default_config.yaml +++ b/examples/default_config.yaml @@ -80,7 +80,11 @@ acquisition: # -------------------------------------------------------------------------- active_learning_loop: n_iterations: 20 # Number of active learning iterations (m) - k_per_iteration: 100 # Number of queries selected per iteration (k) + plate_size: 1536 # Maximum wells available per plate (per iteration). + # Selection stops when the next candidate would exceed this limit; + # remaining capacity is deferred to the next iteration. + wells_per_ps: 1 # Wells consumed by a single Primary Screen query (typically 1) + wells_per_drc: 13 # Wells consumed by a single DRC query (e.g., 13-point singlet DRC) # -------------------------------------------------------------------------- # ChemProp / CheMeleon model diff --git a/moal/acquisition.py b/moal/acquisition.py index e762ce1..02437e6 100644 --- a/moal/acquisition.py +++ b/moal/acquisition.py @@ -86,7 +86,7 @@ def _binary_entropy(p: np.ndarray) -> np.ndarray: class CostAwareGreedyAcquisition: - """Select k (compound, fidelity) query pairs per active-learning iteration. + """Select (compound, fidelity) query pairs that fit within a plate budget. Parameters ---------- @@ -181,20 +181,30 @@ def select( self, unlabeled_smiles: list[str], predictions: np.ndarray, - k: int, + plate_size: int, + wells_per_ps: int, + wells_per_drc: int, ps_labeled_smiles: list[str] | None = None, ps_labeled_predictions: np.ndarray | None = None, ) -> list[tuple[str, QueryType]]: - """Greedily select k (compound, fidelity) pairs. + """Greedily select queries that fit within a plate well budget. - Two pools are considered: + Candidates are ranked by acquisition score (highest first) across two + pools: - - *Unqueried* compounds (``unlabeled_smiles``): eligible for either PS or - DRC. Both candidates enter the unified ranked list. + - *Unqueried* compounds (``unlabeled_smiles``): eligible for either PS + or DRC. Both candidates enter the unified ranked list. - *PS-labeled* compounds (``ps_labeled_smiles``): already screened with PS; eligible for a DRC upgrade only. Only a DRC candidate is generated for each. + The loop walks the ranked list from highest to lowest score. When the + next candidate's well cost would push the running total above + ``plate_size``, the loop stops and returns whatever has been selected so + far. No attempt is made to fill the remaining capacity with lower-ranked + candidates — unused wells are deferred to the next iteration, where all + candidates will be rescored on the updated labeled pool. + Parameters ---------- unlabeled_smiles : list[str] @@ -202,8 +212,13 @@ def select( predictions : np.ndarray Model pEC50 point estimates, shape ``(N,)``, aligned with ``unlabeled_smiles``. - k : int - Number of queries to select. + plate_size : int + Maximum total wells available on the plate. Selection stops as + soon as adding the next candidate would exceed this limit. + wells_per_ps : int + Number of wells consumed by a single PS query. + wells_per_drc : int + Number of wells consumed by a single DRC query. ps_labeled_smiles : list[str], optional Ground-truth keys for compounds that have a PS label but no DRC label (i.e., INTERVAL-censored hits eligible for a full @@ -216,7 +231,8 @@ def select( Returns ------- list[tuple[str, QueryType]] - Ordered list of (smiles, QueryType) pairs, highest-scoring first. + Ordered list of (smiles, QueryType) pairs, highest-scoring first, + whose cumulative well cost does not exceed ``plate_size``. Raises ------ @@ -229,7 +245,7 @@ def select( logger.warning("No unlabeled compounds available for acquisition.") return [] - if k == 0: + if plate_size == 0: return [] predictions = np.asarray(predictions, dtype=np.float32) @@ -271,22 +287,25 @@ def select( selected: list[tuple[str, QueryType]] = [] selected_smiles: set[str] = set() + wells_used = 0 for _score, smi, qt in candidates: if smi in selected_smiles: continue + cost = wells_per_drc if qt == QueryType.DOSE_RESPONSE else wells_per_ps + if wells_used + cost > plate_size: + break selected.append((smi, qt)) selected_smiles.add(smi) - if len(selected) >= k: - break + wells_used += cost - if len(selected) < k: + if wells_used == 0 and (unlabeled_smiles or ps_labeled_smiles): logger.warning( - "Could only select %d queries (requested %d); " - "%d unqueried and %d PS-labeled compounds available.", - len(selected), - k, - len(unlabeled_smiles), - len(ps_labeled_smiles), + "No candidates fit within plate_size=%d " + "(wells_per_ps=%d, wells_per_drc=%d). " + "No queries selected for this iteration.", + plate_size, + wells_per_ps, + wells_per_drc, ) return selected diff --git a/moal/cli.py b/moal/cli.py index c46ea37..1da7d9e 100644 --- a/moal/cli.py +++ b/moal/cli.py @@ -235,7 +235,9 @@ def simulate(config: Path, output_dir: Path | None, verbose: bool) -> None: try: results = loop.run( n_iterations=cfg.active_learning_loop.n_iterations, - k_per_iteration=cfg.active_learning_loop.k_per_iteration, + plate_size=cfg.active_learning_loop.plate_size, + wells_per_ps=cfg.active_learning_loop.wells_per_ps, + wells_per_drc=cfg.active_learning_loop.wells_per_drc, ) except ValueError as exc: raise click.ClickException(str(exc)) from exc diff --git a/moal/config.py b/moal/config.py index 31ac564..f471454 100644 --- a/moal/config.py +++ b/moal/config.py @@ -379,12 +379,26 @@ class ActiveLearningLoopConfig: ---------- n_iterations : int Number of active learning iterations (m). - k_per_iteration : int - Number of oracle queries issued per iteration (k). + plate_size : int + Maximum number of wells available per plate (i.e., per iteration). + The acquisition greedily selects ranked candidates in score order, + stopping as soon as the next candidate would push the total well + count over this limit. Any remaining plate capacity is accepted + and the unused candidates are deferred to the next iteration, where + the model will be re-scored on the updated labeled pool. + wells_per_ps : int + Number of wells consumed by a single Primary Screen (PS) query. + Typically 1 for a singlet primary screen. + wells_per_drc : int + Number of wells consumed by a single Dose-Response Curve (DRC) query. + For example, a 13-point DRC in duplicate consumes 26 wells; a compact + 8-point singlet DRC consumes 8. """ n_iterations: int = 20 - k_per_iteration: int = 10 + plate_size: int = 1536 + wells_per_ps: int = 1 + wells_per_drc: int = 13 @dataclass(frozen=True) diff --git a/moal/dashboard.py b/moal/dashboard.py index 99cdc30..27bd40c 100644 --- a/moal/dashboard.py +++ b/moal/dashboard.py @@ -347,6 +347,9 @@ def update( iter_drc_cost: float, iter_ps_cost: float, iter_upgrade_cost: float = 0.0, + iter_n_drc_new: int = 0, + iter_n_upgrades: int = 0, + iter_n_ps: int = 0, model_metric_value: float | None = None, ) -> None: """Append iteration data for live display and deferred GIF/HTML export. @@ -363,6 +366,12 @@ def update( Total PS cost incurred in the current iteration. iter_upgrade_cost : float, optional Portion of ``iter_drc_cost`` attributable to PS→DRC upgrades. + iter_n_drc_new : int, optional + Number of new (first-pass) DRC queries issued this iteration. + iter_n_upgrades : int, optional + Number of PS→DRC upgrade queries issued this iteration. + iter_n_ps : int, optional + Number of PS queries issued this iteration. model_metric_value : float, optional Held-out test-set metric for this iteration, or None if unavailable. """ @@ -390,6 +399,9 @@ def update( "iter_drc_cost": iter_drc_cost, "iter_ps_cost": iter_ps_cost, "iter_upgrade_cost": iter_upgrade_cost, + "iter_n_drc_new": iter_n_drc_new, + "iter_n_upgrades": iter_n_upgrades, + "iter_n_ps": iter_n_ps, "model_metric_value": model_metric_value, "n_ps_only": n_ps_only, "n_drc_new": n_drc_new, @@ -841,6 +853,9 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure: iter_drc_new = [it["iter_drc_cost"] - it["iter_upgrade_cost"] for it in iterations] iter_upgrades = [it["iter_upgrade_cost"] for it in iterations] iter_ps = [it["iter_ps_cost"] for it in iterations] + iter_n_drc_new = [it.get("iter_n_drc_new", 0) for it in iterations] + iter_n_upgrades = [it.get("iter_n_upgrades", 0) for it in iterations] + iter_n_ps = [it.get("iter_n_ps", 0) for it in iterations] # Scale to thousands so secondary y-axis ticks stay compact whole-number integers cum_total_costs = [ c / 1000 @@ -880,6 +895,8 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure: y=iter_drc_new, name="DRC", marker_color=_COLOUR_DRC, + customdata=iter_n_drc_new, + hovertemplate="DRC
%{customdata} queries
$%{y:.0f}", showlegend=False, ), row=1, @@ -891,6 +908,8 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure: y=iter_upgrades, name="PS→DRC", marker_color=_COLOUR_UPGRADE, + customdata=iter_n_upgrades, + hovertemplate="PS→DRC
%{customdata} upgrades
$%{y:.0f}", showlegend=False, ), row=1, @@ -902,6 +921,8 @@ def _build_figure(self, iterations: list[dict]) -> go.Figure: y=iter_ps, name="PS", marker_color=_COLOUR_PS, + customdata=iter_n_ps, + hovertemplate="PS
%{customdata} queries
$%{y:.0f}", showlegend=False, ), row=1, diff --git a/moal/loop.py b/moal/loop.py index 1419fa0..552214b 100644 --- a/moal/loop.py +++ b/moal/loop.py @@ -201,14 +201,20 @@ def __init__( # Main entry point # ------------------------------------------------------------------ - def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults: + def run( + self, + n_iterations: int, + plate_size: int, + wells_per_ps: int, + wells_per_drc: int, + ) -> LoopResults: """Execute the full active learning campaign. Each iteration completes three sequential steps tracked in the Rich progress bar: - 1. **Query oracle** — issue ``k`` queries from the pre-computed - candidate list assembled at the end of the previous iteration. + 1. **Query oracle** — issue queries from the pre-computed candidate + list assembled at the end of the previous iteration. 2. **Refit model** — fine-tune the model on the growing labeled pool. 3. **Select compounds** — run model inference and acquisition scoring over the remaining pool to prepare the next iteration's queries. @@ -217,8 +223,15 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults: ---------- n_iterations : int Total number of active learning iterations to run. - k_per_iteration : int - Number of oracle queries to issue per iteration. + plate_size : int + Maximum total wells available per plate (i.e., per iteration). + The acquisition greedily selects ranked candidates in score order, + stopping as soon as the next candidate would push the total well + count over this limit. + wells_per_ps : int + Number of wells consumed by a single PS query. + wells_per_drc : int + Number of wells consumed by a single DRC query. Returns ------- @@ -255,8 +268,9 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults: _console.print( f"[bold]moal[/bold] campaign starting — " - f"[cyan]{n_iterations}[/cyan] iterations × " - f"[cyan]{k_per_iteration}[/cyan] queries | " + f"[cyan]{n_iterations}[/cyan] iterations | " + f"plate: [cyan]{plate_size}[/cyan] wells " + f"([cyan]{wells_per_ps}[/cyan] PS / [cyan]{wells_per_drc}[/cyan] DRC) | " f"[cyan]{n_total}[/cyan] compounds | " f"[cyan]{n_true_actives}[/cyan] true actives" ) @@ -283,7 +297,9 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults: self.acquisition.select( unlabeled, unlabeled_preds, - k_per_iteration, + plate_size, + wells_per_ps, + wells_per_drc, ps_labeled_smiles=ps_labeled, ps_labeled_predictions=ps_labeled_preds if ps_labeled else None, ) @@ -359,6 +375,21 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults: if r.fidelity == QueryType.DOSE_RESPONSE and r.canonical_smiles in ps_labeled_before ) + iter_n_ps = sum( + 1 for r in new_records if r.fidelity == QueryType.PRIMARY_SCREEN + ) + iter_n_drc_upgrade = sum( + 1 + for r in new_records + if r.fidelity == QueryType.DOSE_RESPONSE + and r.canonical_smiles in ps_labeled_before + ) + iter_n_drc_new = sum( + 1 + for r in new_records + if r.fidelity == QueryType.DOSE_RESPONSE + and r.canonical_smiles not in ps_labeled_before + ) progress.advance(task) # --- Refit model -------------------------------------- @@ -434,7 +465,7 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults: task, description=( f"[green]Iter {iteration + 1}/{n_iterations}[/green] " - f"Selecting next {k_per_iteration} — " + f"Selecting (plate={plate_size}) — " f"[white]{len(remaining_unlabeled)} unqueried[/white], " f"[magenta]{len(remaining_ps_labeled)} PS hits[/magenta]" " eligible for upgrade" @@ -450,7 +481,9 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults: pending_queries = self.acquisition.select( remaining_unlabeled, unlabeled_preds, - k_per_iteration, + plate_size, + wells_per_ps, + wells_per_drc, ps_labeled_smiles=remaining_ps_labeled, ps_labeled_predictions=ps_labeled_preds if remaining_ps_labeled @@ -490,6 +523,9 @@ def run(self, n_iterations: int, k_per_iteration: int) -> LoopResults: iter_drc_cost=iter_drc_cost, iter_ps_cost=iter_ps_cost, iter_upgrade_cost=iter_upgrade_cost, + iter_n_drc_new=iter_n_drc_new, + iter_n_upgrades=iter_n_drc_upgrade, + iter_n_ps=iter_n_ps, model_metric_value=model_metric_value, ) diff --git a/tests/test_acquisition.py b/tests/test_acquisition.py index da8523a..68e61bd 100644 --- a/tests/test_acquisition.py +++ b/tests/test_acquisition.py @@ -82,10 +82,10 @@ class TestSelect: """Integration tests for CostAwareGreedyAcquisition.select().""" def test_returns_k_unique_queries(self, acq): - """Select must return exactly k pairs with no repeated SMILES, since a compound should only be queried once.""" + """Select must return no repeated SMILES, since a compound should only be queried once.""" smiles = [f"C{i}" for i in range(20)] preds = np.random.default_rng(0).normal(6.0, 1.5, 20).astype(np.float32) - selected = acq.select(smiles, preds, k=5) + selected = acq.select(smiles, preds, plate_size=5, wells_per_ps=1, wells_per_drc=1) assert len(selected) == 5 selected_smiles = [s for s, _ in selected] assert len(selected_smiles) == len(set(selected_smiles)) @@ -94,7 +94,7 @@ def test_fidelity_types_are_valid(self, acq): """Every selection must be either PS or DRC — no other query type should ever be emitted.""" smiles = [f"C{i}" for i in range(10)] preds = np.ones(10, dtype=np.float32) * 6.0 - selected = acq.select(smiles, preds, k=8) + selected = acq.select(smiles, preds, plate_size=8, wells_per_ps=1, wells_per_drc=1) for _, qt in selected: assert qt in (QueryType.PRIMARY_SCREEN, QueryType.DOSE_RESPONSE) @@ -102,32 +102,34 @@ def test_high_pec50_prefers_drc(self, acq): """Compounds with very high predicted pEC50 should prefer DRC.""" smiles = ["high", "low"] preds = np.array([9.5, 3.0], dtype=np.float32) - selected = acq.select(smiles, preds, k=1) + selected = acq.select(smiles, preds, plate_size=1, wells_per_ps=1, wells_per_drc=1) assert selected[0] == ("high", QueryType.DOSE_RESPONSE) def test_at_threshold_prefers_ps(self, acq): """Compounds at the PS threshold (max entropy) should prefer PS when cheap.""" smiles = ["at_threshold"] preds = np.array([5.0], dtype=np.float32) - selected = acq.select(smiles, preds, k=1) + selected = acq.select(smiles, preds, plate_size=1, wells_per_ps=1, wells_per_drc=1) assert selected[0][1] == QueryType.PRIMARY_SCREEN @pytest.mark.parametrize( - "smiles,preds,k", + "smiles,preds,plate_size", [ ([], np.array([]), 5), ([f"C{i}" for i in range(10)], np.ones(10, dtype=np.float32) * 6.0, 0), ], ) - def test_empty_selection_returns_empty(self, acq, smiles, preds, k): - """An empty pool or k=0 must return [] without error, as there is nothing to select.""" - assert acq.select(smiles, preds, k=k) == [] + def test_empty_selection_returns_empty(self, acq, smiles, preds, plate_size): + """An empty pool or plate_size=0 must return [] without error, as there is nothing to select.""" + assert ( + acq.select(smiles, preds, plate_size=plate_size, wells_per_ps=1, wells_per_drc=1) == [] + ) - def test_k_larger_than_pool(self, acq): - """When k exceeds the pool size, all available compounds must be returned rather than raising.""" + def test_plate_larger_than_pool(self, acq): + """When plate_size exceeds the pool's total well cost, all available compounds must be returned.""" smiles = ["A", "B"] preds = np.array([5.0, 6.0], dtype=np.float32) - selected = acq.select(smiles, preds, k=100) + selected = acq.select(smiles, preds, plate_size=100, wells_per_ps=1, wells_per_drc=1) assert len(selected) == 2 # limited by pool size def test_invalid_cost_raises(self): @@ -148,7 +150,7 @@ def test_degenerate_thresholds_still_selects(self, acq): ) smiles = [f"C{i}" for i in range(5)] preds = np.array([5.0, 6.0, 7.0, 8.0, 9.0], dtype=np.float32) - selected = degenerate.select(smiles, preds, k=3) + selected = degenerate.select(smiles, preds, plate_size=3, wells_per_ps=1, wells_per_drc=1) assert len(selected) == 3 for _, qt in selected: assert qt in (QueryType.PRIMARY_SCREEN, QueryType.DOSE_RESPONSE) @@ -162,7 +164,48 @@ def test_select_with_nan_predictions(self, acq): smiles = ["A", "B", "C"] preds = np.array([float("nan"), 6.0, 7.0], dtype=np.float32) with pytest.raises(ValueError, match="finite"): - acq.select(smiles, preds, k=2) + acq.select(smiles, preds, plate_size=2, wells_per_ps=1, wells_per_drc=1) + + def test_drc_cost_stops_at_plate_boundary(self, acq): + """Selection must stop when the next DRC candidate would overflow the plate. + + With plate_size=14, wells_per_drc=13, and wells_per_ps=1, the top + candidate (DRC, 13 wells) fills most of the plate. The second candidate + is also a DRC (13 wells), which would push total wells to 26 > 14, so + the loop must hard-stop and return only the first compound. + """ + acq_asymmetric = CostAwareGreedyAcquisition( + cost_ps=1.0, + cost_drc=10.0, + ps_threshold=5.0, + target_threshold=7.0, + tau=0.5, + ) + smiles = ["high1", "high2", "low"] + # Both high compounds score heavily for DRC; low scores for PS + preds = np.array([9.5, 9.0, 3.0], dtype=np.float32) + selected = acq_asymmetric.select( + smiles, preds, plate_size=14, wells_per_ps=1, wells_per_drc=13 + ) + # First candidate: DRC for "high1" (13 wells used; 1 well left) + # Second candidate: DRC for "high2" would use 13 more → 26 > 14 → stop + assert len(selected) == 1 + assert selected[0] == ("high1", QueryType.DOSE_RESPONSE) + + def test_wells_used_never_exceeds_plate_size(self): + """Total wells consumed by the selection must never exceed plate_size.""" + acq = CostAwareGreedyAcquisition(cost_ps=1.0, cost_drc=10.0) + rng = np.random.default_rng(7) + smiles = [f"C{i}" for i in range(50)] + preds = rng.normal(6.0, 1.5, 50).astype(np.float32) + plate_size, wells_ps, wells_drc = 100, 1, 13 + selected = acq.select( + smiles, preds, plate_size=plate_size, wells_per_ps=wells_ps, wells_per_drc=wells_drc + ) + total_wells = sum( + wells_drc if qt == QueryType.DOSE_RESPONSE else wells_ps for _, qt in selected + ) + assert total_wells <= plate_size class TestPSUpgradeCandidates: @@ -186,7 +229,9 @@ def test_ps_labeled_compounds_only_generate_drc_candidates(self, acq): selected = acq.select( [], np.array([]), - 1, + plate_size=1, + wells_per_ps=1, + wells_per_drc=1, ps_labeled_smiles=ps_smiles, ps_labeled_predictions=ps_preds, ) @@ -202,7 +247,9 @@ def test_ps_labeled_and_unlabeled_compete_correctly(self, acq): selected = acq.select( unlabeled, unlabeled_preds, - 1, + plate_size=1, + wells_per_ps=1, + wells_per_drc=1, ps_labeled_smiles=ps_labeled, ps_labeled_predictions=ps_labeled_preds, ) @@ -212,9 +259,15 @@ def test_empty_ps_labeled_behaves_like_no_pool(self, acq): """Passing ps_labeled_smiles=[] must not change behaviour vs omitting the argument.""" smiles = [f"C{i}" for i in range(5)] preds = np.ones(5, dtype=np.float32) * 6.0 - without_pool = acq.select(smiles, preds, 3) + without_pool = acq.select(smiles, preds, plate_size=3, wells_per_ps=1, wells_per_drc=1) with_empty_pool = acq.select( - smiles, preds, 3, ps_labeled_smiles=[], ps_labeled_predictions=None + smiles, + preds, + plate_size=3, + wells_per_ps=1, + wells_per_drc=1, + ps_labeled_smiles=[], + ps_labeled_predictions=None, ) assert without_pool == with_empty_pool @@ -224,7 +277,9 @@ def test_no_smiles_length_mismatch_assertion(self, acq): acq.select( [], np.array([]), - 1, + plate_size=1, + wells_per_ps=1, + wells_per_drc=1, ps_labeled_smiles=["A", "B"], ps_labeled_predictions=np.array([1.0]), ) diff --git a/tests/test_cli.py b/tests/test_cli.py index eb91756..5f5dea0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -227,7 +227,9 @@ def test_custom_column_names_accepted(self, tmp_path): " enabled: false\n" "active_learning_loop:\n" " n_iterations: 1\n" - " k_per_iteration: 1\n" + " plate_size: 1\n" + " wells_per_ps: 1\n" + " wells_per_drc: 1\n" ) runner = CliRunner() result = runner.invoke( @@ -643,7 +645,9 @@ def test_example_config_loads_as_pipeline_config(self): assert cfg.oracle.cost_ps == 1.0 assert cfg.oracle.cost_drc == 10.0 assert cfg.active_learning_loop.n_iterations == 20 - assert cfg.active_learning_loop.k_per_iteration == 100 + assert cfg.active_learning_loop.plate_size == 1536 + assert cfg.active_learning_loop.wells_per_ps == 1 + assert cfg.active_learning_loop.wells_per_drc == 13 assert cfg.model.fast is True assert cfg.model.reset_weights_on_refit is False assert cfg.data.simulate.input_csv == "" @@ -672,7 +676,9 @@ def _full_config(self, gt_path, pretrain_path=None, *, ps_threshold=5.0, test_se " enabled: false\n" "active_learning_loop:\n" " n_iterations: 1\n" - " k_per_iteration: 1\n" + " plate_size: 1\n" + " wells_per_ps: 1\n" + " wells_per_drc: 1\n" "oracle:\n" f" ps_threshold: {ps_threshold}\n" ) diff --git a/tests/test_loop.py b/tests/test_loop.py index 7f37d63..a715130 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -142,7 +142,9 @@ def loop(oracle, mock_model, acquisition, evaluator): # --------------------------------------------------------------------------- N_ITERATIONS = 3 -K = 5 +PLATE_SIZE = 5 +WELLS_PS = 1 +WELLS_DRC = 1 class TestLoopExecution: @@ -150,21 +152,36 @@ class TestLoopExecution: def test_correct_number_of_iterations(self, loop): """The results list must contain exactly n_iterations entries, confirming the loop ran the requested number of times.""" - results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + results = loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) assert len(results.iterations) == N_ITERATIONS def test_labeled_pool_grows(self, loop, oracle): """The cumulative labeled count must be non-decreasing and total k × n_iterations compounds at the end.""" - results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + results = loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) prev = 0 for iter_result in results.iterations: assert iter_result.cumulative_labeled >= prev prev = iter_result.cumulative_labeled - assert results.total_labeled == N_ITERATIONS * K + assert results.total_labeled == N_ITERATIONS * PLATE_SIZE def test_cost_is_monotonically_increasing(self, loop): """Cumulative cost must be non-decreasing, since assays can only add cost, not remove it.""" - results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + results = loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) costs = results.costs() assert all(costs[i] <= costs[i + 1] for i in range(len(costs) - 1)) assert results.total_cost == pytest.approx(costs[-1]) @@ -175,13 +192,23 @@ def test_no_compound_labeled_twice_same_fidelity(self, loop, oracle): A compound may have two records if it was upgraded from PS to DRC, but must never have two records with the same fidelity. """ - loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) pairs = [(r.canonical_smiles, r.fidelity) for r in oracle.labeled_records] assert len(pairs) == len(set(pairs)) def test_model_refit_called_each_iteration(self, loop, mock_model): """model.refit must be called exactly once per iteration to ensure the model is updated with new labels.""" - loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) assert mock_model.refit.call_count == N_ITERATIONS def test_reset_weights_flag_forwarded_to_refit( @@ -196,7 +223,12 @@ def test_reset_weights_flag_forwarded_to_refit( reset_weights_on_refit=True, ) - loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) assert mock_model.refit.call_count == N_ITERATIONS assert all(call.kwargs["reset_weights"] is True for call in mock_model.refit.call_args_list) @@ -210,7 +242,12 @@ def test_predict_smiles_pool_never_grows(self, loop, mock_model): It can only strictly shrink when a DRC query is made (compound leaves both pools entirely). It must never grow. """ - loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) call_sizes = [len(c.args[0]) for c in mock_model.predict_smiles.call_args_list] assert all(s1 >= s2 for s1, s2 in zip(call_sizes, call_sizes[1:], strict=False)), ( f"Scorable pool grew between iterations; sizes: {call_sizes}" @@ -222,14 +259,24 @@ class TestMetrics: def test_metrics_are_finite(self, loop): """All numeric metrics must be finite after every iteration; nan or inf would indicate a data pipeline bug.""" - results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + results = loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) for iter_result in results.iterations: for key, value in iter_result.metrics.items(): assert np.isfinite(value), f"Metric {key} is not finite: {value}" def test_total_cost_in_final_metrics(self, loop): """final_metrics must contain total_cost matching oracle.total_cost so downstream reporting is consistent.""" - results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + results = loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) assert "total_cost" in results.final_metrics assert results.final_metrics["total_cost"] == pytest.approx(results.total_cost) @@ -242,7 +289,12 @@ def test_total_cost_in_final_metrics(self, loop): ) def test_metric_in_bounds(self, loop, key, lo, hi): """Recall must lie in [0, 1] and actives_per_dollar must be non-negative across all iterations.""" - results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + results = loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) for iter_result in results.iterations: assert key in iter_result.metrics, ( f"Expected metric '{key}' in iter_result.metrics; " @@ -277,7 +329,7 @@ def test_stops_when_all_labeled(self, ground_truth_df): ev = PipelineEvaluator() loop = ActiveLearningLoop(oracle=oracle, model=mock, acquisition=acq, evaluator=ev) - results = loop.run(n_iterations=100, k_per_iteration=5) + results = loop.run(n_iterations=100, plate_size=5, wells_per_ps=1, wells_per_drc=1) # At most 10 unique compounds could be queried; each may contribute 2 records # (PS + DRC upgrade), so total_labeled can exceed pool_size. Check unique compounds. unique_labeled = len(oracle._labeled) @@ -303,7 +355,12 @@ def test_dashboard_update_called_and_costs_correct( evaluator=evaluator, dashboard=mock_db, ) - loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) assert mock_db.update.call_count == N_ITERATIONS for call in mock_db.update.call_args_list: @@ -337,7 +394,12 @@ def test_model_metric_value_stored(self, oracle, mock_model, acquisition, evalua test_set=(test_smiles, test_pec50), model_metric=ModelMetric.MAE, ) - results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + results = loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) for ir in results.iterations: assert ir.model_metric_value is not None assert np.isfinite(ir.model_metric_value) @@ -350,7 +412,12 @@ def test_no_test_set_metric_value_none(self, oracle, mock_model, acquisition, ev acquisition=acquisition, evaluator=evaluator, ) - results = loop.run(n_iterations=N_ITERATIONS, k_per_iteration=K) + results = loop.run( + n_iterations=N_ITERATIONS, + plate_size=PLATE_SIZE, + wells_per_ps=WELLS_PS, + wells_per_drc=WELLS_DRC, + ) for ir in results.iterations: assert ir.model_metric_value is None @@ -411,7 +478,7 @@ def test_queries_succeed_with_kekule_smiles_and_is_canonical_true(self): final_error=0.0, ) - results = loop.run(n_iterations=2, k_per_iteration=2) + results = loop.run(n_iterations=2, plate_size=2, wells_per_ps=1, wells_per_drc=1) # The oracle must have labeled compounds — if the bug were present, # every query would have been silently skipped and the pool would be empty. @@ -478,7 +545,7 @@ def test_upgrade_produces_both_records(self, upgrade_loop, upgrade_oracle): """Running enough iterations must produce at least one compound with both a PS and a DRC record (confirming the upgrade path fires). """ - upgrade_loop.run(n_iterations=6, k_per_iteration=2) + upgrade_loop.run(n_iterations=6, plate_size=2, wells_per_ps=1, wells_per_drc=1) records = upgrade_oracle.labeled_records # Group by canonical SMILES by_smiles: dict = defaultdict(list) @@ -493,13 +560,13 @@ def test_upgrade_produces_both_records(self, upgrade_loop, upgrade_oracle): def test_no_duplicate_fidelity_pairs(self, upgrade_loop, upgrade_oracle): """Each (smiles, fidelity) pair must appear at most once in labeled_records.""" - upgrade_loop.run(n_iterations=4, k_per_iteration=2) + upgrade_loop.run(n_iterations=4, plate_size=2, wells_per_ps=1, wells_per_drc=1) pairs = [(r.canonical_smiles, r.fidelity) for r in upgrade_oracle.labeled_records] assert len(pairs) == len(set(pairs)) def test_cost_includes_both_ps_and_drc(self, upgrade_loop, upgrade_oracle): """Total cost must reflect both PS and DRC assays when upgrades occur.""" - upgrade_loop.run(n_iterations=6, k_per_iteration=2) + upgrade_loop.run(n_iterations=6, plate_size=2, wells_per_ps=1, wells_per_drc=1) manual_cost = sum(r.cost for r in upgrade_oracle.labeled_records) assert upgrade_oracle.total_cost == pytest.approx(manual_cost) @@ -508,7 +575,7 @@ def test_ps_labeled_pool_shrinks_as_upgrades_happen(self, upgrade_loop, upgrade_ as all INTERVAL-censored hits are upgraded to DRC. """ # Run enough iterations to exhaust the whole pool - upgrade_loop.run(n_iterations=10, k_per_iteration=2) + upgrade_loop.run(n_iterations=10, plate_size=2, wells_per_ps=1, wells_per_drc=1) # All compounds labeled; none remain eligible for PS→DRC upgrade assert upgrade_oracle.get_ps_labeled_smiles() == [] @@ -551,7 +618,9 @@ class TestNoisyOracleErrorRamp: """Verify that the per-iteration noise ramp is correctly computed and dispatched.""" N_ITER = 4 - K = 2 + PLATE_SIZE = 2 + WELLS_PS = 1 + WELLS_DRC = 1 def _make_loop(self, oracle, initial_error, final_error): model = NoisyOracleModel(oracle._ground_truth, seed=0) @@ -591,7 +660,12 @@ def _spy(smiles_list, noise_scale, batch_size=256): return real_predict(smiles_list, noise_scale, batch_size) with patch.object(model, "predict_smiles", side_effect=_spy): - loop.run(n_iterations=self.N_ITER, k_per_iteration=self.K) + loop.run( + n_iterations=self.N_ITER, + plate_size=self.PLATE_SIZE, + wells_per_ps=self.WELLS_PS, + wells_per_drc=self.WELLS_DRC, + ) # The first call is the pre-loop seed; calls 1..N_ITER are the per-iteration # Step 3 selections. Slice off the seed call and check the iteration calls. @@ -618,7 +692,12 @@ def _spy(smiles_list, noise_scale, batch_size=256): return real_predict(smiles_list, noise_scale, batch_size) with patch.object(model, "predict_smiles", side_effect=_spy): - loop.run(n_iterations=self.N_ITER, k_per_iteration=self.K) + loop.run( + n_iterations=self.N_ITER, + plate_size=self.PLATE_SIZE, + wells_per_ps=self.WELLS_PS, + wells_per_drc=self.WELLS_DRC, + ) assert len(captured) >= self.N_ITER + 1 for i, ns in enumerate(captured): @@ -638,7 +717,12 @@ def _spy(smiles_list, noise_scale, batch_size=256): return real_predict(smiles_list, noise_scale, batch_size) with patch.object(model, "predict_smiles", side_effect=_spy): - loop.run(n_iterations=self.N_ITER, k_per_iteration=self.K) + loop.run( + n_iterations=self.N_ITER, + plate_size=self.PLATE_SIZE, + wells_per_ps=self.WELLS_PS, + wells_per_drc=self.WELLS_DRC, + ) assert captured, "predict_smiles was never called" assert captured[0] == pytest.approx(initial, abs=1e-7), ( @@ -650,7 +734,9 @@ class TestPretrainRecords: """Tests for ActiveLearningLoop behaviour when pretrain_records are provided.""" N_ITER = 2 - K = 3 + PLATE_SIZE = 3 + WELLS_PS = 1 + WELLS_DRC = 1 @pytest.fixture def pretrain_loop(self, oracle, mock_model, acquisition, evaluator): @@ -702,7 +788,12 @@ def pretrain_loop(self, oracle, mock_model, acquisition, evaluator): def test_pretrain_records_included_in_refit_call(self, pretrain_loop, mock_model): """model.refit must be called with the exact pretrain SMILES in the records list.""" - pretrain_loop.run(n_iterations=self.N_ITER, k_per_iteration=self.K) + pretrain_loop.run( + n_iterations=self.N_ITER, + plate_size=self.PLATE_SIZE, + wells_per_ps=self.WELLS_PS, + wells_per_drc=self.WELLS_DRC, + ) assert mock_model.refit.called # Every pretrain SMILES must appear among the records passed to the first refit first_call_records = mock_model.refit.call_args_list[0][1]["records"] @@ -743,8 +834,18 @@ def test_empty_pretrain_reproduces_no_pretrain_behaviour( loop_no_pretrain.model = m1 loop_empty.model = m2 - loop_no_pretrain.run(n_iterations=self.N_ITER, k_per_iteration=self.K) - loop_empty.run(n_iterations=self.N_ITER, k_per_iteration=self.K) + loop_no_pretrain.run( + n_iterations=self.N_ITER, + plate_size=self.PLATE_SIZE, + wells_per_ps=self.WELLS_PS, + wells_per_drc=self.WELLS_DRC, + ) + loop_empty.run( + n_iterations=self.N_ITER, + plate_size=self.PLATE_SIZE, + wells_per_ps=self.WELLS_PS, + wells_per_drc=self.WELLS_DRC, + ) assert m1.refit.call_count == m2.refit.call_count