Skip to content

Comments

Add a full SFT GPU demo for Llama 3_1#3146

Open
katjasrz wants to merge 8 commits intoAI-Hypercomputer:mainfrom
katjasrz:main
Open

Add a full SFT GPU demo for Llama 3_1#3146
katjasrz wants to merge 8 commits intoAI-Hypercomputer:mainfrom
katjasrz:main

Conversation

@katjasrz
Copy link
Collaborator

@katjasrz katjasrz commented Feb 17, 2026

Description

  • Adds src/MaxText/examples/sft_llama3_gpu.ipynb: a GPU-focused, end-to-end notebook for SFT of Llama 3.1-8B on NVIDIA GPUs (HF auth → gated access note → HF → MaxText checkpoint conversion (CPU) → SFT run → TensorBoard → inference sanity check).
  • Motivation: complements existing notebook/docs that emphasize TPU SFT flows (e.g. sft_llama3_demo.ipynb) with a clear NVIDIA GPU path.

Tests

Executed the notebook in the NGC container nvcr.io/nvidia/jax:26.01-maxtext-py3 on a cluster node with 8 H100 NVIDIA GPUs, CUDA 13.1, driver 580.105.08, JAX 0.8.1.dev20260217.

Verified:

  • Checkpoint conversion completes.
  • SFT runs for 100 steps and writes checkpoints/logs.
  • TensorBoard logs are created.
  • Inference sanity check runs and produces output.

To reproduce, execute as follows:

cd <YOUR_ROOT_DIRECTORY>

# clone this repo
git clone https://github.com/katjasrz/maxtext.git

export ROOT_DIR=$(pwd)
export PROJECT_DIR=$ROOT_DIR/maxtext/src/maxtext/examples
export HF_CACHE_DIR=$ROOT_DIR/huggingface

docker run -it --rm --ipc=host \
  --gpus=all \
  -p 8889:8889 \
  -p 6006:6006 \
  --shm-size=16g \
  --ulimit memlock=-1 \
  -v "$PROJECT_DIR":/workspace \
  -v "$HF_CACHE_DIR":/hf_cache \
  -e HF_HOME=/hf_cache \
  -e LOCAL_UID=$(id -u) \
  -e LOCAL_GID=$(id -g) \
  nvcr.io/nvidia/jax:26.01-maxtext-py3 \
  bash -lc 'set -e
    groupadd -g $LOCAL_GID hostgrp 2>/dev/null || true
    useradd -u $LOCAL_UID -g $LOCAL_GID -M -d /workspace hostusr 2>/dev/null || true
    
    python3 -m pip install --upgrade pip
    pip install jupyterlab ipywidgets
    pip install -U git+https://github.com/google/tunix
		pip install torch --index-url https://download.pytorch.org/whl/cpu
    
    su hostusr -c "cd /workspace && HOME=/workspace HF_HOME=/hf_cache \
      jupyter lab --ip=0.0.0.0 --port=8889 --no-browser"'

Then follow the instructions in the jupyter notebook src/maxtext/examples/sft_llama3_gpu.ipynb

Two container-specific fixes corresponding to an older maxtext version excluded from the notebook:

Fix 1. The code below includes a workaround for a known container issue where create_nnx_model defaults model_mode to None instead of "train". This is patched at runtime.

# Fix for container bug: model_creation_utils.create_nnx_model defaults model_mode=None
# but it should default to "train". Set the correct default.
from MaxText import model_creation_utils
model_creation_utils.create_nnx_model.__defaults__ = (None, None, "train", None)

Fix 2. The code below is a workaround for a known container issue where empty-string defaults for hf_train_files/hf_eval_files/hf_data_dir cause datasets.load_dataset to fail. These are patched to None at runtime.

# Fix for container bug: empty string defaults for hf_train_files/hf_eval_files/hf_data_dir
# cause datasets.load_dataset to fail. Monkey-patch to convert empty strings to None.
# Guard against multiple applications to avoid recursion.
import datasets
if not hasattr(datasets, '_original_load_dataset'):
    datasets._original_load_dataset = datasets.load_dataset

    def _patched_load_dataset(*args, **kwargs):
        for key in ['data_files', 'data_dir']:
            if key in kwargs and kwargs[key] == '':
                kwargs[key] = None
        return datasets._original_load_dataset(*args, **kwargs)

    datasets.load_dataset = _patched_load_dataset

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@google-cla
Copy link

google-cla bot commented Feb 17, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@katjasrz
Copy link
Collaborator Author

I've submitted the CLA.

Copy link
Collaborator

@shralex shralex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

Please move this file to src/maxtext/examples (we're deprecating the uppercase MaxText directory).

Lets make sure its being run by our CI: https://github.com/AI-Hypercomputer/maxtext/blob/main/.github/workflows/run_jupyter_notebooks.yml

Copy link
Collaborator

@A9isha A9isha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome - thank you so much!

Also for consistency, could you please rename this to maxtext/src/maxtext/examples /sft_llama3_demo_gpu.ipynb, and renamehttps://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb to https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo_tpu.ipynb?

And delete the src/MaxText/examples/sft_llama3_gpu.ipynb since src/MaxText is going to get deprecated in favor of src/maxtext

"id": "1687aa03-1549-429a-8156-571c7493ca3d",
"metadata": {},
"source": [
"### Define model paths and run configuration\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious - do we need jax.distributed.is_initialized() like
https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a section explaining how it works on different GPU setups

"metadata": {},
"source": [
"### Define model paths and run configuration\n",
"\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, how much of code cell [5] of https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb do you think is applicable here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is based on the original inline conversion snippet and preserves its core behavior, but improves structure and robustness by wrapping the logic in a reusable function, adding clearer error handling, controlled logging, and safer output directory handling. I’ve also moved the torch installation step into the setup description above, so users can install it once ahead of time and avoid triggering unnecessary reinstallation if it’s already present.

"id": "6b762437-1edb-4123-8257-90cb98028e97",
"metadata": {},
"source": [
"### Download and convert the Llama 3.1 checkpoint\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, could you please add your better logging/documentation to https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants