Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| Usage Examples: | ||
| [Mode 1: HF/PyTorch] | ||
| python src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py hf --path <local_hf_path> --format <safetensors | pth> | ||
| [Mode 2: MaxText Arch] |
There was a problem hiding this comment.
Nit: Shall we update MaxText Arch to Maxtext or MaxText Architecture to avoid potential confusion?
| os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), | ||
| f"model_name={args.model_name}", | ||
| f"scan_layers={args.scan_layers}", | ||
| "attention=dot_product", |
| parser_mt.add_argument( | ||
| "--scan_layers", | ||
| type=str, |
There was a problem hiding this comment.
Do you think it's better we leverage
maxtext/benchmarks/benchmark_utils.py
Line 33 in 2b06b9c
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces a unified checkpoint inspector tool for HuggingFace, MaxText architecture, and Orbax. The tool is a great addition for debugging model bring-ups. However, there is a significant bug in the MaxText mode where layer indices are ignored during path flattening, leading to incomplete output. Additionally, there are opportunities to improve memory efficiency when inspecting large checkpoints.
🔍 General Feedback
- Bug Fix Required: The MaxText architecture inspection currently collapses all layers into a single set of keys because it ignores
SequenceKeyindices. This needs to be addressed to provide a complete view of the model structure. - Memory Efficiency: For
safetensors, usingget_sliceinstead ofget_tensoravoids unnecessary data loading. For PyTorch checkpoints, better handling of large files and common state-dict wrappers would make the tool more robust. - Consistency: Standardizing separators across different modes (e.g., using
.everywhere) would improve the user experience.
| from safetensors import safe_open | ||
| except ImportError: | ||
| sys.exit("Error: 'safetensors' is required. `pip install safetensors`") | ||
|
|
There was a problem hiding this comment.
🟡 f.get_tensor(k) loads the entire tensor into memory. Since you only need the shape, using f.get_slice(k) is much more memory-efficient, especially for large models.
| chkpt_vars_raw[k] = f.get_slice(k).shape |
There was a problem hiding this comment.
Sounds like a good idea if it works.
| elif args.format == "pth": | ||
| for i, ckpt_path in enumerate(ckpt_paths): | ||
| print(f"Loading {ckpt_path.name} ({i+1}/{len(ckpt_paths)})...") | ||
| checkpoint = torch.load(ckpt_path, map_location="cpu") |
There was a problem hiding this comment.
🟡 Loading a full .pth checkpoint into CPU memory can be very memory-intensive for large models. If the torch version allows, consider using mmap=True. Also, .pth files often wrap the state_dict in a dictionary (e.g., under a 'model' or 'state_dict' key).
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) | |
| # Some checkpoints wrap the state_dict | |
| if isinstance(checkpoint, dict) and "model" in checkpoint: | |
| checkpoint = checkpoint["model"] | |
| elif isinstance(checkpoint, dict) and "state_dict" in checkpoint: | |
| checkpoint = checkpoint["state_dict"] |
|
|
||
| # Initialize without heavyweight runtime | ||
| config = pyconfig.initialize(argv) | ||
| devices_array = maxtext_utils.create_device_mesh(config) |
There was a problem hiding this comment.
🟠 This line only extracts the key attribute from the path components, which means SequenceKey (used for list indices, like layer numbers) is ignored. This will cause all layers in a model to have the same flattened key, leading to them overwriting each other in the flat_shapes dictionary and resulting in incomplete output.
| devices_array = maxtext_utils.create_device_mesh(config) | |
| key_parts = [str(getattr(k, "key", getattr(k, "idx", k))) for k in path_tuple] |
| except ImportError: | ||
| sys.exit("Error: 'orbax-checkpoint' or 'etils' not found. `pip install orbax-checkpoint etils[epath]`") | ||
|
|
||
| path = epath.Path(args.path) |
There was a problem hiding this comment.
🟡 The path_tuple in Orbax might contain integers for indices. Attempting to join them will raise a TypeError. Using map(str, k) ensures compatibility.
| path = epath.Path(args.path) | |
| key_str = ".".join(map(str, k)) |
| # ============================================================================== | ||
| # Main CLI Driver | ||
| # ============================================================================== | ||
| def main(): |
There was a problem hiding this comment.
🟢 Consider adding bin to the choices, as HuggingFace often uses pytorch_model.bin for PyTorch weights.
| def main(): | |
| "--format", type=str, required=False, choices=["safetensors", "pth", "bin"], default="safetensors", help="File format" |
| config = pyconfig.initialize(argv) | ||
| devices_array = maxtext_utils.create_device_mesh(config) | ||
| mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) | ||
| quant = quantizations.configure_quantization(config) |
There was a problem hiding this comment.
🟢 For consistency with the Orbax inspection mode and typical MaxText parameter paths, consider using . as a separator instead of -.
| quant = quantizations.configure_quantization(config) | |
| mt_param_key = "params." + ".".join(key_parts) |
Description
Unified script to inspect checkpoint structure for HF/MaxText/Orbax, to help bringup and debugging
Fix: b/484416862
A unified tool to inspect checkpoint structures for:
Usage Examples:
Tests
1 HF checkpoint, locally downloaded
https://paste.googleplex.com/5971112811954176
2 MaxText model, on-fly
https://paste.googleplex.com/5636102443630592
https://paste.googleplex.com/5609907941408768
3 Orbax checkpoint
https://paste.googleplex.com/5320246790586368
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.