Skip to content

Tool to inspect checkpoint structure#3139

Open
shuningjin wants to merge 2 commits intomainfrom
shuningjin-ckpt-structure
Open

Tool to inspect checkpoint structure#3139
shuningjin wants to merge 2 commits intomainfrom
shuningjin-ckpt-structure

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Feb 13, 2026

Description

Unified script to inspect checkpoint structure for HF/MaxText/Orbax, to help bringup and debugging

Fix: b/484416862

A unified tool to inspect checkpoint structures for:

  1. HuggingFace/PyTorch
  • input: local path (safetensor or pth)
  • need load weight
  • print: param-name / shape
  1. MaxText Model Architecture
  • input: model name, scan mode
  • lightweight, no memory or compute need (cpu or tpu)
  • how: jax.eval_shape
  • print: param-name / shape + param count
  1. Orbax Checkpoints
  • input: gcs path or local path
  • lightweight, no memory or compute need (cpu or tpu)
  • how: read meta data
  • print: param-name / shape

Usage Examples:

[Mode 1: HF/PyTorch]   
   python src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py hf --path <local_hf_path> --format <safetensors | pth>
[Mode 2: MaxText Arch] 
  python src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py maxtext --model_name <maxtext_model_name> --scan_layers <True | False>
[Mode 3: Orbax]        
  python src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py orbax --path <local_orbax_path | gcs_orbax_path>

Tests

1 HF checkpoint, locally downloaded

python src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py hf \
--path ~/deepseek2-16b/hf-16b \
--format safetensors

https://paste.googleplex.com/5971112811954176

2 MaxText model, on-fly

python src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py maxtext \
--model_name deepseek2-16b --scan_layers True

https://paste.googleplex.com/5636102443630592

python src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py maxtext \
--model_name deepseek2-16b --scan_layers False

https://paste.googleplex.com/5609907941408768

3 Orbax checkpoint

python src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py orbax \
--path gs://maxtext-deepseek/deepseek3-671b/2025-03-31/scanned/0/items

https://paste.googleplex.com/5320246790586368

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Feb 13, 2026

Codecov Report

❌ Patch coverage is 0% with 109 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...axText/utils/ckpt_conversion/inspect_checkpoint.py 0.00% 109 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making our process better!

Usage Examples:
[Mode 1: HF/PyTorch]
python src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py hf --path <local_hf_path> --format <safetensors | pth>
[Mode 2: MaxText Arch]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Shall we update MaxText Arch to Maxtext or MaxText Architecture to avoid potential confusion?

os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"model_name={args.model_name}",
f"scan_layers={args.scan_layers}",
"attention=dot_product",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you meet issues?

Comment on lines +206 to +208
parser_mt.add_argument(
"--scan_layers",
type=str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's better we leverage

def str2bool(v: str) -> bool:
?

@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a unified checkpoint inspector tool for HuggingFace, MaxText architecture, and Orbax. The tool is a great addition for debugging model bring-ups. However, there is a significant bug in the MaxText mode where layer indices are ignored during path flattening, leading to incomplete output. Additionally, there are opportunities to improve memory efficiency when inspecting large checkpoints.

🔍 General Feedback

  • Bug Fix Required: The MaxText architecture inspection currently collapses all layers into a single set of keys because it ignores SequenceKey indices. This needs to be addressed to provide a complete view of the model structure.
  • Memory Efficiency: For safetensors, using get_slice instead of get_tensor avoids unnecessary data loading. For PyTorch checkpoints, better handling of large files and common state-dict wrappers would make the tool more robust.
  • Consistency: Standardizing separators across different modes (e.g., using . everywhere) would improve the user experience.

from safetensors import safe_open
except ImportError:
sys.exit("Error: 'safetensors' is required. `pip install safetensors`")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 f.get_tensor(k) loads the entire tensor into memory. Since you only need the shape, using f.get_slice(k) is much more memory-efficient, especially for large models.

Suggested change
chkpt_vars_raw[k] = f.get_slice(k).shape

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good idea if it works.

elif args.format == "pth":
for i, ckpt_path in enumerate(ckpt_paths):
print(f"Loading {ckpt_path.name} ({i+1}/{len(ckpt_paths)})...")
checkpoint = torch.load(ckpt_path, map_location="cpu")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Loading a full .pth checkpoint into CPU memory can be very memory-intensive for large models. If the torch version allows, consider using mmap=True. Also, .pth files often wrap the state_dict in a dictionary (e.g., under a 'model' or 'state_dict' key).

Suggested change
checkpoint = torch.load(ckpt_path, map_location="cpu")
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
# Some checkpoints wrap the state_dict
if isinstance(checkpoint, dict) and "model" in checkpoint:
checkpoint = checkpoint["model"]
elif isinstance(checkpoint, dict) and "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]


# Initialize without heavyweight runtime
config = pyconfig.initialize(argv)
devices_array = maxtext_utils.create_device_mesh(config)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 This line only extracts the key attribute from the path components, which means SequenceKey (used for list indices, like layer numbers) is ignored. This will cause all layers in a model to have the same flattened key, leading to them overwriting each other in the flat_shapes dictionary and resulting in incomplete output.

Suggested change
devices_array = maxtext_utils.create_device_mesh(config)
key_parts = [str(getattr(k, "key", getattr(k, "idx", k))) for k in path_tuple]

except ImportError:
sys.exit("Error: 'orbax-checkpoint' or 'etils' not found. `pip install orbax-checkpoint etils[epath]`")

path = epath.Path(args.path)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The path_tuple in Orbax might contain integers for indices. Attempting to join them will raise a TypeError. Using map(str, k) ensures compatibility.

Suggested change
path = epath.Path(args.path)
key_str = ".".join(map(str, k))

# ==============================================================================
# Main CLI Driver
# ==============================================================================
def main():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Consider adding bin to the choices, as HuggingFace often uses pytorch_model.bin for PyTorch weights.

Suggested change
def main():
"--format", type=str, required=False, choices=["safetensors", "pth", "bin"], default="safetensors", help="File format"

config = pyconfig.initialize(argv)
devices_array = maxtext_utils.create_device_mesh(config)
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
quant = quantizations.configure_quantization(config)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 For consistency with the Orbax inspection mode and typical MaxText parameter paths, consider using . as a separator instead of -.

Suggested change
quant = quantizations.configure_quantization(config)
mt_param_key = "params." + ".".join(key_parts)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Comments