You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat: add from_foundation flag to ChemPropLightningModule
Adds a from_foundation: str | bool parameter (default 'chemeleon') that
controls how the ChemProp message-passing encoder is initialised:
- 'chemeleon': downloads CheMeleon weights from Zenodo (existing behavior)
- '/path/to/weights.pt': loads a local checkpoint in the same
{hyper_parameters, state_dict} format as CheMeleon
- False: builds BondMessagePassing() with default ChemProp architecture
and random weights; no checkpoint required
Validation is performed at construction time via _validate_from_foundation()
which checks against _KNOWN_FOUNDATION_MODELS and Path.exists(). Unknown
names and non-existent paths raise ValueError with a helpful message.
Changes:
- moal/config.py: add from_foundation field to ModelConfig
- moal/model.py: _KNOWN_FOUNDATION_MODELS, _validate_from_foundation,
updated __init__ / _build_model dispatch, _load_foundation_weights
(replaces _get_chemeleon_mp)
- moal/cli.py: forward from_foundation in both model builders
- examples/default_config.yaml: document all three modes
- tests/test_model.py: update hparam assertions; add TestFromFoundation
class (7 new tests)
- README.md: generalize encoder description; add Foundation Model design note
- .github/copilot-instructions.md: update ModelConfig table; add
Foundation model section; retire outdated _CHEMPELEON_ATOM_FDIM note
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@@ -213,7 +213,7 @@ All modules use `logger = logging.getLogger(__name__)`. The `suppress_noisy_logg
213
213
214
214
### Freeze/unfreeze schedule
215
215
216
-
`ChemPropLightningModule` freezes the CheMeleon encoder for the first `freeze_epochs` training epochs, then unfreezes and adds a second optimizer for the encoder at `lr_encoder`. The epoch counter resets on every `trainer.fit()` call (every AL iteration). This is intentional — early iterations have tiny labeled pools where encoder fine-tuning would overfit.
216
+
`ChemPropLightningModule` freezes the message-passing encoder for the first `freeze_epochs` training epochs, then unfreezes and adds a second optimizer for the encoder at `lr_encoder`. The epoch counter resets on every `trainer.fit()` call (every AL iteration). This is intentional — early iterations have tiny labeled pools where encoder fine-tuning would overfit.
217
217
218
218
### Scaffold split
219
219
@@ -227,6 +227,14 @@ All modules use `logger = logging.getLogger(__name__)`. The `suppress_noisy_logg
227
227
228
228
`CliRunner.invoke()` does **not** sandbox file I/O by default. Any test that triggers CLI output-directory creation must pass `--output-dir str(tmp_path / "out")` (or use `runner.isolated_filesystem()`) to avoid leaking `results/` into the pytest CWD.
229
229
230
-
### CheMeleon feature dimensions
230
+
### Foundation model (`from_foundation`)
231
231
232
-
`_CHEMPELEON_ATOM_FDIM = 72` and `_CHEMPELEON_BOND_FDIM = 14` in `model.py` are hardcoded to match the CheMeleon pretraining feature spec. These are verified at model initialization. Do not change them without updating the checkpoint.
232
+
`ChemPropLightningModule` accepts a `from_foundation: str | bool` constructor parameter (default `"chemeleon"`) that controls encoder initialisation:
233
+
234
+
-`"chemeleon"` — downloads the CheMeleon checkpoint from Zenodo, caches it at `~/.chemprop/chemeleon_mp.pt`, and loads its weights.
235
+
- Any other string — treated as a filesystem path; the checkpoint must have `{"hyper_parameters": ..., "state_dict": ...}` format (same as CheMeleon).
236
+
-`False` — builds `BondMessagePassing()` with default ChemProp architecture and random weights; no checkpoint is required.
237
+
238
+
The known-name registry lives in `_KNOWN_FOUNDATION_MODELS: frozenset[str]` at module level. `_validate_from_foundation(value)` is called at `__init__` time; it raises `ValueError` for unknown names and non-existent paths. `from_foundation` is included in `save_hyperparameters()`.
239
+
240
+
`True` is not a valid value — the validator catches it because `True` is not in `_KNOWN_FOUNDATION_MODELS` and `Path(True)` is not a valid path expression.
Copy file name to clipboardExpand all lines: README.md
+11-2Lines changed: 11 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -7,7 +7,7 @@ A Python pipeline for maximizing the discovery of **active compounds** (pEC50 >
7
7
-**Primary Screen (PS):** Returns an inequality label (`< T` or `>= T`) at a configurable threshold. Cheap. A hit (`>= T`) is an INTERVAL-censored label — eligible for a DRC upgrade in a later iteration.
8
8
-**Dose-Response Curve (DRC):** Returns the exact continuous pEC50 value. Expensive. Can be run as a first-pass query *or* as a follow-up upgrade on a PS hit.
9
9
10
-
The underlying predictive model is **ChemProp**initialized with **CheMeleon** pretrained weights, trained with a **Tobit (censored regression) loss** that correctly handles both label types.
10
+
The underlying predictive model is **ChemProp**fine-tuned with a **Tobit (censored regression) loss** that correctly handles both label types. By default the ChemProp encoder is initialised with **CheMeleon** pretrained weights (see `model.from_foundation` below).
11
11
12
12
## Installation
13
13
@@ -49,7 +49,8 @@ data:
49
49
```
50
50
51
51
CheMeleon pretrained weights are downloaded automatically from Zenodo on first
52
-
run and cached at `~/.chemprop/chemeleon_mp.pt` for subsequent use.
52
+
run and cached at `~/.chemprop/chemeleon_mp.pt` for subsequent use (only when
53
+
`model.from_foundation: chemeleon`, the default).
53
54
54
55
## Commands
55
56
@@ -167,6 +168,14 @@ The campaign emits a rich progress bar with `n_iterations × 3` discrete steps:
167
168
168
169
## Key Design Notes
169
170
171
+
**Foundation model (`model.from_foundation`):** Controls which weights initialise the ChemProp message-passing encoder. Three values are accepted:
172
+
173
+
- `"chemeleon"`(default) — downloads the CheMeleon checkpoint from Zenodo and loads it.
174
+
- A filesystem path string — loads a local checkpoint in the same `{hyper_parameters, state_dict}` format as CheMeleon.
175
+
- `false`— builds the encoder with default ChemProp architecture and random weights; no checkpoint required. Useful for ablation studies or environments without network access.
176
+
177
+
Unknown named strings and non-existent paths raise `ValueError` at model construction. The `from_foundation` value is recorded in Lightning checkpoints alongside all other hyperparameters.
178
+
170
179
**Unified input format:** All three CSV inputs — `data.simulate.input_csv`, `data.simulate.pretrain.input_csv`, and `data.plan.input_csv` — use the same campaign state schema (`smiles`, `relation`, `value`). For `moal simulate`, only `==` rows are loaded as oracle ground truth; PS and blank rows are skipped. For `moal plan` and the pretrain input, all labeled rows (`<`, `>=`, `==`) become training records; unqueried rows (empty) are inference targets or skipped with a warning, respectively.
171
180
172
181
**Pretrain warm-starting:** `moal simulate` accepts a pretrain CSV (`data.simulate.pretrain.input_csv`) in the same mixed-fidelity format. Pretrain records are merged with oracle-acquired records before each `model.refit()` call. Oracle records always win on a same-fidelity duplicate; pretrain PS INTERVAL records are automatically dropped when the oracle upgrades that compound to DRC. See `data.simulate.pretrain.*` in the config reference for all fields.
0 commit comments