Skip to content

TracerArrayConversionError when indexing into a distribution #291

@adzcai

Description

@adzcai

I sometimes get an error when indexing into a Distribution object via square brackets.
For example, here's a snippet that breaks:

from jax import jit
import jax.numpy as jnp
import distrax

def foo(logits):
  d = distrax.Categorical(logits=logits)
  mask = jnp.arange(logits.shape[0])
  return d[mask].entropy()

logits = jnp.zeros((3, 3, 3))
foo(logits)  # fine
jit(foo)(logits)  # raises

Specifically I get the following traceback:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 3, in foo
  File ".venv/lib64/python3.11/site-packages/distrax/_src/distributions/categorical.py", line 156, in __getitem__
    index = distribution.to_batch_shape_index(self.batch_shape, index)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib64/python3.11/site-packages/distrax/_src/distributions/distribution.py", line 354, in to_batch_shape_index
    new_index = [x[index] for x in np.indices(batch_shape)]
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib64/python3.11/site-packages/distrax/_src/distributions/distribution.py", line 354, in <listcomp>
    new_index = [x[index] for x in np.indices(batch_shape)]
                 ~^^^^^^^
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape bool[3,3]
The error occurred while tracing the function foo at <stdin>:1 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=True] b
    from line <stdin>:3:11 (foo)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

It seems the error is in this line.
The error goes away if I change np.indices to jnp.indices in that line.
Would that be an available solution?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions