Skip to content

Weighted batch sampler to balance DRC and PS records per batch #23

@smcolby

Description

@smcolby

Problem

In typical usage, PS records outnumber DRC records roughly 2:1 (e.g. 8,000 PS vs 4,000 DRC). With random shuffling, the expected fraction of DRC records per batch is ~33%, meaning most batches are dominated by low-fidelity PS labels. This slows convergence of the high-fidelity DRC objective and may cause the model to over-optimize for the censored PS landscape during early training.

Proposed Solution

Add a WeightedRandomSampler in MixedFidelityDataModule that oversamples DRC records to achieve a configurable per-batch DRC:PS ratio:

from torch.utils.data import WeightedRandomSampler

def _make_sampler(records, target_drc_fraction=0.5):
    weights = []
    n_drc = sum(1 for r in records if r.fidelity == QueryType.DOSE_RESPONSE)
    n_ps  = len(records) - n_drc
    w_drc = target_drc_fraction / n_drc if n_drc > 0 else 0
    w_ps  = (1 - target_drc_fraction) / n_ps if n_ps > 0 else 0
    for r in records:
        weights.append(w_drc if r.fidelity == QueryType.DOSE_RESPONSE else w_ps)
    return WeightedRandomSampler(weights, num_samples=len(records), replacement=True)

New datamodule_kwargs option: drc_batch_fraction: float = 0.5

Files

  • moal/dataset.py

Notes

replacement=True is required for oversampling. This effectively creates a biased epoch where DRC records are seen more frequently, complementing the w_drc > w_ps loss weighting.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions