Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/maxtext/utils/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def create_sharded_state():
metadata = ckptr.metadata(config.load_parameters_path)

is_nnx_checkpoint = True
has_base_key = False

if (
"params" in metadata.item_metadata.tree.keys()
and "params" in metadata.item_metadata.tree.get("params", {}).keys()
Expand All @@ -197,6 +199,16 @@ def create_sharded_state():

item_to_restore = {"params": {"params": target_for_restore}}
restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}}
elif "base" in metadata.item_metadata.tree.keys():
# structure of nnx-rl checkpoint: {'base': {'decoder': {..., 'value': ...}}}
has_base_key = True
target_for_restore = jax.tree.map(
lambda v: {"value": v.value},
sharded_state,
is_leaf=lambda n: isinstance(n, nnx.Variable),
)
item_to_restore = {"base": target_for_restore}
restore_args = {"base": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}
else:
# structure of nnx checkpoint: {'decoder': {'value': ...}}
target_for_restore = jax.tree.map(
Expand All @@ -215,6 +227,10 @@ def create_sharded_state():
)

if is_nnx_checkpoint:
# Unwrap 'base' key if present (NNX-RL format)
if has_base_key:
restored = restored.get("base", restored)

checkpoint = jax.tree.map(
lambda v: v["value"],
restored,
Expand Down
57 changes: 56 additions & 1 deletion tests/utils/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from MaxText.layers import models
from MaxText.layers import quantizations
from maxtext.checkpoint_conversion.utils.hf_utils import convert_jax_weight_to_torch
from maxtext.checkpoint_conversion.utils.utils import load_orbax_checkpoint
from maxtext.utils import max_logging
from maxtext.utils import maxtext_utils

Expand Down Expand Up @@ -430,7 +431,61 @@ def main(config, test_args): # pylint: disable=W0621
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
quant = quantizations.configure_quantization(config)
maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None)
# maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None)

# Unwrap NNX 'value' layer from checkpoint
def unwrap_nnx_values(tree):
"""Recursively unwrap NNX value wrappers from a pytree, skipping RNG state."""
if isinstance(tree, dict):
# Skip any keys containing 'to_nnx__rngs' (NNX RNG state)
filtered_tree = {k: v for k, v in tree.items() if 'to_nnx__rngs' not in k}

# Check if this is a value wrapper (single 'value' key with array data)
if len(filtered_tree) == 1 and 'value' in filtered_tree and isinstance(filtered_tree['value'], (jax.Array, np.ndarray)):
return filtered_tree['value']
# Otherwise recursively process all keys
return {k: unwrap_nnx_values(v) for k, v in filtered_tree.items()}
else:
return tree

# Get abstract state with proper sharding specs
unboxed_abstract_state, state_mesh_annotations, _ = maxtext_utils.get_abstract_state(
maxtext_model, None, config, rng1, mesh, False
)

# Load checkpoint and unwrap NNX value layer
loaded_params = load_orbax_checkpoint(config)
loaded_params = unwrap_nnx_values(loaded_params)
loaded_params = {'params': loaded_params['base']}

# Convert all arrays to numpy to strip old sharding information
def to_numpy(tree):
"""Convert all JAX arrays to numpy arrays to remove sharding info."""
def convert_leaf(x):
if isinstance(x, jax.Array):
return np.array(x)
return x
return jax.tree_util.tree_map(convert_leaf, tree)

loaded_params = to_numpy(loaded_params)

# Reshard loaded params to match expected sharding from abstract state
def reshard_to_match(target_tree, source_tree):
"""Reshard source arrays to match target's sharding."""
def copy_sharding(target_leaf, source_leaf):
if isinstance(source_leaf, (jax.Array, np.ndarray)):
if isinstance(target_leaf, jax.Array) and hasattr(target_leaf, 'sharding'):
# Reshard source to match target's sharding
return jax.device_put(source_leaf, target_leaf.sharding)
return source_leaf
return source_leaf

return jax.tree_util.tree_map(copy_sharding, target_tree, source_tree)

# Apply the proper sharding to loaded params
params = reshard_to_match(unboxed_abstract_state.params, loaded_params)

maxtext_state = maxtext_utils.init_decode_state(maxtext_model.apply, params)

prompts = ["I love to", "Today is a", "What is the"]
all_data_to_save = []
Expand Down
Loading