Skip to content

Commit d0c0d16

Browse files
jkr26copybara-github
authored andcommitted
Opens up API hole enabling jax_privacy to disable ability to read jax abstract mesh in drjax.
PiperOrigin-RevId: 766232793
1 parent 079ee0f commit d0c0d16

3 files changed

Lines changed: 22 additions & 5 deletions

File tree

drjax/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
@_functools.wraps(_api.drjax_program)
3333
def program(**kwargs):
3434
"""A decorator enabling calling the DrJAX API."""
35+
# TODO(b/421499635): Remove this API hole once we diagnose and fix clipping
36+
# ooming.
37+
use_abstract_mesh = kwargs.pop('use_abstract_mesh', True)
3538
try:
3639
[(placement, value)] = kwargs.items() # pylint: disable=unbalanced-dict-unpacking
3740
except ValueError as e:
@@ -50,10 +53,13 @@ def program(**kwargs):
5053
)
5154
# pylint: enable=f-string-without-interpolation
5255
return _api.drjax_program(
53-
placements=value, self_module=_sys.modules[__name__]
56+
placements=value,
57+
self_module=_sys.modules[__name__],
58+
use_abstract_mesh=use_abstract_mesh,
5459
)
5560
else:
5661
return _api.drjax_program(
5762
placements=kwargs,
5863
self_module=_sys.modules[__name__],
64+
use_abstract_mesh=use_abstract_mesh,
5965
)

drjax/_src/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def drjax_program(
233233
*,
234234
placements: Mapping[str, int],
235235
self_module,
236+
use_abstract_mesh: bool = True,
236237
):
237238
"""Patches symbols into current module and call `jax.jit` on the result.
238239
@@ -261,6 +262,9 @@ def drjax_program(
261262
collectives referencing this name results in undefined behavior).
262263
self_module: The Python module to patch the API when performing DrJAX
263264
tracing.
265+
use_abstract_mesh: Whether to optionally search for jax's abstract mesh when
266+
adding drjax sharding constraints (e.g. making use of drjax compatible
267+
with jax.sharding.use_mesh).
264268
265269
Returns:
266270
A decorated function enabling the calling of the DrJAX API. Interoperable
@@ -275,6 +279,7 @@ def drjax_program(
275279

276280
placed_computations = impls.PlacedComputations(
277281
placements_to_n_elements=placements,
282+
use_abstract_mesh=use_abstract_mesh,
278283
)
279284
prim_computations, primdefs = primitives.register_primitives(
280285
placements=placements

drjax/_src/impls.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,13 @@ def call_jaxpr(fn, arg):
4747

4848
# TODO(b/366437841): Remove use of pxla.thread_resources.env.physical_mesh,
4949
# which is a JAX internal API.
50-
def _global_mesh() -> jax.sharding.Mesh | jax.sharding.AbstractMesh | None:
50+
def _global_mesh(
51+
use_abstract: bool,
52+
) -> jax.sharding.Mesh | jax.sharding.AbstractMesh | None:
5153
"""Returns the JAX global mesh if installed, or `None` otherwise."""
52-
jax_global_mesh = jax.sharding.get_abstract_mesh()
54+
jax_global_mesh = None
55+
if use_abstract:
56+
jax_global_mesh = jax.sharding.get_abstract_mesh()
5357
if jax_global_mesh is None or jax_global_mesh.empty:
5458
jax_global_mesh = pxla.thread_resources.env.physical_mesh
5559
return None if jax_global_mesh.empty else jax_global_mesh
@@ -93,8 +97,10 @@ class PlacedComputations:
9397
def __init__(
9498
self,
9599
placements_to_n_elements: Mapping[str, int],
100+
use_abstract_mesh: bool = True,
96101
):
97102
self._placements_to_n_elements = placements_to_n_elements
103+
self._use_abstract_mesh = use_abstract_mesh
98104

99105
def broadcast_to_placement(
100106
self,
@@ -130,7 +136,7 @@ def broadcast_to_placement(
130136
A logically tiled array along the zeroth axis, as described above.
131137
"""
132138
if mesh is None:
133-
mesh = _global_mesh()
139+
mesh = _global_mesh(self._use_abstract_mesh)
134140

135141
arg = jnp.array(arg)
136142
n_elements = self._placements_to_n_elements[placement]
@@ -240,7 +246,7 @@ def map_to_placement(
240246
requirements specified above.
241247
"""
242248
if mesh is None:
243-
mesh = _global_mesh()
249+
mesh = _global_mesh(self._use_abstract_mesh)
244250

245251
def _constrain_at_placement_with_slices_like(x, y):
246252
pspec = P(placement, *([P.UNCONSTRAINED] * (len(x.shape) - 1)))

0 commit comments

Comments
 (0)