From a2249b2c5d472697c6194a8d3b07a46097bcafbf Mon Sep 17 00:00:00 2001 From: Francisco Ruiz Date: Tue, 26 Apr 2022 01:20:58 -0700 Subject: [PATCH] Improve the tests in `bernoulli_test.py`. PiperOrigin-RevId: 444487876 --- distrax/_src/distributions/bernoulli_test.py | 351 +++++++------------ 1 file changed, 121 insertions(+), 230 deletions(-) diff --git a/distrax/_src/distributions/bernoulli_test.py b/distrax/_src/distributions/bernoulli_test.py index ac2d3025..09f15f07 100644 --- a/distrax/_src/distributions/bernoulli_test.py +++ b/distrax/_src/distributions/bernoulli_test.py @@ -20,7 +20,6 @@ import chex from distrax._src.distributions import bernoulli from distrax._src.utils import equivalence -import jax import jax.numpy as jnp import numpy as np from scipy import special as sp_special @@ -34,15 +33,24 @@ def setUp(self): self.p = np.asarray([0.2, 0.4, 0.6, 0.8]) self.logits = sp_special.logit(self.p) - def test_parameters_from_probs(self): - dist = self.distrax_cls(probs=self.p) - self.assertion_fn(rtol=1e-2)(dist.logits, self.logits) - self.assertion_fn(rtol=1e-2)(dist.probs, self.p) - - def test_parameters_from_logits(self): - dist = self.distrax_cls(logits=self.logits) - self.assertion_fn(rtol=1e-2)(dist.logits, self.logits) - self.assertion_fn(rtol=1e-2)(dist.probs, self.p) + @parameterized.named_parameters( + ('0d probs', (), True), + ('0d logits', (), False), + ('1d probs', (4,), True), + ('1d logits', (4,), False), + ('2d probs', (3, 4), True), + ('2d logits', (3, 4), False), + ) + def test_properties(self, shape, from_probs): + rng = np.random.default_rng(42) + probs = rng.uniform(size=shape) + logits = sp_special.logit(probs) + dist_kwargs = {'probs': probs} if from_probs else {'logits': logits} + dist = self.distrax_cls(**dist_kwargs) + self.assertion_fn(rtol=1e-3)(dist.logits, logits) + self.assertion_fn(rtol=1e-3)(dist.probs, probs) + self.assertEqual(dist.event_shape, ()) + self.assertEqual(dist.batch_shape, shape) @parameterized.named_parameters( ('probs and logits', {'logits': [0.1, -0.2], 'probs': [0.5, 0.4]}), @@ -54,41 +62,6 @@ def test_raises_on_invalid_inputs(self, dist_params): with self.assertRaises(ValueError): self.distrax_cls(**dist_params) - @chex.all_variants(with_pmap=False) - @parameterized.named_parameters( - ('from_logits', True), - ('from_probs', False)) - def test_log_probs_parameter(self, from_logits): - distr_params = {'logits': self.logits} if from_logits else {'probs': self.p} - dist = self.distrax_cls(**distr_params) - log_probs0, log_probs1 = self.variant(dist._log_probs_parameter)() - self.assertion_fn(rtol=1e-2)(log_probs1, np.log(self.p)) - self.assertion_fn(rtol=1e-2)(log_probs0, np.log(1 - self.p)) - - @chex.all_variants(with_pmap=False) - @parameterized.named_parameters( - ('from_logits', True), - ('from_probs', False)) - def test_probs_and_log_probs(self, from_logits): - distr_params = {'logits': self.logits} if from_logits else {'probs': self.p} - dist = self.distrax_cls(**distr_params) - probs0, probs1, log_probs0, log_probs1 = self.variant( - bernoulli._probs_and_log_probs)(dist) - self.assertion_fn(rtol=1e-2)(probs1, self.p) - self.assertion_fn(rtol=1e-2)(probs0, 1.0 - self.p) - self.assertion_fn(rtol=1e-2)(log_probs1, np.log(self.p)) - self.assertion_fn(rtol=1e-2)(log_probs0, np.log(1.0 - self.p)) - - @parameterized.named_parameters( - ('1d logits', {'logits': [0.0, 1.0, -0.5]}), - ('1d probs', {'probs': [0.1, 0.5, 0.3]}), - ('2d logits', {'logits': [[0.0, 1.0, -0.5], [-0.1, 0.3, 0.0]]}), - ('2d probs', {'probs': [[0.1, 0.4, 0.5], [0.5, 0.25, 0.25]]}), - ) - def test_event_shape(self, distr_params): - distr_params = {k: jnp.asarray(v) for k, v in distr_params.items()} - super()._test_event_shape((), distr_params) - @chex.all_variants @parameterized.named_parameters( ('1d logits, no shape', {'logits': [0.0, 1.0, -0.5]}, ()), @@ -123,6 +96,28 @@ def test_sample_shape(self, distr_params, sample_shape): dist_kwargs=distr_params, sample_shape=sample_shape) + @chex.all_variants + @parameterized.named_parameters( + ('sample, from probs', 'sample', True), + ('sample, from logits', 'sample', False), + ('sample_and_log_prob, from probs', 'sample_and_log_prob', True), + ('sample_and_log_prob, from logits', 'sample_and_log_prob', False), + ) + def test_sample_values(self, method, from_probs): + probs = np.array([0., 0.2, 0.5, 0.8, 1.]) # Includes edge cases (0 and 1). + logits = sp_special.logit(probs) + n_samples = 100000 + dist_kwargs = {'probs': probs} if from_probs else {'logits': logits} + dist = self.distrax_cls(**dist_kwargs) + sample_fn = self.variant( + lambda key: getattr(dist, method)(seed=key, sample_shape=n_samples)) + samples = sample_fn(self.key) + samples = samples[0] if method == 'sample_and_log_prob' else samples + self.assertEqual(samples.shape, (n_samples,) + probs.shape) + self.assertTrue(np.all(np.logical_or(samples == 0, samples == 1))) + self.assertion_fn(rtol=0.1)(np.mean(samples, axis=0), probs) + self.assertion_fn(rtol=0.1)(np.std(samples, axis=0), dist.stddev()) + @chex.all_variants @parameterized.named_parameters( ('1d logits, no shape', {'logits': [0.0, 1.0, -0.5]}, ()), @@ -160,201 +155,90 @@ def test_sample_and_log_prob(self, distr_params, sample_shape): @chex.all_variants @parameterized.named_parameters( - ('bool', jnp.bool_), - ('uint32', jnp.uint32), - ('uint64', jnp.uint64), - ('int32', jnp.int32), - ('int64', jnp.int64), - ('float32', jnp.float32), - ('float64', jnp.float64)) - def test_sample_dtype(self, dtype): + ('sample, bool', 'sample', jnp.bool_), + ('sample, uint16', 'sample', jnp.uint16), + ('sample, uint32', 'sample', jnp.uint32), + ('sample, int16', 'sample', jnp.int16), + ('sample, int32', 'sample', jnp.int32), + ('sample, float16', 'sample', jnp.float16), + ('sample, float32', 'sample', jnp.float32), + ('sample_and_log_prob, bool', 'sample_and_log_prob', jnp.bool_), + ('sample_and_log_prob, uint16', 'sample_and_log_prob', jnp.uint16), + ('sample_and_log_prob, uint32', 'sample_and_log_prob', jnp.uint32), + ('sample_and_log_prob, int16', 'sample_and_log_prob', jnp.int16), + ('sample_and_log_prob, int32', 'sample_and_log_prob', jnp.int32), + ('sample_and_log_prob, float16', 'sample_and_log_prob', jnp.float16), + ('sample_and_log_prob, float32', 'sample_and_log_prob', jnp.float32), + ) + def test_sample_dtype(self, method, dtype): dist_params = {'logits': self.logits, 'dtype': dtype} dist = self.distrax_cls(**dist_params) - samples = self.variant(dist.sample)(seed=self.key) + samples = self.variant(getattr(dist, method))(seed=self.key) + samples = samples[0] if method == 'sample_and_log_prob' else samples self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) - - @chex.all_variants - @parameterized.named_parameters( - ('from probs', False), - ('from logits', True)) - def test_sample_unique_values(self, from_logits): - dist_params = {'logits': self.logits} if from_logits else {'probs': self.p} - dist = self.distrax_cls(**dist_params) - sample_fn = self.variant( - lambda key: dist.sample(seed=key, sample_shape=100)) - samples = sample_fn(self.key) - np.testing.assert_equal(np.unique(samples), np.asarray([0, 1])) - - @chex.all_variants - @parameterized.named_parameters( - ('zeros, float', 0.), - ('zeros, int', 0), - ('ones, float', 1.), - ('ones, int', 1)) - def test_sample_extreme_probs(self, p_extreme): - dist_params = {'probs': p_extreme} - dist = self.distrax_cls(**dist_params) - sample_fn = self.variant(lambda k: dist.sample(seed=k, sample_shape=100)) - samples = sample_fn(self.key) - np.testing.assert_equal(np.unique(samples), - np.asarray(p_extreme).astype(np.int32)) - - @chex.all_variants - @parameterized.named_parameters( - ('plus_inf', jnp.inf, 1), - ('minus_inf', -jnp.inf, 0)) - def test_sample_extreme_logits(self, l_extreme, expected): - dist_params = {'logits': l_extreme} - dist = self.distrax_cls(**dist_params) - sample_fn = self.variant(lambda k: dist.sample(seed=k, sample_shape=100)) - samples = sample_fn(self.key) - np.testing.assert_equal(np.unique(samples), np.array(expected)) + self.assertEqual(samples.dtype, dtype) @chex.all_variants @parameterized.named_parameters( - ('log_prob; 1d logits, int value', - 'log_prob', - {'logits': [0.0, 0.5, -0.5]}, - 1), - ('log_prob; 1d probs, int value', - 'log_prob', - {'probs': [0.3, 0.2, 0.5]}, - 1), - ('log_prob; 1d logits, 1d value', - 'log_prob', - {'logits': [0.0, 0.5, -0.5]}, - [1, 0, 1]), - ('log_prob; 1d probs, 1d value', - 'log_prob', - {'probs': [0.3, 0.2, 0.5]}, - [1, 0, 1]), - ('log_prob; 1d logits, 2d value', - 'log_prob', - {'logits': [0.0, 0.5, -0.5]}, - [[1, 0, 0], [0, 1, 0]]), - ('log_prob; 1d probs, 2d value', - 'log_prob', - {'probs': [0.3, 0.2, 0.5]}, - [[1, 0, 0], [0, 1, 0]]), - ('log_prob; 2d logits, 1d value', - 'log_prob', - {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}, - [1, 0, 1]), - ('log_prob; 2d probs, 1d value', - 'log_prob', - {'probs': [[0.1, 0.5, 0.4], [0.3, 0.3, 0.4]]}, - [1, 0, 1]), - ('log_prob; 2d logits, 2d value', - 'log_prob', - {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}, - [[1, 0, 0], [1, 1, 0]]), - ('log_prob; 2d probs, 2d value', - 'log_prob', - {'probs': [[0.1, 0.5, 0.4], [0.3, 0.3, 0.4]]}, - [[1, 0, 0], [1, 1, 0]]), - ('log_prob; extreme probs', - 'log_prob', - {'probs': [0.0, 0.0, 1.0, 1.0]}, - [0, 1, 0, 1]), - ('prob; 1d logits, int value', - 'prob', - {'logits': [0.0, 0.5, -0.5]}, - 1), - ('prob; 1d probs, int value', - 'prob', - {'probs': [0.3, 0.2, 0.5]}, - 1), - ('prob; 1d logits, 1d value', - 'prob', - {'logits': [0.0, 0.5, -0.5]}, - [1, 0, 1]), - ('prob; 1d probs, 1d value', - 'prob', - {'probs': [0.3, 0.2, 0.5]}, - [1, 0, 1]), - ('prob; 1d logits, 2d value', - 'prob', - {'logits': [0.0, 0.5, -0.5]}, + ('1d logits, int value', {'logits': [0.0, 0.5, -0.5]}, 1), + ('1d probs, int value', {'probs': [0.3, 0.2, 0.5]}, 1), + ('1d logits, 1d value', {'logits': [0.0, 0.5, -0.5]}, [1, 0, 1]), + ('1d probs, 1d value', {'probs': [0.3, 0.2, 0.5]}, [1, 0, 1]), + ('1d logits, 2d value', {'logits': [0.0, 0.5, -0.5]}, [[1, 0, 0], [0, 1, 0]]), - ('prob; 1d probs, 2d value', - 'prob', - {'probs': [0.3, 0.2, 0.5]}, + ('1d probs, 2d value', {'probs': [0.3, 0.2, 0.5]}, [[1, 0, 0], [0, 1, 0]]), - ('prob; 2d logits, 1d value', - 'prob', - {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}, + ('2d logits, 1d value', {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}, [1, 0, 1]), - ('prob; 2d probs, 1d value', - 'prob', - {'probs': [[0.1, 0.5, 0.4], [0.3, 0.3, 0.4]]}, + ('2d probs, 1d value', {'probs': [[0.1, 0.5, 0.4], [0.3, 0.3, 0.4]]}, [1, 0, 1]), - ('prob; 2d logits, 2d value', - 'prob', - {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}, + ('2d logits, 2d value', {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}, [[1, 0, 0], [1, 1, 0]]), - ('prob; 2d probs, 2d value', - 'prob', - {'probs': [[0.1, 0.5, 0.4], [0.3, 0.3, 0.4]]}, + ('2d probs, 2d value', {'probs': [[0.1, 0.5, 0.4], [0.3, 0.3, 0.4]]}, [[1, 0, 0], [1, 1, 0]]), - ('prob; extreme probs', - 'prob', - {'probs': [0.0, 0.0, 1.0, 1.0]}, + ('edge cases with logits', {'logits': [-np.inf, -np.inf, np.inf, np.inf]}, [0, 1, 0, 1]), - ('cdf; from 2d logits', - 'cdf', - {'probs': [[0.1, 0.5, 0.4], [0.3, 0.3, 0.4]]}, - [[1, 0, 0], [1, 1, 0]]), - ('cdf; from 2d probs', - 'cdf', - {'probs': [[0.1, 0.5, 0.4], [0.3, 0.3, 0.4]]}, - [[1, 0, 0], [1, 1, 0]]), - ('log_cdf; from 2d logits', - 'log_cdf', - {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}, - [[1, 0, 0], [1, 1, 0]]), - ('log_cdf; from 2d probs', - 'log_cdf', - {'probs': [[0.1, 0.5, 0.4], [0.3, 0.3, 0.4]]}, - [[1, 0, 0], [1, 1, 0]]), + ('edge cases with probs', {'probs': [0.0, 0.0, 1.0, 1.0]}, [0, 1, 0, 1]), ) - def test_pdf(self, function_string, distr_params, value): + def test_method_with_value(self, distr_params, value): distr_params = {k: jnp.asarray(v) for k, v in distr_params.items()} value = jnp.asarray(value) - super()._test_attribute( - attribute_string=function_string, - dist_kwargs=distr_params, - call_args=(value,), - assertion_fn=self.assertion_fn(rtol=1e-2)) + for method in ['prob', 'log_prob', 'cdf', 'log_cdf', + 'survival_function', 'log_survival_function']: + with self.subTest(method=method): + super()._test_attribute( + attribute_string=method, + dist_kwargs=distr_params, + call_args=(value,), + assertion_fn=self.assertion_fn(rtol=1e-2)) @chex.all_variants(with_pmap=False) @parameterized.named_parameters( - ('entropy; from 2d logits', - 'entropy', {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}), - ('entropy; from 2d probs', - 'entropy', {'probs': [[0.1, 0.5, 0.4], [0.2, 0.4, 0.4]]}), - ('mode; from 2d logits', - 'mode', {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}), - ('mode; from 2d probs', - 'mode', {'probs': [[0.1, 0.5, 0.4], [0.2, 0.4, 0.4]]}), + ('from logits', {'logits': [[0.0, 0.5, -0.5], [-0.2, 0.3, 0.5]]}), + ('from probs', {'probs': [[0.1, 0.5, 0.4], [0.2, 0.4, 0.4]]}), ) - def test_method(self, function_string, distr_params): + def test_method(self, distr_params): distr_params = {k: jnp.asarray(v) for k, v in distr_params.items()} - super()._test_attribute( - attribute_string=function_string, - dist_kwargs=distr_params, - call_args=(), - assertion_fn=self.assertion_fn(rtol=1e-2)) + for method in ['entropy', 'mode', 'mean', 'variance', 'stddev']: + with self.subTest(method=method): + super()._test_attribute( + attribute_string=method, + dist_kwargs=distr_params, + call_args=(), + assertion_fn=self.assertion_fn(rtol=1e-2)) @chex.all_variants(with_pmap=False) @parameterized.named_parameters( - ('from_logits', True), - ('from_probs', False)) - def test_median(self, from_logits): - distr_params = {'logits': self.logits} if from_logits else {'probs': self.p} - dist = self.distrax_cls(**distr_params) - self.assertion_fn(rtol=1e-2)( - self.variant(dist.median)(), self.variant(dist.mean)()) + ('from probs', True), + ('from logits', False), + ) + def test_median(self, from_probs): + rng = np.random.default_rng(42) + probs = rng.uniform(size=(4, 5)) + logits = sp_special.logit(probs) + dist_kwargs = {'probs': probs} if from_probs else {'logits': logits} + dist = self.distrax_cls(**dist_kwargs) + self.assertion_fn(rtol=1e-3)(self.variant(dist.median)(), probs) @chex.all_variants(with_pmap=False) @parameterized.named_parameters( @@ -363,33 +247,40 @@ def test_median(self, from_logits): ('kl tfp_to_distrax', 'kl_divergence', 'tfp_to_distrax'), ('cross-ent distrax_to_distrax', 'cross_entropy', 'distrax_to_distrax'), ('cross-ent distrax_to_tfp', 'cross_entropy', 'distrax_to_tfp'), - ('cross-ent tfp_to_distrax', 'cross_entropy', 'tfp_to_distrax')) + ('cross-ent tfp_to_distrax', 'cross_entropy', 'tfp_to_distrax'), + ) def test_with_two_distributions(self, function_string, mode_string): super()._test_with_two_distributions( attribute_string=function_string, mode_string=mode_string, - dist1_kwargs={'probs': jnp.asarray([[0.1, 0.5, 0.4], [0.2, 0.4, 0.4]])}, + dist1_kwargs={ + 'probs': jnp.asarray([[0.1, 0.5, 0.4], [0.2, 0.4, 0.8]])}, dist2_kwargs={'logits': jnp.asarray([0.0, -0.1, 0.1]),}, assertion_fn=self.assertion_fn(rtol=1e-2)) def test_jittable(self): super()._test_jittable( (np.array([0., 4., -1., 4.]),), - assertion_fn=self.assertion_fn(rtol=1e-2)) + assertion_fn=self.assertion_fn(rtol=1e-3)) @parameterized.named_parameters( - ('single element', 2), - ('range', slice(-1)), - ('range_2', (slice(None), slice(-1))), - ('ellipsis', (Ellipsis, -1)), + ('single element, from probs', 2, True), + ('single element, from logits', 2, False), + ('range, from probs', slice(-1), True), + ('range, from logits', slice(-1), False), + ('range_2, from probs', (slice(None), slice(-1)), True), + ('range_2, from logits', (slice(None), slice(-1)), False), + ('ellipsis, from probs', (Ellipsis, -1), True), + ('ellipsis, from logits', (Ellipsis, -1), False), ) - def test_slice(self, slice_): - logits = jnp.array(np.random.randn(3, 4, 5)) - probs = jax.nn.softmax(jnp.array(np.random.randn(3, 4, 5)), axis=-1) - dist1 = self.distrax_cls(logits=logits) - dist2 = self.distrax_cls(probs=probs) - self.assertion_fn(rtol=1e-2)(dist1[slice_].logits, logits[slice_]) - self.assertion_fn(rtol=1e-2)(dist2[slice_].probs, probs[slice_]) + def test_slice(self, slice_, from_probs): + rng = np.random.default_rng(42) + probs = rng.uniform(size=(3, 4, 5)) + logits = sp_special.logit(probs) + dist_kwargs = {'probs': probs} if from_probs else {'logits': logits} + dist = self.distrax_cls(**dist_kwargs) + self.assertion_fn(rtol=1e-3)(dist[slice_].logits, logits[slice_]) + self.assertion_fn(rtol=1e-3)(dist[slice_].probs, probs[slice_]) if __name__ == '__main__':