diff --git a/checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py b/checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py index 71e527754..05c241c20 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py @@ -149,15 +149,19 @@ def _colocated_cpu_sharding( sharding: jax.sharding.Sharding, ) -> jax.sharding.Sharding: """Returns a CPU sharding colocated with the given device sharding.""" + target_memory_kind = sharding.memory_kind + if target_memory_kind == 'device': + target_memory_kind = 'pinned_host' + if isinstance(sharding, jax.sharding.SingleDeviceSharding): cpu_devices = cp.colocated_cpu_devices(list(sharding.device_set)) return jax.sharding.SingleDeviceSharding( - cpu_devices[0], memory_kind=sharding.memory_kind + cpu_devices[0], memory_kind=target_memory_kind ) elif isinstance(sharding, jax.sharding.NamedSharding): cpu_mesh = cp.colocated_cpu_devices(sharding.mesh) return jax.sharding.NamedSharding( - cpu_mesh, sharding.spec, memory_kind=sharding.memory_kind + cpu_mesh, sharding.spec, memory_kind=target_memory_kind ) else: raise TypeError(