From b13b343c8f3c3ab91e44aab9da6e6dd1c5406a89 Mon Sep 17 00:00:00 2001 From: Charles Li Date: Thu, 19 Feb 2026 00:02:03 +0000 Subject: [PATCH] RL: Add chips_per_vm hyperparameter and fix an issue of missing config Fix b/471046638 --- docs/tutorials/posttraining/rl.md | 14 ++++++++++++-- src/maxtext/configs/types.py | 1 + 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/posttraining/rl.md b/docs/tutorials/posttraining/rl.md index 91024344c0..d3d286cbe6 100644 --- a/docs/tutorials/posttraining/rl.md +++ b/docs/tutorials/posttraining/rl.md @@ -127,8 +127,16 @@ export HF_TOKEN= export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory export RUN_NAME= # e.g., $(date +%Y-%m-%d-%H-%M-%S) + +export CHIPS_PER_VM= # depends on hardware, for v5p this is 4, for v6e this is 8 ``` +For the value of `CHIPS_PER_VM` on different TPU hardware, refer the official document + +- [TPU v5e](https://docs.cloud.google.com/tpu/docs/v5e) (single host, chips_per_vm=8) +- [TPU v5p](https://docs.cloud.google.com/tpu/docs/v5p) (single host, chips_per_vm=4) +- [TPU v6e](https://docs.cloud.google.com/tpu/docs/v6e) (single host, chips_per_vm=8) + ## Get your model checkpoint ### Option 1: Using an existing MaxText checkpoint @@ -159,7 +167,8 @@ python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ load_parameters_path=${MAXTEXT_CKPT_PATH} \ run_name=${RUN_NAME} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ - hf_access_token=${HF_TOKEN} + hf_access_token=${HF_TOKEN} \ + chips_per_vm=${CHIPS_PER_VM} ``` The overview of what this run will do is as follows: @@ -183,7 +192,8 @@ python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ run_name=${RUN_NAME} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ hf_access_token=${HF_TOKEN} \ - loss_algo=gspo-token + loss_algo=gspo-token \ + chips_per_vm=${CHIPS_PER_VM} ``` The overview of what this run will do is as follows: diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index cde89b0c7e..799a125736 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1809,6 +1809,7 @@ class MaxTextConfig( # Reinforcement Learning RLHardware, VLLM, + RL, RLDataset, RLEvaluation, Reward,