Add a full SFT GPU demo for Llama 3_1#3146
Add a full SFT GPU demo for Llama 3_1#3146katjasrz wants to merge 8 commits intoAI-Hypercomputer:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
I've submitted the CLA. |
shralex
left a comment
There was a problem hiding this comment.
Thank you for the contribution!
Please move this file to src/maxtext/examples (we're deprecating the uppercase MaxText directory).
Lets make sure its being run by our CI: https://github.com/AI-Hypercomputer/maxtext/blob/main/.github/workflows/run_jupyter_notebooks.yml
A9isha
left a comment
There was a problem hiding this comment.
This is awesome - thank you so much!
Also for consistency, could you please rename this to maxtext/src/maxtext/examples /sft_llama3_demo_gpu.ipynb, and renamehttps://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb to https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo_tpu.ipynb?
And delete the src/MaxText/examples/sft_llama3_gpu.ipynb since src/MaxText is going to get deprecated in favor of src/maxtext
| "id": "1687aa03-1549-429a-8156-571c7493ca3d", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Define model paths and run configuration\n", |
There was a problem hiding this comment.
just curious - do we need jax.distributed.is_initialized() like
https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb here?
There was a problem hiding this comment.
I've added a section explaining how it works on different GPU setups
| "metadata": {}, | ||
| "source": [ | ||
| "### Define model paths and run configuration\n", | ||
| "\n", |
There was a problem hiding this comment.
For consistency, how much of code cell [5] of https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb do you think is applicable here?
There was a problem hiding this comment.
This implementation is based on the original inline conversion snippet and preserves its core behavior, but improves structure and robustness by wrapping the logic in a reusable function, adding clearer error handling, controlled logging, and safer output directory handling. I’ve also moved the torch installation step into the setup description above, so users can install it once ahead of time and avoid triggering unnecessary reinstallation if it’s already present.
| "id": "6b762437-1edb-4123-8257-90cb98028e97", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Download and convert the Llama 3.1 checkpoint\n", |
There was a problem hiding this comment.
For consistency, could you please add your better logging/documentation to https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb?
Description
src/MaxText/examples/sft_llama3_gpu.ipynb: a GPU-focused, end-to-end notebook for SFT of Llama 3.1-8B on NVIDIA GPUs (HF auth → gated access note → HF → MaxText checkpoint conversion (CPU) → SFT run → TensorBoard → inference sanity check).sft_llama3_demo.ipynb) with a clear NVIDIA GPU path.Tests
Executed the notebook in the NGC container
nvcr.io/nvidia/jax:26.01-maxtext-py3on a cluster node with 8 H100 NVIDIA GPUs, CUDA 13.1, driver 580.105.08, JAX 0.8.1.dev20260217.Verified:
To reproduce, execute as follows:
Then follow the instructions in the jupyter notebook src/maxtext/examples/sft_llama3_gpu.ipynb
Two container-specific fixes corresponding to an older maxtext version excluded from the notebook:
Fix 1. The code below includes a workaround for a known container issue where create_nnx_model defaults model_mode to None instead of "train". This is patched at runtime.
Fix 2. The code below is a workaround for a known container issue where empty-string defaults for hf_train_files/hf_eval_files/hf_data_dir cause datasets.load_dataset to fail. These are patched to None at runtime.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.