diff --git a/docs/tutorials/posttraining/rl.md b/docs/tutorials/posttraining/rl.md index 91024344c..d3d286cbe 100644 --- a/docs/tutorials/posttraining/rl.md +++ b/docs/tutorials/posttraining/rl.md @@ -127,8 +127,16 @@ export HF_TOKEN= export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory export RUN_NAME= # e.g., $(date +%Y-%m-%d-%H-%M-%S) + +export CHIPS_PER_VM= # depends on hardware, for v5p this is 4, for v6e this is 8 ``` +For the value of `CHIPS_PER_VM` on different TPU hardware, refer the official document + +- [TPU v5e](https://docs.cloud.google.com/tpu/docs/v5e) (single host, chips_per_vm=8) +- [TPU v5p](https://docs.cloud.google.com/tpu/docs/v5p) (single host, chips_per_vm=4) +- [TPU v6e](https://docs.cloud.google.com/tpu/docs/v6e) (single host, chips_per_vm=8) + ## Get your model checkpoint ### Option 1: Using an existing MaxText checkpoint @@ -159,7 +167,8 @@ python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ load_parameters_path=${MAXTEXT_CKPT_PATH} \ run_name=${RUN_NAME} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ - hf_access_token=${HF_TOKEN} + hf_access_token=${HF_TOKEN} \ + chips_per_vm=${CHIPS_PER_VM} ``` The overview of what this run will do is as follows: @@ -183,7 +192,8 @@ python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ run_name=${RUN_NAME} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ hf_access_token=${HF_TOKEN} \ - loss_algo=gspo-token + loss_algo=gspo-token \ + chips_per_vm=${CHIPS_PER_VM} ``` The overview of what this run will do is as follows: diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index cde89b0c7..799a12573 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1809,6 +1809,7 @@ class MaxTextConfig( # Reinforcement Learning RLHardware, VLLM, + RL, RLDataset, RLEvaluation, Reward,