diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 5d01e8751..eb0fee097 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -183,6 +183,8 @@ def create_sharded_state(): metadata = ckptr.metadata(config.load_parameters_path) is_nnx_checkpoint = True + has_base_key = False + if ( "params" in metadata.item_metadata.tree.keys() and "params" in metadata.item_metadata.tree.get("params", {}).keys() @@ -197,6 +199,16 @@ def create_sharded_state(): item_to_restore = {"params": {"params": target_for_restore}} restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}} + elif "base" in metadata.item_metadata.tree.keys(): + # structure of nnx-rl checkpoint: {'base': {'decoder': {..., 'value': ...}}} + has_base_key = True + target_for_restore = jax.tree.map( + lambda v: {"value": v.value}, + sharded_state, + is_leaf=lambda n: isinstance(n, nnx.Variable), + ) + item_to_restore = {"base": target_for_restore} + restore_args = {"base": ocp.checkpoint_utils.construct_restore_args(target_for_restore)} else: # structure of nnx checkpoint: {'decoder': {'value': ...}} target_for_restore = jax.tree.map( @@ -215,6 +227,10 @@ def create_sharded_state(): ) if is_nnx_checkpoint: + # Unwrap 'base' key if present (NNX-RL format) + if has_base_key: + restored = restored.get("base", restored) + checkpoint = jax.tree.map( lambda v: v["value"], restored, diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index 5ba40a170..950b10a37 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -59,6 +59,7 @@ from MaxText.layers import models from MaxText.layers import quantizations from maxtext.checkpoint_conversion.utils.hf_utils import convert_jax_weight_to_torch +from maxtext.checkpoint_conversion.utils.utils import load_orbax_checkpoint from maxtext.utils import max_logging from maxtext.utils import maxtext_utils @@ -430,7 +431,61 @@ def main(config, test_args): # pylint: disable=W0621 mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None) + # maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None) + + # Unwrap NNX 'value' layer from checkpoint + def unwrap_nnx_values(tree): + """Recursively unwrap NNX value wrappers from a pytree, skipping RNG state.""" + if isinstance(tree, dict): + # Skip any keys containing 'to_nnx__rngs' (NNX RNG state) + filtered_tree = {k: v for k, v in tree.items() if 'to_nnx__rngs' not in k} + + # Check if this is a value wrapper (single 'value' key with array data) + if len(filtered_tree) == 1 and 'value' in filtered_tree and isinstance(filtered_tree['value'], (jax.Array, np.ndarray)): + return filtered_tree['value'] + # Otherwise recursively process all keys + return {k: unwrap_nnx_values(v) for k, v in filtered_tree.items()} + else: + return tree + + # Get abstract state with proper sharding specs + unboxed_abstract_state, state_mesh_annotations, _ = maxtext_utils.get_abstract_state( + maxtext_model, None, config, rng1, mesh, False + ) + + # Load checkpoint and unwrap NNX value layer + loaded_params = load_orbax_checkpoint(config) + loaded_params = unwrap_nnx_values(loaded_params) + loaded_params = {'params': loaded_params['base']} + + # Convert all arrays to numpy to strip old sharding information + def to_numpy(tree): + """Convert all JAX arrays to numpy arrays to remove sharding info.""" + def convert_leaf(x): + if isinstance(x, jax.Array): + return np.array(x) + return x + return jax.tree_util.tree_map(convert_leaf, tree) + + loaded_params = to_numpy(loaded_params) + + # Reshard loaded params to match expected sharding from abstract state + def reshard_to_match(target_tree, source_tree): + """Reshard source arrays to match target's sharding.""" + def copy_sharding(target_leaf, source_leaf): + if isinstance(source_leaf, (jax.Array, np.ndarray)): + if isinstance(target_leaf, jax.Array) and hasattr(target_leaf, 'sharding'): + # Reshard source to match target's sharding + return jax.device_put(source_leaf, target_leaf.sharding) + return source_leaf + return source_leaf + + return jax.tree_util.tree_map(copy_sharding, target_tree, source_tree) + + # Apply the proper sharding to loaded params + params = reshard_to_match(unboxed_abstract_state.params, loaded_params) + + maxtext_state = maxtext_utils.init_decode_state(maxtext_model.apply, params) prompts = ["I love to", "Today is a", "What is the"] all_data_to_save = []