diff --git a/distrax/_src/bijectors/bijector.py b/distrax/_src/bijectors/bijector.py index 05b76a2c..57e1230d 100644 --- a/distrax/_src/bijectors/bijector.py +++ b/distrax/_src/bijectors/bijector.py @@ -94,6 +94,7 @@ def __init__(self, Only set to True if you're absoltely sure the Jacobian determinant is constant; if you're not sure, set to None. """ + if event_ndims_out is None: event_ndims_out = event_ndims_in if event_ndims_in < 0: diff --git a/distrax/_src/distributions/distribution.py b/distrax/_src/distributions/distribution.py index f7d9fd6c..6f0b15d6 100644 --- a/distrax/_src/distributions/distribution.py +++ b/distrax/_src/distributions/distribution.py @@ -125,6 +125,7 @@ def sample(self, Returns: A sample of shape `sample_shape` + `batch_shape` + `event_shape`. """ + rng, sample_shape = convert_seed_and_sample_shape(seed, sample_shape) num_samples = functools.reduce(operator.mul, sample_shape, 1) # product