Problem
In typical usage, PS records outnumber DRC records roughly 2:1 (e.g. 8,000 PS vs 4,000 DRC). With random shuffling, the expected fraction of DRC records per batch is ~33%, meaning most batches are dominated by low-fidelity PS labels. This slows convergence of the high-fidelity DRC objective and may cause the model to over-optimize for the censored PS landscape during early training.
Proposed Solution
Add a WeightedRandomSampler in MixedFidelityDataModule that oversamples DRC records to achieve a configurable per-batch DRC:PS ratio:
from torch.utils.data import WeightedRandomSampler
def _make_sampler(records, target_drc_fraction=0.5):
weights = []
n_drc = sum(1 for r in records if r.fidelity == QueryType.DOSE_RESPONSE)
n_ps = len(records) - n_drc
w_drc = target_drc_fraction / n_drc if n_drc > 0 else 0
w_ps = (1 - target_drc_fraction) / n_ps if n_ps > 0 else 0
for r in records:
weights.append(w_drc if r.fidelity == QueryType.DOSE_RESPONSE else w_ps)
return WeightedRandomSampler(weights, num_samples=len(records), replacement=True)
New datamodule_kwargs option: drc_batch_fraction: float = 0.5
Files
Notes
replacement=True is required for oversampling. This effectively creates a biased epoch where DRC records are seen more frequently, complementing the w_drc > w_ps loss weighting.
Problem
In typical usage, PS records outnumber DRC records roughly 2:1 (e.g. 8,000 PS vs 4,000 DRC). With random shuffling, the expected fraction of DRC records per batch is ~33%, meaning most batches are dominated by low-fidelity PS labels. This slows convergence of the high-fidelity DRC objective and may cause the model to over-optimize for the censored PS landscape during early training.
Proposed Solution
Add a
WeightedRandomSamplerinMixedFidelityDataModulethat oversamples DRC records to achieve a configurable per-batch DRC:PS ratio:New datamodule_kwargs option: drc_batch_fraction: float = 0.5
Files
Notes
replacement=True is required for oversampling. This effectively creates a biased epoch where DRC records are seen more frequently, complementing the w_drc > w_ps loss weighting.