Skip to content

Commit 90da06c

Browse files
DistraxDevDistraxDev
authored andcommitted
Migrate RLax distributions to use distrax.
PiperOrigin-RevId: 368473519
1 parent 941bb83 commit 90da06c

1 file changed

Lines changed: 0 additions & 2 deletions

File tree

distrax/_src/distributions/categorical.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ def __init__(self,
5050
"""
5151
super().__init__()
5252
chex.assert_exactly_one_is_none(probs, logits)
53-
chex.if_args_not_none(chex.assert_axis_dimension_gt, probs, axis=-1, val=1)
54-
chex.if_args_not_none(chex.assert_axis_dimension_gt, logits, axis=-1, val=1)
5553
if not (jnp.issubdtype(dtype, jnp.integer) or
5654
jnp.issubdtype(dtype, jnp.floating)):
5755
raise ValueError(

0 commit comments

Comments
 (0)