Skip to content

Commit 5517bfa

Browse files
Jake VanderPlascopybara-github
authored andcommitted
[drjax] Avoid call to deprecated batching.moveaxis
This is deprecated in JAX v0.7.1; `jnp.moveaxis` is a drop-in replacement. PiperOrigin-RevId: 796456784
1 parent d4f26e5 commit 5517bfa

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

drjax/_src/primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def _batch_agg(xs, batched_shape):
203203
# Certain jax libs can silently insert the 'batching' dim 'all the way at
204204
# the front'; we are about to destroy the front axis by agging, so move
205205
# that puppy to the back. Tell the rest of JAX what happened here.
206-
xs = batching.moveaxis(*xs, *batched_shape, -1)
206+
xs = jnp.moveaxis(*xs, *batched_shape, -1)
207207
return agg_prim_fn(xs), len(xs.shape) - 2
208208

209209
# Make sure this can also be batched / mapped. This happens when dispatching

0 commit comments

Comments
 (0)