From 1a1ad1f83c01c09707d36c4233f3dd8b74b4ced2 Mon Sep 17 00:00:00 2001 From: Alexander Cai Date: Thu, 3 Jul 2025 02:12:49 -0400 Subject: [PATCH] index Distribution using jnp.indices instead of np.indices --- distrax/_src/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distrax/_src/distributions/distribution.py b/distrax/_src/distributions/distribution.py index ec33f6dc..0e501d71 100644 --- a/distrax/_src/distributions/distribution.py +++ b/distrax/_src/distributions/distribution.py @@ -351,7 +351,7 @@ def to_batch_shape_index( A new index that is only applied on the batch shape. """ try: - new_index = [x[index] for x in np.indices(batch_shape)] + new_index = [x[index] for x in jnp.indices(batch_shape)] return tuple(new_index) except IndexError as e: raise IndexError(f'Batch shape `{batch_shape}` not compatible with index '