I sometimes get an error when indexing into a Distribution object via square brackets.
For example, here's a snippet that breaks:
from jax import jit
import jax.numpy as jnp
import distrax
def foo(logits):
d = distrax.Categorical(logits=logits)
mask = jnp.arange(logits.shape[0])
return d[mask].entropy()
logits = jnp.zeros((3, 3, 3))
foo(logits) # fine
jit(foo)(logits) # raises
Specifically I get the following traceback:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 3, in foo
File ".venv/lib64/python3.11/site-packages/distrax/_src/distributions/categorical.py", line 156, in __getitem__
index = distribution.to_batch_shape_index(self.batch_shape, index)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib64/python3.11/site-packages/distrax/_src/distributions/distribution.py", line 354, in to_batch_shape_index
new_index = [x[index] for x in np.indices(batch_shape)]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib64/python3.11/site-packages/distrax/_src/distributions/distribution.py", line 354, in <listcomp>
new_index = [x[index] for x in np.indices(batch_shape)]
~^^^^^^^
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape bool[3,3]
The error occurred while tracing the function foo at <stdin>:1 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=True] b
from line <stdin>:3:11 (foo)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
It seems the error is in this line.
The error goes away if I change np.indices to jnp.indices in that line.
Would that be an available solution?
I sometimes get an error when indexing into a
Distributionobject via square brackets.For example, here's a snippet that breaks:
Specifically I get the following traceback:
It seems the error is in this line.
The error goes away if I change
np.indicestojnp.indicesin that line.Would that be an available solution?