Skip to content

Commit 90c46ad

Browse files
smcolbyCopilot
andcommitted
feat: add from_foundation flag to ChemPropLightningModule
Adds a from_foundation: str | bool parameter (default 'chemeleon') that controls how the ChemProp message-passing encoder is initialised: - 'chemeleon': downloads CheMeleon weights from Zenodo (existing behavior) - '/path/to/weights.pt': loads a local checkpoint in the same {hyper_parameters, state_dict} format as CheMeleon - False: builds BondMessagePassing() with default ChemProp architecture and random weights; no checkpoint required Validation is performed at construction time via _validate_from_foundation() which checks against _KNOWN_FOUNDATION_MODELS and Path.exists(). Unknown names and non-existent paths raise ValueError with a helpful message. Changes: - moal/config.py: add from_foundation field to ModelConfig - moal/model.py: _KNOWN_FOUNDATION_MODELS, _validate_from_foundation, updated __init__ / _build_model dispatch, _load_foundation_weights (replaces _get_chemeleon_mp) - moal/cli.py: forward from_foundation in both model builders - examples/default_config.yaml: document all three modes - tests/test_model.py: update hparam assertions; add TestFromFoundation class (7 new tests) - README.md: generalize encoder description; add Foundation Model design note - .github/copilot-instructions.md: update ModelConfig table; add Foundation model section; retire outdated _CHEMPELEON_ATOM_FDIM note Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 822677d commit 90c46ad

7 files changed

Lines changed: 208 additions & 28 deletions

File tree

.github/copilot-instructions.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ All campaign parameters live in `moal/config.py` as frozen dataclasses. The YAML
151151
| YAML key | Dataclass | Notable fields |
152152
|---|---|---|
153153
| `oracle:` | `OracleConfig` | `cost_ps`, `cost_drc`, `ps_threshold`, `upper_bound`, `activity_threshold` |
154-
| `model:` | `ModelConfig` | `hidden_size`, `depth`, `ffn_hidden_size`, `ffn_num_layers`, `freeze_epochs`, `lr_encoder`, `lr_head`, `sigma`, `w_drc`, `w_ps`, `learnable_sigma`, `reset_weights_on_refit`, **`fast`**, **`initial_error`**, **`final_error`** |
154+
| `model:` | `ModelConfig` | `hidden_size`, `depth`, `ffn_hidden_size`, `ffn_num_layers`, `freeze_epochs`, `lr_encoder`, `lr_head`, `sigma`, `w_drc`, `w_ps`, `learnable_sigma`, `reset_weights_on_refit`, **`fast`**, **`initial_error`**, **`final_error`**, **`from_foundation`** |
155155
| `acquisition:` | `AcquisitionConfig` | `ps_threshold`, `target_threshold`, **`tau`** |
156156
| `trainer:` | `TrainerConfig` | `max_epochs`, `accelerator`, `enable_progress_bar`, `enable_model_summary`, `val_fraction`, `split_seed`, `num_workers`, `log_every_n_steps` |
157157
| `dashboard:` | `DashboardConfig` | `enabled`, `model_metric`, `port`, `export_width`, `export_height`, `theme` |
@@ -213,7 +213,7 @@ All modules use `logger = logging.getLogger(__name__)`. The `suppress_noisy_logg
213213

214214
### Freeze/unfreeze schedule
215215

216-
`ChemPropLightningModule` freezes the CheMeleon encoder for the first `freeze_epochs` training epochs, then unfreezes and adds a second optimizer for the encoder at `lr_encoder`. The epoch counter resets on every `trainer.fit()` call (every AL iteration). This is intentional — early iterations have tiny labeled pools where encoder fine-tuning would overfit.
216+
`ChemPropLightningModule` freezes the message-passing encoder for the first `freeze_epochs` training epochs, then unfreezes and adds a second optimizer for the encoder at `lr_encoder`. The epoch counter resets on every `trainer.fit()` call (every AL iteration). This is intentional — early iterations have tiny labeled pools where encoder fine-tuning would overfit.
217217

218218
### Scaffold split
219219

@@ -227,6 +227,14 @@ All modules use `logger = logging.getLogger(__name__)`. The `suppress_noisy_logg
227227

228228
`CliRunner.invoke()` does **not** sandbox file I/O by default. Any test that triggers CLI output-directory creation must pass `--output-dir str(tmp_path / "out")` (or use `runner.isolated_filesystem()`) to avoid leaking `results/` into the pytest CWD.
229229

