From e3c389dff9ddf7b1ddfea35ad588bdbaa553b6c7 Mon Sep 17 00:00:00 2001 From: Surya Bhupatiraju Date: Fri, 1 Sep 2023 12:17:25 -0700 Subject: [PATCH] Replace deprecated jax.linear_util.wrap_init with jax.extend.linear_util.wrap_init. PiperOrigin-RevId: 562017973 --- distrax/_src/utils/transformations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distrax/_src/utils/transformations.py b/distrax/_src/utils/transformations.py index e754c054..8d845499 100644 --- a/distrax/_src/utils/transformations.py +++ b/distrax/_src/utils/transformations.py @@ -258,7 +258,7 @@ def write(var, val): # if primitive is an xla_call, get subexpressions and evaluate recursively call_jaxpr, params = _extract_call_jaxpr(eqn.primitive, params) if call_jaxpr: - subfuns = [jax.linear_util.wrap_init( + subfuns = [jax.extend.linear_util.wrap_init( functools.partial(_interpret_inverse, call_jaxpr, ()))] prim_inv = eqn.primitive