diff --git a/distrax/_src/distributions/von_mises_test.py b/distrax/_src/distributions/von_mises_test.py index 87d3e861..1bce94ec 100644 --- a/distrax/_src/distributions/von_mises_test.py +++ b/distrax/_src/distributions/von_mises_test.py @@ -390,8 +390,7 @@ def samples_grad(s, concentration): lambda conc: tfp.distributions.von_mises.von_mises_cdf(s, conc), broadcast_concentration) inv_prob = np.exp(-concentration * (np.cos(s) - 1.)) * ( - (2. * np.pi) * scipy.special.i0e(concentration) - ) + (2. * np.pi) * scipy.special.i0e(concentration)) # Computes the implicit derivative, # dz = dconc * -(dF(z; conc) / dconc) / p(z; conc) dsamples = -dcdf_dconcentration * inv_prob @@ -399,12 +398,10 @@ def samples_grad(s, concentration): for seed in range(10): sample, sample_grad = jax_sample_and_grad( - seed, jnp.array(locs), jnp.array(concentration) - ) + seed, jnp.array(locs), jnp.array(concentration)) comparison = samples_grad(sample, concentration) np.testing.assert_allclose( - comparison, sample_grad, rtol=1e-06, atol=1e-06 - ) + comparison, sample_grad, rtol=1e-06, atol=1e-06) def test_von_mises_sample_moments(self): locs_v = np.array([-1., 0.3, 2.3])