230-
### CheMeleon feature dimensions
230+
### Foundation model (`from_foundation`)
231231

232-
`_CHEMPELEON_ATOM_FDIM = 72` and `_CHEMPELEON_BOND_FDIM = 14` in `model.py` are hardcoded to match the CheMeleon pretraining feature spec. These are verified at model initialization. Do not change them without updating the checkpoint.
232+
`ChemPropLightningModule` accepts a `from_foundation: str | bool` constructor parameter (default `"chemeleon"`) that controls encoder initialisation:
233+
234+
- `"chemeleon"` — downloads the CheMeleon checkpoint from Zenodo, caches it at `~/.chemprop/chemeleon_mp.pt`, and loads its weights.
235+
- Any other string — treated as a filesystem path; the checkpoint must have `{"hyper_parameters": ..., "state_dict": ...}` format (same as CheMeleon).
236+
- `False` — builds `BondMessagePassing()` with default ChemProp architecture and random weights; no checkpoint is required.
237+
238+
The known-name registry lives in `_KNOWN_FOUNDATION_MODELS: frozenset[str]` at module level. `_validate_from_foundation(value)` is called at `__init__` time; it raises `ValueError` for unknown names and non-existent paths. `from_foundation` is included in `save_hyperparameters()`.
239+
240+
`True` is not a valid value — the validator catches it because `True` is not in `_KNOWN_FOUNDATION_MODELS` and `Path(True)` is not a valid path expression.

README.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ A Python pipeline for maximizing the discovery of **active compounds** (pEC50 >
77
- **Primary Screen (PS):** Returns an inequality label (`< T` or `>= T`) at a configurable threshold. Cheap. A hit (`>= T`) is an INTERVAL-censored label — eligible for a DRC upgrade in a later iteration.
88
- **Dose-Response Curve (DRC):** Returns the exact continuous pEC50 value. Expensive. Can be run as a first-pass query *or* as a follow-up upgrade on a PS hit.
99

10-
The underlying predictive model is **ChemProp** initialized with **CheMeleon** pretrained weights, trained with a **Tobit (censored regression) loss** that correctly handles both label types.
10+
The underlying predictive model is **ChemProp** fine-tuned with a **Tobit (censored regression) loss** that correctly handles both label types. By default the ChemProp encoder is initialised with **CheMeleon** pretrained weights (see `model.from_foundation` below).
1111

1212
## Installation
1313

@@ -49,7 +49,8 @@ data:
4949
```
5050
5151
CheMeleon pretrained weights are downloaded automatically from Zenodo on first
52-
run and cached at `~/.chemprop/chemeleon_mp.pt` for subsequent use.
52+
run and cached at `~/.chemprop/chemeleon_mp.pt` for subsequent use (only when
53+
`model.from_foundation: chemeleon`, the default).
5354

5455
## Commands
5556

@@ -167,6 +168,14 @@ The campaign emits a rich progress bar with `n_iterations × 3` discrete steps:
167168

168169
## Key Design Notes
169170

171+
**Foundation model (`model.from_foundation`):** Controls which weights initialise the ChemProp message-passing encoder. Three values are accepted:
172+
173+
- `"chemeleon"` (default) — downloads the CheMeleon checkpoint from Zenodo and loads it.
174+
- A filesystem path string — loads a local checkpoint in the same `{hyper_parameters, state_dict}` format as CheMeleon.
175+
- `false` — builds the encoder with default ChemProp architecture and random weights; no checkpoint required. Useful for ablation studies or environments without network access.
176+
177+
Unknown named strings and non-existent paths raise `ValueError` at model construction. The `from_foundation` value is recorded in Lightning checkpoints alongside all other hyperparameters.
178+
170179
**Unified input format:** All three CSV inputs — `data.simulate.input_csv`, `data.simulate.pretrain.input_csv`, and `data.plan.input_csv` — use the same campaign state schema (`smiles`, `relation`, `value`). For `moal simulate`, only `==` rows are loaded as oracle ground truth; PS and blank rows are skipped. For `moal plan` and the pretrain input, all labeled rows (`<`, `>=`, `==`) become training records; unqueried rows (empty) are inference targets or skipped with a warning, respectively.
171180

172181
**Pretrain warm-starting:** `moal simulate` accepts a pretrain CSV (`data.simulate.pretrain.input_csv`) in the same mixed-fidelity format. Pretrain records are merged with oracle-acquired records before each `model.refit()` call. Oracle records always win on a same-fidelity duplicate; pretrain PS INTERVAL records are automatically dropped when the oracle upgrades that compound to DRC. See `data.simulate.pretrain.*` in the config reference for all fields.

examples/default_config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ model:
104104
initial_error: 1.2 # Fast mode: Starting noise magnitude (pEC50 units) for the error ramp in fast mode
105105
final_error: 0.65 # Fast mode: Ending noise magnitude for the ramp
106106
# Set equal to initial_error for constant noise
107+
from_foundation: chemeleon # Encoder initialisation: "chemeleon" (download CheMeleon from Zenodo),
108+
# a local path to a {hyper_parameters, state_dict} checkpoint file,
109+
# or false for random ChemProp weights (no checkpoint required)
107110

108111
# --------------------------------------------------------------------------
109112
# PyTorch Lightning trainer

moal/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ def _build_simulation_model(
710710
w_drc=cfg.model.w_drc,
711711
w_ps=cfg.model.w_ps,
712712
learnable_sigma=cfg.model.learnable_sigma,
713+
from_foundation=cfg.model.from_foundation,
713714
)
714715

715716

@@ -748,6 +749,7 @@ def _build_plan_model(cfg: PipelineConfig) -> ChemPropLightningModule:
748749
w_drc=cfg.model.w_drc,
749750
w_ps=cfg.model.w_ps,
750751
learnable_sigma=cfg.model.learnable_sigma,
752+
from_foundation=cfg.model.from_foundation,
751753
)
752754

753755

moal/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ class ModelConfig:
8282
mode. The ramp linearly interpolates from initial_error to final_error
8383
over all iterations. Set equal to initial_error for a constant noise
8484
level.
85+
from_foundation : str or bool
86+
Controls encoder initialisation. ``"chemeleon"`` (default) downloads
87+
and loads CheMeleon pretrained weights. A filesystem path string loads
88+
a local checkpoint in the same ``{hyper_parameters, state_dict}``
89+
format. ``False`` builds the encoder with default ChemProp architecture
90+
and random weights (no checkpoint required).
8591
"""
8692

8793
ffn_hidden_size: int = 300
@@ -97,6 +103,7 @@ class ModelConfig:
97103
fast: bool = False
98104
initial_error: float = 0.7
99105
final_error: float = 0.5
106+
from_foundation: str | bool = "chemeleon"
100107

101108

102109
@dataclass(frozen=True)

moal/model.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,41 @@
3434

3535
logger = logging.getLogger(__name__)
3636

37+
_KNOWN_FOUNDATION_MODELS: frozenset[str] = frozenset({"chemeleon"})
38+
39+
40+
def _validate_from_foundation(value: str | bool) -> None:
41+
"""Validate the ``from_foundation`` parameter value.
42+
43+
Parameters
44+
----------
45+
value : str or bool
46+
The value to validate.
47+
48+
Raises
49+
------
50+
ValueError
51+
If ``value`` is not ``False``, a known named model, or an existing
52+
file path.
53+
"""
54+
if value is False:
55+
return
56+
if not isinstance(value, str):
57+
raise ValueError(
58+
f"from_foundation must be False, a known model name, or a filesystem path; "
59+
f"got {value!r}. Known names: {sorted(_KNOWN_FOUNDATION_MODELS)}"
60+
)
61+
if value in _KNOWN_FOUNDATION_MODELS:
62+
return
63+
if Path(value).exists():
64+
return
65+
raise ValueError(
66+
f"from_foundation={value!r} is not a recognised foundation model name "
67+
f"and does not resolve to an existing file path. "
68+
f"Known names: {sorted(_KNOWN_FOUNDATION_MODELS)}. "
69+
"Pass False to use random ChemProp weights."
70+
)
71+
3772

3873
def download_chemeleon() -> None:
3974
"""Download the CheMeleon checkpoint if not already cached locally.
@@ -67,7 +102,7 @@ def download_chemeleon() -> None:
67102

68103

69104
class ChemPropLightningModule(L.LightningModule):
70-
"""ChemProp MPNN fine-tuned from CheMeleon pretrained weights.
105+
"""ChemProp MPNN with configurable foundation-model encoder initialisation.
71106
72107
Parameters
73108
----------
@@ -91,6 +126,12 @@ class ChemPropLightningModule(L.LightningModule):
91126
Primary screen loss weight. Default is 0.3.
92127
learnable_sigma : bool, optional
93128
If True, σ is a learned parameter. Default is False.
129+
from_foundation : str or bool, optional
130+
Controls encoder initialisation. ``"chemeleon"`` (default) downloads
131+
and loads CheMeleon pretrained weights. A filesystem path string loads
132+
a local checkpoint in ``{hyper_parameters, state_dict}`` format.
133+
``False`` builds the encoder with default ChemProp architecture and
134+
random weights.
94135
"""
95136

96137
def __init__(
@@ -104,8 +145,11 @@ def __init__(
104145
w_drc: float = 1.0,
105146
w_ps: float = 0.3,
106147
learnable_sigma: bool = False,
148+
from_foundation: str | bool = "chemeleon",
107149
) -> None:
108150
super().__init__()
151+
_validate_from_foundation(from_foundation)
152+
self._from_foundation = from_foundation
109153
self.save_hyperparameters()
110154

111155
self.freeze_epochs = freeze_epochs
@@ -132,7 +176,7 @@ def _build_model(
132176
ffn_hidden_size: int,
133177
ffn_num_layers: int,
134178
) -> nn.Module:
135-
"""Construct the MPNN with CheMeleon message-passing weights.
179+
"""Construct the MPNN, dispatching on ``self._from_foundation``.
136180
137181
Parameters
138182
----------
@@ -144,42 +188,43 @@ def _build_model(
144188
Returns
145189
-------
146190
nn.Module
147-
Fully assembled ``chemprop.models.MPNN`` with pretrained
148-
message-passing weights and a freshly initialised FFN head.
191+
Fully assembled ``chemprop.models.MPNN``.
149192
"""
150-
chemeleon_weights = self._get_chemeleon_mp()
193+
if self._from_foundation is False:
194+
logger.info("Building ChemProp encoder with random weights (from_foundation=False).")
195+
mp: nn.Module = BondMessagePassing()
196+
else:
197+
foundation_weights = self._load_foundation_weights()
198+
mp = BondMessagePassing(**foundation_weights["hyper_parameters"])
199+
mp.load_state_dict(foundation_weights["state_dict"])
151200

152-
# Mean aggregation
153201
agg = MeanAggregation()
154-
155-
# Message passing
156-
mp = BondMessagePassing(**chemeleon_weights["hyper_parameters"])
157-
mp.load_state_dict(chemeleon_weights["state_dict"])
158-
159-
# FFN predictor head
160202
ffn = RegressionFFN(
161-
input_dim=mp.output_dim, # Infer input dim from mp output
203+
input_dim=cast(BondMessagePassing, mp).output_dim,
162204
hidden_dim=ffn_hidden_size,
163205
n_layers=ffn_num_layers,
164206
)
165207
return cast(nn.Module, MPNN(message_passing=mp, agg=agg, predictor=ffn))
166208

167-
def _get_chemeleon_mp(self) -> dict:
168-
"""Load and return the CheMeleon pretrained message-passing weights.
209+
def _load_foundation_weights(self) -> dict:
210+
"""Load pretrained message-passing weights from a named model or local path.
169211
170-
Calls :func:`download_chemeleon` to ensure the checkpoint exists at
171-
``~/.chemprop/chemeleon_mp.pt``, then loads it with
172-
``weights_only=True``.
212+
When ``self._from_foundation == "chemeleon"`` the checkpoint is
213+
downloaded from Zenodo if not already cached. For any other string
214+
value it is treated as a local filesystem path.
173215
174216
Returns
175217
-------
176218
dict
177219
Checkpoint dictionary with ``hyper_parameters`` and
178220
``state_dict`` keys.
179221
"""
180-
# Ensure the CheMeleon checkpoint is downloaded
181-
download_chemeleon()
182-
ckpt_path = Path().home() / ".chemprop" / "chemeleon_mp.pt"
222+
if self._from_foundation == "chemeleon":
223+
download_chemeleon()
224+
ckpt_path = Path().home() / ".chemprop" / "chemeleon_mp.pt"
225+
else:
226+
ckpt_path = Path(str(self._from_foundation))
227+
logger.info("Loading foundation weights from local path: %s", ckpt_path)
183228
return cast(dict[str, Any], torch.load(ckpt_path, weights_only=True))
184229

185230
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)