diff --git a/drjax/_src/api.py b/drjax/_src/api.py index ed45033..40d07de 100644 --- a/drjax/_src/api.py +++ b/drjax/_src/api.py @@ -264,7 +264,7 @@ def drjax_program( tracing. use_abstract_mesh: Whether to optionally search for jax's abstract mesh when adding drjax sharding constraints (e.g. making use of drjax compatible - with jax.sharding.use_mesh). + with jax.set_mesh). Returns: A decorated function enabling the calling of the DrJAX API. Interoperable diff --git a/drjax/_src/impls_sharding_test.py b/drjax/_src/impls_sharding_test.py index 36b2b62..1f78bfc 100644 --- a/drjax/_src/impls_sharding_test.py +++ b/drjax/_src/impls_sharding_test.py @@ -104,7 +104,7 @@ def test_broadcast_clients_with_jax_use_mesh(self): [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], [_CLIENTS_AXIS, _DATA_AXIS] ) arg = jnp.zeros(shape=[_DATA_SIZE]) - with jax.sharding.use_mesh(global_mesh): + with jax.set_mesh(global_mesh): result = self._comp_factory.broadcast_to_placement( arg, _CLIENTS_AXIS,