diff --git a/profold2/data/dataset.py b/profold2/data/dataset.py index 4831c02d..2ce16976 100644 --- a/profold2/data/dataset.py +++ b/profold2/data/dataset.py @@ -2021,6 +2021,7 @@ def _split_args(args): if 'weights' in kwargs: weights = kwargs.pop('weights') if weights: + assert len(dataset) == len(weights) kwargs['sampler'] = WeightedRandomSampler(weights, num_samples=len(weights)) if 'shuffle' in kwargs: kwargs.pop('shuffle')