Problem
When val_fraction > 0, the current train/val split is a random shuffle. For imbalanced activity distributions (e.g. ~68% inactive, ~32% active
at pEC50 ≥ 5.0, with <2% highly potent), random splits can produce validation sets with very different activity distributions from the training
set. This leads to unreliable validation loss as a stopping criterion or hyperparameter selection signal.
Proposed Solution
In MixedFidelityDataModule (or equivalent in dataset.py), implement stratified splitting for DRC records:
from sklearn.model_selection import StratifiedShuffleSplit
# Stratify DRC records by pEC50 quartile
drc_values = np.array([r.value for r in drc_records])
quartile_labels = pd.qcut(drc_values, q=4, labels=False)
splitter = StratifiedShuffleSplit(n_splits=1, test_size=val_fraction, random_state=seed)
train_idx, val_idx = next(splitter.split(drc_records, quartile_labels))
PS records are split proportionally (random, since they carry no exact value).
New parameter: stratify_split: bool = True (default on, backward-compatible when val_fraction=0.0).
Files
Notes
sklearn is likely already in the dependency tree transitively; if not, the quartile binning can be implemented with numpy.percentile and manual
index assignment.
Problem
When
val_fraction > 0, the current train/val split is a random shuffle. For imbalanced activity distributions (e.g. ~68% inactive, ~32% activeat pEC50 ≥ 5.0, with <2% highly potent), random splits can produce validation sets with very different activity distributions from the training
set. This leads to unreliable validation loss as a stopping criterion or hyperparameter selection signal.
Proposed Solution
In
MixedFidelityDataModule(or equivalent indataset.py), implement stratified splitting for DRC records:PS records are split proportionally (random, since they carry no exact value).
New parameter: stratify_split: bool = True (default on, backward-compatible when val_fraction=0.0).
Files
Notes
sklearn is likely already in the dependency tree transitively; if not, the quartile binning can be implemented with numpy.percentile and manual
index assignment.