From d4f26e5ed6d75d2475b51a78cf850cba0238ad1c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 31 Jul 2025 17:34:50 -0700 Subject: [PATCH] Replace `jax.sharding.use_mesh` with `jax.set_mesh`. `jax.set_mesh` can act as a global setter or a context manager. PiperOrigin-RevId: 789530543 --- drjax/_src/api.py | 2 +- drjax/_src/impls_sharding_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/drjax/_src/api.py b/drjax/_src/api.py index ed45033..40d07de 100644 --- a/drjax/_src/api.py +++ b/drjax/_src/api.py @@ -264,7 +264,7 @@ def drjax_program( tracing. use_abstract_mesh: Whether to optionally search for jax's abstract mesh when adding drjax sharding constraints (e.g. making use of drjax compatible - with jax.sharding.use_mesh). + with jax.set_mesh). Returns: A decorated function enabling the calling of the DrJAX API. Interoperable diff --git a/drjax/_src/impls_sharding_test.py b/drjax/_src/impls_sharding_test.py index 36b2b62..1f78bfc 100644 --- a/drjax/_src/impls_sharding_test.py +++ b/drjax/_src/impls_sharding_test.py @@ -104,7 +104,7 @@ def test_broadcast_clients_with_jax_use_mesh(self): [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], [_CLIENTS_AXIS, _DATA_AXIS] ) arg = jnp.zeros(shape=[_DATA_SIZE]) - with jax.sharding.use_mesh(global_mesh): + with jax.set_mesh(global_mesh): result = self._comp_factory.broadcast_to_placement( arg, _CLIENTS_AXIS,