Skip to content

Commit beeb4ec

Browse files
DistraxDevDistraxDev
authored andcommitted
Fix Distrax from Jax API change.
The method map_primitive was removed from Primitive class. PiperOrigin-RevId: 888552312
1 parent e72e21e commit beeb4ec

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

distrax/_src/utils/transformations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ def write(var, val):
331331

332332

333333
def _extract_call_jaxpr(primitive, params):
334-
if not (primitive.call_primitive or primitive.map_primitive):
334+
if not (
335+
primitive.call_primitive or getattr(primitive, "map_primitive", False)
336+
):
335337
return None, params
336338
else:
337339
params = dict(params)

0 commit comments

Comments
 (0)