Replace jax.sharding.use_mesh with jax.set_mesh. jax.set_mesh can act as a global setter or a context manager.
#24
+2
−2