diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 4158b2e8c..5df173b93 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -38,15 +38,26 @@ def get_miles_extra_args_provider(add_custom_arguments=None): def add_miles_arguments(parser): # Ray def add_cluster_arguments(parser): - parser.add_argument("--actor-num-nodes", type=int, default=1, help="Number of nodes for training actor") parser.add_argument( - "--actor-num-gpus-per-node", type=int, default=8, help="Number of gpus per node for training actor" + "--actor-num-nodes", type=int, default=1, help="Number of nodes for training the Actor." ) parser.add_argument( - "--critic-num-nodes", type=int, default=None, help="Number of nodes for training actor" + "--actor-num-gpus-per-node", + type=int, + default=8, + help="Number of GPUs per node for training the Actor.", ) parser.add_argument( - "--critic-num-gpus-per-node", type=int, default=None, help="Number of gpus per node for training actor" + "--critic-num-nodes", + type=int, + default=None, + help="Number of nodes for the Critic. Defaults to `--actor-num-nodes`.", + ) + parser.add_argument( + "--critic-num-gpus-per-node", + type=int, + default=None, + help="Number of GPUs per node for the Critic. Defaults to `--actor-num-gpus-per-node`.", ) parser.add_argument( @@ -54,24 +65,28 @@ def add_cluster_arguments(parser): type=int, default=None, help=( - "Number of GPUs for inference. Note that when using --colocate, " - "i.e. the training and the inference engines are on the same gpus, this param will be ignored and will be set as " - "actor_num_gpus_per_node * actor_num_nodes." + "Total number of GPUs required for rollout (inference). " + "In `--colocate` mode, this is ignored and set to `actor-num-gpus-per-node * actor-num-nodes` (plus critic GPUs if enabled)." ), ) parser.add_argument( "--rollout-num-gpus-per-engine", type=int, default=1, - help="Number of GPUs per inference engine, just like the tp_size in sglang.", + help=( + "Number of GPUs per inference engine, same as `tp_size` in SGLang. " + "For multi-node serving, this should be the total GPU count / `tp_size` for each SGLang instance." + ), ) parser.add_argument( "--num-gpus-per-node", type=int, default=8, help=( - "Number of gpus per node for rollout." - "Notice: If you are going to use less than 8 gpus per node under colocate mode, you should set this number." + "Total GPUs per node on the physical machine. " + "This informs the Ray scheduler of the hardware capacity. " + "In **Colocate mode**, it is required that the machine has fewer than 8 GPUs to calculate correct VRAM offsets. " + "In **Disaggregated mode**, it ensures SGLang engines are distributed correctly across nodes without exceeding per-node GPU limits." ), ) parser.add_argument( @@ -79,31 +94,30 @@ def add_cluster_arguments(parser): action="store_true", default=False, help=( - "Whether to colocate the inference engines and the actor. " - "Turning this on will also set --offload to true." + "Deploy training and rollout on the same GPUs. " + "This mode automatically enables `--offload-train` and `--offload-rollout` to facilitate weight-swapping between the training actor and inference engine. " + "**Note:** The offload parameters are currently only used for AMD GPUs and will be removed soon. " + "**Memory Tip:** When colocating, it is highly recommended to set `--sglang-mem-fraction-static` to **0.8** (especially on **NVIDIA Blackwell B200/B300** GPUs). " + "This leaves sufficient VRAM (~20%%) for Megatron to initialize its structures before the first weight offload to CPU occurs. " + "On GB200/GB300, values up to 0.75 are safer for long-running jobs to prevent potential OOMs. " + "#TODO: Verify optimal fraction for Blackwell in production" ), ) parser.add_argument( "--offload", action="store_true", default=False, - help=("Equivalent to --offload-train + --offload-rollout. "), + help="Equivalent to --offload-train + --offload-rollout. ", ) parser.add_argument( "--offload-train", action=argparse.BooleanOptionalAction, - help=( - "Whether to offload the training actor to CPU during training. " - "This will always be true when --colocate is set." - ), + help="Whether to offload the training actor to CPU during training. This will always be true when --colocate is set.", ) parser.add_argument( "--offload-rollout", action=argparse.BooleanOptionalAction, - help=( - "Whether to offload the rollout generator to CPU during training. " - "This will always be true when --colocate is set." - ), + help="Whether to offload the rollout generator to CPU during training. This will always be true when --colocate is set.", ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") @@ -117,32 +131,40 @@ def add_train_arguments(parser): type=str, choices=["megatron", "fsdp"], default="megatron", - help="The backend for training.", + help="The backend for training. Highly suggest Megatron for numerical stability and efficiency.", ) parser.add_argument( "--qkv-format", type=str, choices=["thd", "bshd"], default="thd", - help="The qkv layout.", + help=( + "Whether to pack variable-length sequences into the token dimension for training. " + "`thd` (T-H-D, a.k.a. varlen / packed sequence) concatenates sequences and uses `cu_seqlens` to avoid padding; it is the default and is usually faster by reducing padding overhead. " + "`bshd` (B-S-H-D) uses fixed-shape padded batches; use it for newer models with novel attention architectures (e.g., sparse attention, attention sink) where the training backend does not support `thd`." + ), ) parser.add_argument( "--true-on-policy-mode", action="store_true", default=False, - help="Whether to enable true-on-policy mode.", + help=( + "Strictly align SGLang's log probs and training engine's log probs to bit-wise equal. " + "This parameter is only used for FSDP right now. " + "[Ref](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/mismatch/blog-en.md#truly-on-policy-training)" + ), ) parser.add_argument( "--train-env-vars", type=json.loads, default="{}", - help="Extra environment variables for training process, e.g. PyTorch memory management ones.", + help="Extra environment variables for training process, e.g., PyTorch memory management ones.", ) parser.add_argument( "--train-memory-margin-bytes", type=int, default=1024**3, - help="Add margin for train memory allocation. By default we will reserve 1GB as margin.", + help="Reserved memory margin for training in bytes. Defaults to 1GB.", ) parser.add_argument( "--disable-weights-backuper", @@ -150,36 +172,41 @@ def add_train_arguments(parser): dest="enable_weights_backuper", help=( "Applies to `megatron` training backend only. " - "Disables the system that backups model weights (Actor, Ref, Old Actor) to CPU RAM. " - "Disabling saves significant host memory but prevents features that rely on weight-swapping, such as computing KL-divergence against a reference model. " - "Note: do not set `--ref-load` and `--keep-old-actor` if disable weights backuper." + "Disables the system that backs up model weights (Actor, Ref, Old Actor) to CPU RAM. " + "Disabling saves significant host memory but prevents features that rely on weight-swapping, such as computing the KL-divergence against a reference model. " + "**Note**: do not set `--ref-load` and `--keep-old-actor` if disable weights backuper." ), ) parser.add_argument( "--megatron-to-hf-mode", choices=["raw", "bridge"], default="raw", - help="The method to convert megatron weights to hugging face weights for SGLang.", + help="Method to convert Megatron weights to HuggingFace format for SGLang integration.", ) parser.add_argument( "--custom-model-provider-path", type=str, default=None, help=( - "Path to a custom model provider function. " - "If set, we will use this function instead of the default model provider. " - "The function should have the signature " - "`def custom_model_provider(pre_process: bool, post_process: bool, vp_stage: int | None = None) -> GPTModel`. " - "Example: 'my_module.my_model_provider'." + "Path to a custom function that replaces the default model provider. " + "[Ref](../get_started/customization.md#20-model-provider---custom-model-provider-path)" ), ) parser.add_argument( "--recompute-loss-function", action="store_true", - help="Whether to enable recompute loss function to save memory during training.", + help="Enable recomputing the loss function to save memory during training.", ) parser.add_argument( - "--log-probs-chunk-size", type=int, default=-1, help="Chunk size to compute log probs to save memory" + "--log-probs-chunk-size", + type=int, + default=-1, + help=( + "Specifies the chunk size for logprobs computation to reduce peak memory usage. " + "Processing logits in smaller batches, it prevents CUDA OOM errors during long-context prefilling or re-computation. " + "Set to `-1` to disable chunking. " + "[Ref](https://github.com/sgl-project/sglang/pull/6318)" + ), ) return parser @@ -190,22 +217,16 @@ def add_rollout_arguments(parser): "--hf-checkpoint", type=str, default=None, - help=( - "The huggingface checkpoint of the trained model. " - "This is used to initialize sglang and also provide the tokenizer. " - "Note that, we will always update the parameters in sglang with that of megatron before training, " - "so you only need to provide a huggingface checkpoint that has the same architecture as the model you want to train. " - "It doesn't necessary need to contain the most up-to-date parameters." - ), + help="Path to the HuggingFace checkpoint used to initialize SGLang and provide the tokenizer.", ) parser.add_argument( "--model-name", type=str, default=None, help=( - "The name of the model, this is used to convert the megatron weights into huggingface format. " - "If not set, we will use `type(AutoConfig.from_pretrained(args.hf_checkpoint)).__name__.lower()` as model_name. " - "Also, sometimes this will help alleviate the bug that transformers cannot find certain model." + "The name of the model that is used to convert the Megatron weights into HuggingFace format. " + "If not set, we will use `type(AutoConfig.from_pretrained(args.hf_checkpoint)).__name__.lower()` as `model_name`. " + "Providing this argument can also help in cases where transformers cannot find certain models." ), ) parser.add_argument( @@ -217,34 +238,34 @@ def add_rollout_arguments(parser): else "miles.rollout.sglang_rollout.generate_rollout" ), help=( - "Path to the rollout generation function." - "You should use this model to create your own custom rollout function, " - "and then set this to the path of your custom rollout function. " - "The signature of the function should be " - "`def generate_rollout(args, rollout_id, *, evaluation=False) -> list[list[Sample]]`" - "and within the output sample, you should at least set `tokens`, `response_length`, `reward` " - "and `truncated`." + "Path to the rollout generation function. " + "Use this to inject custom logic (e.g., for multi-turn or tool use). " + "[Ref](../get_started/customization.md#1-rollout-function---rollout-function-path)" ), ) parser.add_argument( "--rollout-temperature", type=float, default=1.0, - help="the temperature for the inference engine during rollout.", + help="Sampling temperature for the inference engine during rollout.", ) parser.add_argument( - "--rollout-top-p", type=float, default=1.0, help="the top-p for the inference engine during rollout." + "--rollout-top-p", type=float, default=1.0, help="Top-p (nucleus) sampling threshold during rollout." ) parser.add_argument( - "--rollout-top-k", type=int, default=-1, help="the top-k for the inference engine during rollout." + "--rollout-top-k", + type=int, + default=-1, + help="Top-k sampling threshold during rollout. `-1` means disabled.", ) parser.add_argument( "--rollout-max-context-len", type=int, default=None, help=( - "The maximum context size for the inference engine during rollout." - "It should no exceed the `max_position_embeddinds` in Huggingface model's `config.json`" + "The maximum context size for the inference engine during rollout. " + "It should not exceed the `max_position_embeddings` in the HuggingFace model's `config.json`. " + "**Note:** This acts as a hard cap for the total tokens (Prompt + Response)." ), ) parser.add_argument( @@ -252,9 +273,10 @@ def add_rollout_arguments(parser): type=int, default=None, help=( - "The maximum length of the prompt for the inference engine during rollout. " - "If set, we will filter out the long prompts during initialization of the global dataset. " - "This is not recommended if the dataset is large." + "Maximum length of the prompt. " + "Longer prompts are filtered during dataset initialization. " + "This is not recommended if the dataset is large. " + "**Note:** Defaults to `rollout-max-context-len - 1` if not set, ensuring at least one token can be generated." ), ) parser.add_argument( @@ -262,8 +284,8 @@ def add_rollout_arguments(parser): type=int, default=None, help=( - "The maximum length of the response for the inference engine during rollout. " - "It is basically `max_tokens` in sglang." + "Maximum length of the response (`max_tokens` in SGLang). " + "**Note:** Generation will stop when either this limit is reached or the total session length hits `rollout-max-context-len`." ), ) parser.add_argument( @@ -271,8 +293,8 @@ def add_rollout_arguments(parser): action="store_true", default=False, help=( - "Whether to skip special tokens in the response during rollout. " - "This is useful when you want to use the response as a prompt for the next rollout." + "Skip special tokens (e.g., `<\\|im_end\\|>`, `<\\|endoftext\\|>`) in the decoded response string. " + "**Critical for Multi-Turn RL:** Ensures that when a response is appended to the conversation history for the next turn, it doesn't include terminal special tokens that would interfere with chat template formatting or cause early termination in subsequent turns." ), ) parser.add_argument( @@ -280,11 +302,7 @@ def add_rollout_arguments(parser): type=str, nargs="+", default=None, - help=( - "The stop words for the inference engine during rollout. " - "It can be a list of strings or a single string. " - "It may be hard to pass special tokens in command line, in that case rollout_stop_token_ids can be used." - ), + help='A list of strings that trigger termination of generation if they appear in the output (e.g., `"\\nUser:"`).', ) parser.add_argument( "--rollout-stop-token-ids", @@ -292,24 +310,21 @@ def add_rollout_arguments(parser): nargs="+", default=None, help=( - "The stop token ids for the inference engine during rollout. " - "It can be a list of integers or a single integer." + "A list of numerical token IDs that trigger termination. " + "This is the token-level equivalent of `--rollout-stop` and is preferred for special control tokens that are difficult to input as strings." ), ) parser.add_argument( "--rollout-shuffle", action="store_true", default=False, - help=("Whether to shuffle the prompts during rollout."), + help="Shuffle the prompts during rollout.", ) parser.add_argument( "--rollout-seed", type=int, default=42, - help=( - "The seed for the random number generator during rollout. " - "This is used to shuffle the prompts and also for the random sampling of the prompts." - ), + help="Seed for the random number generator during rollout (used for shuffling and sampling).", ) # sampling @@ -318,12 +333,11 @@ def add_rollout_arguments(parser): type=int, default=None, help=( - "This defines the granularity of the sampling batch in the rollout function. " - "When the number of available samples falls below the target, a sampling " - "operation of size over_sampling_batch_size will be triggered." - "Regardless of whether partial rollout is used or filters are applied, " - "the sampling granularity is always determined by this value. " - "If this value is None, rollout_batch_size will be used as the default over_sampling_batch_size." + "Number of prompts requested in each **oversampling** round when **dynamic sampling** is enabled. " + "Miles samples `over_sampling_batch_size` prompts, generates `--n-samples-per-prompt` responses per prompt asynchronously, and then keeps/discards each prompt group via `--dynamic-sampling-filter-path`. " + "If filtering is strict and the remaining accepted batch size drops below the target `--rollout-batch-size`, Miles automatically triggers another oversampling round of the same size. " + "If unset, defaults to `--rollout-batch-size`. " + "See [Dynamic Sampling](../get_started/quick_start.md#dynamic-sampling)." ), ) parser.add_argument( @@ -331,10 +345,8 @@ def add_rollout_arguments(parser): type=str, default=None, help=( - "This is the filter function for dynamic sampling. " - "It should be able to judge whether the result of a prompt should be selected or not." - "We will do dynamic filter for sampling as in DAPO. e.g. not all correct or all wrong samples." - "You could use `miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std` as an example." + "Path to the filter function for dynamic sampling. " + "[Ref](../get_started/customization.md#4-dynamic-sampling-filter---dynamic-sampling-filter-path)" ), ) @@ -344,9 +356,9 @@ def add_rollout_arguments(parser): action="store_true", default=False, help=( - "Whether to use partial rollout. " - "If set, the unfinished samples during dynamic sampling will be recycled back to data buffer. " - "This is useful for long responses." + "Enable partial rollout for **dynamic sampling**: cache partially generated (aborted/unfinished) samples and resume generation in later rollout steps, reducing wasted compute for long responses. " + "Cached samples are stored in the rollout buffer and can be prioritized/selected via `--buffer-filter-path` (default FIFO behavior). " + "See [Partial Rollout](../get_started/quick_start.md#partial-rollout)." ), ) parser.add_argument( @@ -354,8 +366,9 @@ def add_rollout_arguments(parser): action="store_true", default=False, help=( - "Whether to mask previous generation in partial rollout. " - "If set, only on-policy generated tokens will be used in training" + "When using partial rollout, mask the previously generated (cached) response tokens so they do not contribute to the loss; only tokens generated after resuming are used for training. " + "This helps avoid training on a cached prefix produced by an older policy version. " + "See [Partial Rollout](../get_started/quick_start.md#partial-rollout)." ), ) parser.add_argument( @@ -363,8 +376,9 @@ def add_rollout_arguments(parser): type=str, default=None, help=( - "Only substitue the `def generate(args, sample, sampling_params)` function within the example rollout function. " - "This should be useful if you need to implement some special rollout logic, e.g. multi-turn, function calling." + "Path to override only the `generate` step within the default rollout function. " + "If your custom `generate` returns `list[Sample]` (multi-sample), make sure your rollout pipeline can handle it; the default rollout expects a flat `list[Sample]` of length `--n-samples-per-prompt` for each prompt group. " + "[Ref](../get_started/customization.md#2-custom-generate-function---custom-generate-function-path)" ), ) parser.add_argument( @@ -372,9 +386,8 @@ def add_rollout_arguments(parser): type=str, default=None, help=( - "The custom function for logging rollout data. The signature of the functions is: " - "def log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_time) -> bool. " - "The return value indicates whether to skip the default logging. " + "Path to a custom function for logging training rollout data. " + "[Ref](../get_started/customization.md#14-logging-functions)" ), ) parser.add_argument( @@ -382,9 +395,8 @@ def add_rollout_arguments(parser): type=str, default=None, help=( - "The custom function for logging eval rollout data. " - "def log_eval_rollout_data(rollout_id, args, data, extra_metrics) -> bool. " - "The return value indicates whether to skip the default logging. " + "Path to a custom function for logging evaluation rollout data. " + "[Ref](../get_started/customization.md#14-logging-functions)" ), ) @@ -393,9 +405,8 @@ def add_rollout_arguments(parser): type=str, default=None, help=( - "Path to the buffer filter function. " - "It should be able to select the samples in the buffer. " - "The function should take list[list[Sample]] and return list[list[Sample]]." + "Path to the function to filter or sort samples in the rollout buffer before training. " + "[Ref](../get_started/customization.md#5-buffer-filter---buffer-filter-path)" ), ) # update weight @@ -404,20 +415,24 @@ def add_rollout_arguments(parser): type=int, default=512 * 1024**2, help=( - "buffer size for update weight, in bytes. " - "This is used for updating weights by chunk and should be useful for MoE models." + "Buffer size for updating weights, in bytes. " + "[Ref](https://hebiao064.github.io/rl-weight-sync#42-optimizing-sglang-server-calls-with-tensor-bucketing-from-50s-to-30s)" ), ) parser.add_argument( "--update-weights-interval", type=int, default=1, - help="Interval for updating the weights", + help="Interval (in rollout rounds) for syncing weights to inference engines. Set to `>1` for async RL.", ) parser.add_argument( "--keep-old-actor", action="store_true", - help="Whether to keep the rollout model on training process", + help=( + 'Maintains a "Model Queue" (Actor, Rollout Actor, Old Actor) to ensure importance sampling ratios are calculated against the exact policy version that generated the data. ' + "Essential for asynchronous RL where training and inference are decoupled, preventing mathematical incorrectness due to model staleness. " + "It consumes additional Host Memory (extra ~1x model size for `update_weights_interval > 1` or 2x for `update_weights_interval == 1`) depending on update interval." + ), ) parser.add_argument( @@ -425,8 +440,8 @@ def add_rollout_arguments(parser): type=str, default=None, help=( - "The called after we have all the rollout data including log_probs. " - "It may be helpful for updating loss mask." + "Path to a function called after all rollout data (including log probs) is ready. " + "[Ref](../get_started/customization.md#8-rollout-data-postprocess---rollout-data-postprocess-path)" ), ) parser.add_argument( @@ -440,7 +455,7 @@ def add_rollout_arguments(parser): type=str, default=None, nargs="+", - help="Address and ports of the external engines.", + help="Addresses and ports of the external engines.", ) return parser @@ -449,25 +464,29 @@ def add_fault_tolerance_arguments(parser): "--use-fault-tolerance", action="store_true", default=False, - help="Whether to enable the fault tolerance function during rollout.", + help="Enable fault tolerance for rollout engines. Periodically sends `/health_generate` heartbeats.", ) parser.add_argument( "--rollout-health-check-interval", type=float, default=30.0, - help="Interval in seconds between rollout engine /health_generate checks during generate/eval.", + help="Interval in seconds between rollout engine `/health_generate` checks during generate/eval.", ) parser.add_argument( "--rollout-health-check-timeout", type=float, default=30.0, - help="Timeout in seconds to wait for a rollout engine /health_generate response before killing it.", + help="Timeout in seconds to wait for a rollout engine `/health_generate` response before killing it.", ) parser.add_argument( "--rollout-health-check-first-wait", type=float, default=0, - help="Initial grace period (in seconds) before starting health checks. This allows time for model compilation and initialization. Increase this value significantly when using deepgemm.", + help=( + "Initial grace period (in seconds) before starting health checks. " + "This allows time for model compilation and initialization. " + "Increase this value significantly when using deepgemm." + ), ) return parser @@ -479,7 +498,11 @@ def add_data_arguments(parser): "--num-rollout", type=int, default=None, - help="Number of rollout steps. If not set, we will calculate the number of rollout steps from the dataset size.", + help=( + "Number of rollout steps. " + "If not set, Miles will calculate the number of rollout steps from the dataset size. " + "**Note:** This value will be overwritten if `--num-epoch` is also set." + ), ) parser.add_argument( "--num-epoch", @@ -487,9 +510,8 @@ def add_data_arguments(parser): default=None, help=( "Number of epochs for the training. " - "This is used to calculate the number of rollout steps from the dataset size. " - "If set, we will calculate the number of rollout steps as `num_rollout = num_epoch * dataset_size // rollout_batch_size`." - "If both `--num-epoch` and `--num-rollout` are set, `--num-epoch` will be ignored." + "If set, `num_rollout` is calculated as `(num_epoch * dataset_size) // rollout_batch_size`. " + "**Note:** This argument takes precedence and will overwrite `--num-rollout` if both are specified." ), ) @@ -498,8 +520,10 @@ def add_data_arguments(parser): action="store_false", dest="rollout_global_dataset", help=( - "Disable the global dataset for rollout. By default, Miles loads `--prompt-data` into a global dataset and samples from it for rollout. " - "Setting this flag turns off this behavior, Use this flag only when providing a custom `--rollout-function-path` (and usually a custom `--data-source-path`) that handles data loading independently." + "Disable the global dataset for rollout. " + "By default, Miles loads `--prompt-data` into a global dataset and samples from it for rollout. " + "Setting this flag turns off this behavior. " + "Use this flag only when providing a custom `--rollout-function-path` (and usually a custom `--data-source-path`) that handles data loading independently." ), ) @@ -507,41 +531,65 @@ def add_data_arguments(parser): "--data-source-path", type=str, default="miles.rollout.data_source.RolloutDataSourceWithBuffer", - help="The data source class for rollout data.", + help=( + "Path to a custom Python class for the rollout data source. " + "[Ref](../get_started/customization.md#15-data-source---data-source-path)" + ), ) parser.add_argument( "--prompt-data", type=str, default=None, help=( - "The path to the prompt data. " - "Currently we only support jsonl format, and each line should contains --input-key and --label-key, " - "which will be used as the prompt and the label respectively." - "If you want to use a custom template, you can set --apply-chat-template to true, in that case, " - "the input should be the same structure as an openai message, e.g. [{'role': 'user', 'content': 'blabla'}]. " + "Path to the prompt dataset (JSONL format), and each line should contain `--input-key` and `--label-key`, " + "which will be used as the prompt and the label, respectively." + ), + ) + parser.add_argument( + "--apply-chat-template", + action="store_true", + default=False, + help=( + "Whether to apply the chat template to the input prompt. " + "The input should be the same structure as an OpenAI message, e.g., `[{'role': 'user', 'content': 'blabla'}]`." ), ) - parser.add_argument("--apply-chat-template", action="store_true", default=False) # Temporarily be JSON-serialized str, will be a real dict after using Omegaconf - parser.add_argument("--apply-chat-template-kwargs", type=json.loads, default="{}") - parser.add_argument("--input-key", type=str, default="input", help="JSON dataset key") - parser.add_argument("--label-key", type=str, default=None, help="JSON dataset key") + parser.add_argument( + "--apply-chat-template-kwargs", + type=json.loads, + default="{}", + help="Extra arguments for the chat template processing (JSON string).", + ) + parser.add_argument( + "--input-key", + type=str, + default="input", + help="Key in the JSONL data representing the user input/prompt.", + ) + parser.add_argument( + "--label-key", + type=str, + default=None, + help="Key in the JSONL data representing the label/ground truth.", + ) parser.add_argument( "--multimodal-keys", type=json.loads, default=None, - help=( - 'JSON string for multimodal data mapping media types to data keys. Example: \'{"image": "image_file"}\'' - ), + help='JSON string for multimodal data mapping media types to data keys. Example: `\'{"image": "image_file"}\'`', + ) + parser.add_argument( + "--metadata-key", + type=str, + default="metadata", + help="When adding tools during `apply_chat_template`, provide the key for the tools to the prompt dataset.", ) - parser.add_argument("--metadata-key", type=str, default="metadata", help="JSON dataset key") parser.add_argument( "--tool-key", type=str, default="tools", - help=( - "When need to add tools during apply_chat_template, you should provide the key for the tools in the prompt dataset." - ), + help="JSON key for tool definitions in the prompt dataset (used when applying chat templates).", ) parser.add_argument( @@ -549,8 +597,9 @@ def add_data_arguments(parser): type=int, default=None, help=( - "The starting rollout step, if not set, will try to load the step from --load when doing continue training, " - "otherwise will be set to 0, meaning training from start." + "The starting rollout step. " + "If not set, it is inferred from the --load checkpoint when resuming training. " + "Otherwise, if training is not continuous, Miles will start training from scratch" ), ) @@ -559,13 +608,16 @@ def add_data_arguments(parser): "--rollout-batch-size", type=int, required=True, - help=( - "The number of prompts in each rollout step. " - "The total data returned should be rollout_batch_size * n_samples_per_prompt. " - ), + help="Number of prompts per rollout batch. The total data returned should be `rollout_batch_size` * `n_samples_per_prompt`.", ) parser.add_argument( - "--n-samples-per-prompt", type=int, default=1, help="Number of responses for each prompt in generation" + "--n-samples-per-prompt", + type=int, + default=1, + help=( + "Number of responses to generate for each prompt, e.g., the group size of GRPO. " + "The default rollout pipeline expects each prompt group to contain exactly `n_samples_per_prompt` samples." + ), ) # gbs of the training, note that the gbs is of sample, not of prompts, @@ -577,8 +629,11 @@ def add_data_arguments(parser): type=int, default=None, help=( - "Number of steps per rollout, e.g. It is equivalent to setting gbs as " - "`rollout_batch_size * n_samples_per_prompt // num_steps_per_rollout`." + "The number of training steps to perform using the data collected in a single rollout round. " + "Setting this to `n` means the policy model will be updated `n` times using the same batch of rollout data. " + "Miles ensures that `(rollout-batch-size * n-samples-per-prompt) = (global-batch-size * num-steps-per-rollout)`. " + "If this value is not provided, you have to set `--global-batch-size` explicitly. " + "If both are provided, `--num-steps-per-rollout` will **override** the global batch size with `num_steps_per_rollout = (rollout_batch_size * n_samples_per_prompt) // num_steps_per_rollout`." ), ) # mbs for the training, will be ignored if `use_dynamic_batch_size` is set. @@ -588,8 +643,8 @@ def add_data_arguments(parser): action="store_true", default=False, help=( - "Repartition each rollout batch so each data-parallel rank gets a similar total token count via Karmarkar-Karp method. " - "It may be beneficial for training speed but changes per-rank sample grouping and adds a small CPU scheduling overhead." + "Repartition each rollout batch so each data-parallel rank gets a similar total token count via the Karmarkar-Karp method. " + "It may be beneficial for training speed, but changes per-rank sample grouping and adds a small CPU scheduling overhead." ), ) @@ -598,10 +653,11 @@ def add_data_arguments(parser): action="store_true", default=False, help=( - "Because the sample length varies, to maximize the GPU utilization, " - "we will use the dynamic batch size to adjust the micro batch size according to the maximum number of tokens each gpu can run. " - "For example, if we have 3 samples, with the length of 100, 200, and 300, and the max_tokens_per_gpu is 300, when enabling " - "dynamic batch size, miles will make 2 micro batches, i.e. [100, 200], [300]." + "Dynamically packs variable-length samples into micro-batches to maximize GPU utilization, ensuring the total token count per batch does not exceed `--max-tokens-per-gpu`. " + "For example, with a 300-token limit, samples of lengths 100, 200, and 300 would be packed into two batches: `[100, 200]` and `[300]`. " + "**Note:** Miles ensures that enabling this optimization does not affect the mathematical correctness of per-sample or per-token loss calculation. " + "It is **strongly recommended** to enable this for maximum efficiency. " + "**Compatibility:** only supported when `--qkv-format` is `thd` (does not work for `bshd`)." ), ) parser.add_argument( @@ -609,9 +665,9 @@ def add_data_arguments(parser): type=int, default=None, help=( - "The maximum number of tokens per GPU for dynamic batch size. " - "Note that when enabling context parallel (CP), the max tokens per gpu should be around " - "`max_response_len // cp_size` instead of `max_response_len`." + "The maximum number of tokens (Prompt + Response combined) per GPU for dynamic batch size. " + "This parameter defines the total sequence length budget for packing samples into micro-batches during training. " + "Note that when enabling context parallel (CP), the effective capacity is shared, so the value should be approximately `(Total_Sequence_Length) // cp_size`." ), ) parser.add_argument( @@ -620,8 +676,7 @@ def add_data_arguments(parser): default=None, help=( "The maximum number of tokens per GPU for calculating log probs. " - "This is used to calculate the log probs of the responses during rollout, " - "and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. " + "This is used to calculate the log probs of the responses during rollout, and should be set to a larger value than `max_tokens_per_gpu` if you want better performance." ), ) return parser @@ -632,8 +687,8 @@ def add_eval_arguments(parser): type=str, default=None, help=( - "Path to the eval generation function." - "If not set, we will use rollout_function_path as the default. " + "Path to a custom evaluation function. " + "[Ref](../get_started/customization.md#16-evaluation-function---eval-function-path)" ), ) @@ -645,45 +700,79 @@ def add_eval_arguments(parser): type=str, default=None, nargs="+", - help=( - "Path to the evaluation prompt data, " - "should first input the name of the eval dataset and then the path, e.g. " - "aime /path/to/aime.jsonl" - ), + help="List of name and path pairs for evaluation datasets (e.g., `aime /path/to/aime.jsonl`).", ) parser.add_argument( "--eval-config", type=str, default=None, - help=( - "Path to an OmegaConf YAML/JSON file describing evaluation datasets. " - "When provided, this overrides --eval-prompt-data." - ), + help="Path to an OmegaConf YAML/JSON file describing evaluation datasets (overrides `--eval-prompt-data`).", ) parser.add_argument( "--skip-eval-before-train", action="store_true", default=False, - help="Whether to skip evaluation before training.", + help="Skip the evaluation step before training starts.", ) # The following keys are used to override the rollout version during eval. - parser.add_argument("--eval-input-key", type=str, default=None, help="JSON dataset key") - parser.add_argument("--eval-label-key", type=str, default=None, help="JSON dataset key") - parser.add_argument("--eval-tool-key", type=str, default=None, help="JSON dataset key") + parser.add_argument( + "--eval-input-key", type=str, default=None, help="JSON key for input text in evaluation datasets." + ) + parser.add_argument( + "--eval-label-key", + type=str, + default=None, + help="JSON key for ground truth labels in evaluation datasets.", + ) + parser.add_argument( + "--eval-tool-key", type=str, default=None, help="JSON key for tool definitions in evaluation datasets." + ) parser.add_argument( "--n-samples-per-eval-prompt", type=int, default=1, - help="number of responses for each prompt in generation", + help="Number of responses for each prompt in generation.", + ) + parser.add_argument( + "--eval-temperature", + type=float, + default=None, + help="Temperature for evaluation (defaults to rollout temperature if not set).", + ) + parser.add_argument( + "--eval-top-p", + type=float, + default=None, + help="Top-p sampling threshold for evaluation (defaults to rollout top-p if not set).", + ) + parser.add_argument( + "--eval-top-k", + type=int, + default=None, + help="Top-k sampling threshold for evaluation (defaults to rollout top-k if not set).", + ) + parser.add_argument( + "--eval-max-response-len", + type=int, + default=None, + help="Maximum response length for evaluation (defaults to rollout max response length if not set).", + ) + parser.add_argument( + "--eval-max-prompt-len", type=int, default=None, help="Maximum prompt length for evaluation." + ) + parser.add_argument( + "--eval-min-new-tokens", + type=int, + default=None, + help="Minimum tokens to generate for evaluation responses (Not used).", + ) + parser.add_argument( + "--eval-max-context-len", + type=int, + default=None, + help="Maximum context length for evaluation (defaults to rollout max context length if not set).", ) - parser.add_argument("--eval-temperature", type=float, default=None) - parser.add_argument("--eval-top-p", type=float, default=None) - parser.add_argument("--eval-top-k", type=int, default=None) - parser.add_argument("--eval-max-response-len", type=int, default=None) - parser.add_argument("--eval-max-prompt-len", type=int, default=None) - parser.add_argument("--eval-min-new-tokens", type=int, default=None) - parser.add_argument("--eval-max-context-len", type=int, default=None) return parser @@ -692,13 +781,10 @@ def add_algo_arguments(parser): "--ref-load", type=str, default=None, - help=( - "The checkpoint for reference model. " - "When --load is not set, this will be used as the initial checkpoint for training. " - ), + help="Path to the reference model checkpoint. Used as an initial checkpoint if `--load` is not set.", ) parser.add_argument( - "--ref-ckpt-step", type=int, default=None, help="The checkpoint step for reference model. " + "--ref-ckpt-step", type=int, default=None, help="The checkpoint step for the reference model." ) reset_arg(parser, "--load", type=str, default=None) reset_arg(parser, "--save", type=str, default=None) @@ -709,10 +795,7 @@ def add_algo_arguments(parser): "--no-save-optim", action="store_true", default=False, - help=( - "If set, do not save the optimizer state when saving checkpoints. " - "This reduces checkpoint size but disables training resumption from the saved checkpoint." - ), + help="If set, optimizer state is not saved with checkpoints to reduce size, but prevents resumption of training.", ) parser.add_argument( "--save-hf", @@ -720,7 +803,7 @@ def add_algo_arguments(parser): default=None, help=( "Path to save the model in HuggingFace format when using Megatron backend. " - "The model will be saved to `save_hf.format(rollout_id)`. " + "The model will be saved to `save_hf.format(rollout_id)`." ), ) reset_arg(parser, "--seed", type=int, default=1234) @@ -728,49 +811,63 @@ def add_algo_arguments(parser): reset_arg(parser, "--calculate-per-token-loss", action="store_true") reset_arg(parser, "--lr", type=float, default=1e-6) - parser.add_argument("--num-critic-only-steps", type=int, default=0, help="Number of critic only steps") - parser.add_argument("--critic-load", type=str, default=None, help="The checkpoint for critic model.") - parser.add_argument("--critic-save", type=str, default=None, help="The checkpoint for critic model.") - parser.add_argument("--critic-lr", type=float, default=None, help="The lr for critic model") + parser.add_argument( + "--num-critic-only-steps", + type=int, + default=0, + help="Number of initial steps dedicated to training only the Critic.", + ) + parser.add_argument( + "--critic-load", type=str, default=None, help="Checkpoint to load for the critic model." + ) + parser.add_argument("--critic-save", type=str, default=None, help="Path to save the critic model.") + parser.add_argument( + "--critic-lr", type=float, default=None, help="Learning rate for the Critic. Defaults to `--lr`." + ) parser.add_argument( "--critic-lr-warmup-iters", type=int, default=0, - help="number of iterations to linearly warmup for critic model.", + help="Number of iterations for Critic learning rate linear warmup.", ) - parser.add_argument("--eps-clip", type=float, default=0.2, help="PPO clip range") - parser.add_argument("--eps-clip-high", type=float, default=None, help="PPO clip upper range") + parser.add_argument("--eps-clip", type=float, default=0.2, help="PPO clip range.") + parser.add_argument( + "--eps-clip-high", + type=float, + default=None, + help="PPO clip upper range (defaults to `--eps-clip` if not set).", + ) parser.add_argument( "--eps-clip-c", type=float, default=None, - help="lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729", + help="Lower bound for [Dual-clip PPO](https://arxiv.org/pdf/1912.09729).", ) - parser.add_argument("--value-clip", type=float, default=0.2, help="the clip for value loss") + parser.add_argument("--value-clip", type=float, default=0.2, help="Clip range for value loss.") parser.add_argument( "--kl-coef", type=float, default=0.00, - help="KL penalty coefficient for reward shaping. This is applied to the reward signal before advantage calculation.", + help=( + "KL penalty coefficient for reward shaping. " + "This is applied to the reward signal before advantage calculation for PPO and REINFORCE-style estimator." + ), ) parser.add_argument( "--loss-type", type=str, choices=["policy_loss", "sft_loss", "custom_loss"], default="policy_loss", - help=( - "Choose loss type, currently support ppo policy_loss or sft_loss, " - "if custom_loss is set, we will use the function path from `--custom-loss-function-path`." - ), + help="Type of loss function to use.", ) parser.add_argument( "--custom-loss-function-path", type=str, default=None, help=( - "Path to the custom loss function, if the loss_type is `custom_loss`, " - "we will use this function to calculate the loss. " + "Path to a custom loss calculation function (requires `--loss-type custom_loss`). " + "[Ref](../get_started/customization.md#9-custom-loss-function---custom-loss-function-path)" ), ) parser.add_argument( @@ -778,7 +875,10 @@ def add_algo_arguments(parser): type=str, choices=["k1", "k2", "k3", "low_var_kl"], default="k1", - help="Choose KL loss type: kl, k2, k3, low_var_kl", + help=( + "Selection of the KL loss implementation. " + "See [Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for more details." + ), ) parser.add_argument( "--advantage-estimator", @@ -792,76 +892,103 @@ def add_algo_arguments(parser): "on_policy_distillation", ], default="grpo", + help="Advantage estimator to use.", ) parser.add_argument( "--disable-compute-advantages-and-returns", action="store_false", dest="compute_advantages_and_returns", help=( - "Whether to disable computing advantages and returns. " - "If set, we will not compute the advantages and returns, " - "This is useful for sft or custom loss function." + "Disables the calculation of advantages and returns. " + "This is typically used for SFT or custom loss functions where value estimation is not required." ), ) parser.add_argument( - "--use-kl-loss", action="store_true", default=False, help="whether to use KL loss from GRPO" + "--use-kl-loss", + action="store_true", + default=False, + help="Enable KL loss term in the final objective (as in GRPO).", ) parser.add_argument( "--kl-loss-coef", type=float, default=0.0, - help="KL penalty coefficient for the loss function. This is added to the final PPO loss.", + help="Weight of the KL loss term in the final objective.", ) parser.add_argument( "--use-unbiased-kl", action="store_true", default=False, - help="Whether to enable unbiased KL estimation.", + help="Apply Importance Sampling (IS) correction to the KL estimator. Reduces bias from distribution shift.", ) parser.add_argument( "--ref-update-interval", type=int, default=None, - help="Interval (in rollout steps) to update ref model from actor. If None, ref model is not updated.", + help="Interval (in rollout steps) to update ref model from actor. If `None`, ref model is not updated.", + ) + parser.add_argument( + "--entropy-coef", + type=float, + default=0.0, + help=( + "Coefficient for entropy regularization term. " + "Penalizes low entropy to encourage exploration and prevent premature convergence." + ), + ) + parser.add_argument( + "--gamma", + type=float, + default=1.0, + help="Discount factor for future rewards. Used in PPO (GAE) and REINFORCE++.", + ) + parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambda.") + parser.add_argument( + "--normalize-advantages", + action="store_true", + default=False, + help=( + "Performs distributed masked whitening of advantages. " + "Normalization statistics are computed globally across the Data-Parallel group, ignoring padding tokens." + ), ) - parser.add_argument("--entropy-coef", type=float, default=0.0, help="Entropy loss coef") - parser.add_argument("--gamma", type=float, default=1.0, help="PPO GAE gamma") - parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambd") - parser.add_argument("--normalize-advantages", action="store_true", default=False) parser.add_argument( "--disable-grpo-std-normalization", action="store_false", dest="grpo_std_normalization", - help="from Dr.GRPO https://arxiv.org/pdf/2503.20783", + help="Disable standard deviation normalization for GRPO. From [Dr.GRPO](https://arxiv.org/pdf/2503.20783)", ) parser.add_argument( "--disable-rewards-normalization", action="store_false", dest="rewards_normalization", - help="Disable rewards normalization", + help=( + "Disable the default group-wise reward normalization for GRPO, GSPO, and REINFORCE++. " + "This effectively skips the baseline subtraction step." + ), ) parser.add_argument( "--use-rollout-entropy", action="store_true", default=False, help=( - "Whether to calculate the entropy when calculating the logprobs from actor and reference model. " - "This is useful for doing special loss mask." + "Enable entropy calculation when calculating the logprobs from actor and reference model. " + "This is useful for implementing custom entropy-based loss masking." ), ) parser.add_argument( "--get-mismatch-metrics", action="store_true", default=False, - help="Whether to calculate the mismatch metrics.", + help="Calculate mismatch metrics. If it is set, you need to provide a custom TIS function via `--custom-tis-function-path`.", ) parser.add_argument( "--reset-optimizer-states", action="store_true", default=False, help=( - "Whether to reset optimizer states after each rollout. " - "If enabled, the optimizer's history will be cleared at the end of each rollout, which can sometimes help with training stability or fulfill specific experiment requirements." + "Resets the optimizer state after each rollout round. " + "This clears the optimization history, which can improve stability or satisfy specific experimental requirements." ), ) parser.add_argument( @@ -869,8 +996,9 @@ def add_algo_arguments(parser): action="store_true", default=False, help=( - "Whether to use the rollout logprobs when calculating the importance sampling ratios. " - "If not set, we will use the logprobs from the actor model." + "Use rollout logprobs as the old-policy logprobs when computing importance sampling ratios / PPO-style KL in GRPO/GSPO/PPO. " + "If not set, Miles recomputes old-policy logprobs with the training actor (e.g., `old_actor` or `actor`, depending on configuration). " + "If `--get-mismatch-metrics` is set, the log probs will still be recomputed by the training engine (one more forward pass will be applied)." ), ) # Off-Policy Correction using Importance Sampling: https://fengyao.notion.site/off-policy-rl @@ -878,7 +1006,10 @@ def add_algo_arguments(parser): "--use-tis", action="store_true", default=False, - help="Enable TIS from https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33.", + help=( + "Enable Token-level Importance Sampling (TIS) from this " + "[blog](https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33)." + ), ) parser.add_argument( "--tis-clip", @@ -896,32 +1027,49 @@ def add_algo_arguments(parser): "--custom-tis-function-path", type=str, default=None, - help="Path to the custom TIS/RS function (e.g., examples/train_infer_mismatch_helper/mis.py:compute_mis_weights_with_cp).", + help=( + "Path to a custom TIS or MIS function. " + "[Ref](../get_started/customization.md#10-custom-tisrs-function---custom-tis-function-path)" + ), ) parser.add_argument( "--custom-pg-loss-reducer-function-path", type=str, default=None, - help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. (e.g., examples/Dr.GRPO/custom_reducer.py:get_pg_loss_reducer).", + help=( + "Custom reducer function for policy gradient loss. " + "[Ref](../get_started/customization.md#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path)" + ), ) parser.add_argument( "--use-routing-replay", action="store_true", default=False, - help="The routing replay technique from https://arxiv.org/abs/2507.18071", + help=( + "Enable R2 (Routing Replay) for MoE: record expert routing decisions during forward and replay them during backward. " + "[Paper](https://arxiv.org/abs/2507.18071) **Note:** automatically set to `True` when `--use-rollout-routing-replay` is enabled." + ), ) parser.add_argument( "--use-rollout-routing-replay", action="store_true", default=False, - help="The rollout routing replay technique from https://arxiv.org/abs/2510.11370", + help=( + "Enable R3 (Rollout Routing Replay) for MoE: record expert routing decisions during rollout and replay them during training. " + "**Requires `--use-miles-router`**. " + "[Paper](https://arxiv.org/abs/2510.11370) [Ref](miles-router.md#22-rollout-routing-replay-r3-for-moe)" + ), ) parser.add_argument( "--use-opsm", action="store_true", default=False, - help="Whether to enable Off-Policy Sequence Masking (OPSM).", + help=( + "Enable Off-Policy Sequence Masking (OPSM). " + "Filters sequences that have **BOTH** negative advantages (bad results) AND high KL divergence (stale data). " + "This stabilizes training by preventing updates from unreliable, highly off-policy samples." + ), ) parser.add_argument( "--opsm-delta", @@ -936,19 +1084,27 @@ def add_router_arguments(parser): "--use-miles-router", action="store_true", default=False, - help="Whether to use MilesRouter for text-based routing instead of SGLang token-based routing", + help=( + "Use Miles Router (FastAPI passthrough proxy) instead of SGLang Model Gateway for rollout routing. " + "Required for features that depend on preserving extra rollout metadata (e.g., R3). " + "[Ref](miles-router.md)" + ), ) parser.add_argument( "--miles-router-middleware-paths", type=str, nargs="+", default="", + help=( + "Paths to custom MilesRouter middleware functions. " + "[Ref](../get_started/customization.md#18-miles-router-middleware---miles-router-middleware-paths)" + ), ) parser.add_argument( "--miles-router-timeout", type=float, default=None, - help="Timeout for MilesRouter HTTP requests in seconds.", + help="Timeout for router HTTP requests in seconds.", ) parser.add_argument( "--miles-router-max-connections", @@ -979,24 +1135,24 @@ def add_router_arguments(parser): # wandb def add_wandb_arguments(parser): # wandb parameters - parser.add_argument("--use-wandb", action="store_true", default=False) + parser.add_argument("--use-wandb", action="store_true", default=False, help="Enable WandB logging.") parser.add_argument( "--wandb-mode", type=str, default=None, choices=["online", "offline", "disabled"], - help="W&B mode: online (default), offline (local only), or disabled. Overrides WANDB_MODE env var.", + help="WandB operating mode. Overrides `WANDB_MODE`.", ) parser.add_argument( "--wandb-dir", type=str, default=None, - help="Directory to store wandb logs. Default is ./wandb in current directory.", + help="Directory to store WandB logs. Default is `./wandb` in current directory.", ) - parser.add_argument("--wandb-key", type=str, default=None) - parser.add_argument("--wandb-host", type=str, default=None) - parser.add_argument("--wandb-team", type=str, default=None) - parser.add_argument("--wandb-group", type=str, default=None) + parser.add_argument("--wandb-key", type=str, default=None, help="WandB API key.") + parser.add_argument("--wandb-host", type=str, default=None, help="WandB host address.") + parser.add_argument("--wandb-team", type=str, default=None, help="WandB team name.") + parser.add_argument("--wandb-group", type=str, default=None, help="WandB group name.") reset_arg(parser, "--wandb-project", type=str, default=None) parser.add_argument( "--disable-wandb-random-suffix", @@ -1004,7 +1160,7 @@ def add_wandb_arguments(parser): dest="wandb_random_suffix", default=True, help=( - "Whether to add a random suffix to the wandb run name. " + "Disable adding a random suffix to the WandB run name. " "By default, we will add a random 6 length string with characters to the run name." ), ) @@ -1012,31 +1168,27 @@ def add_wandb_arguments(parser): "--wandb-always-use-train-step", action="store_true", default=False, - help=( - "Whether to always use train step as the step metric in wandb. " - "If set, we will always use the train steps for wandb logging, " - "otherwise, will use rollout step for most info other than train/*. " - ), + help="Use training steps instead of rollout steps for the x-axis.", ) parser.add_argument( "--log-multi-turn", action="store_true", default=False, - help="Whether to log information for multi-turn rollout.", + help="Log detailed information for multi-turn conversations.", ) parser.add_argument( "--log-passrate", action="store_true", default=False, - help="Whether to turn on passrate logging, which will log the pass@n of the responses in the rollout.", + help="Enable logging of `pass@n` metrics.", ) parser.add_argument( "--log-reward-category", type=str, default=None, help=( - "Log statistics of the category of reward, such as why the reward function considers it as failed. " - "Specify the key in the reward dict using this argument.", + "Log reward-category statistics (e.g., why the reward function marked a failure). " + "Use this argument to specify the key in the reward dict." ), ) parser.add_argument( @@ -1045,20 +1197,22 @@ def add_wandb_arguments(parser): default=False, help="Explicitly log metrics for correct samples.", ) - parser.add_argument("--wandb-run-id", type=str, default=None) + parser.add_argument("--wandb-run-id", type=str, default=None, help="Specific WandB run ID to resume.") return parser # tensorboard def add_tensorboard_arguments(parser): # tb_project_name, tb_experiment_name - parser.add_argument("--use-tensorboard", action="store_true", default=False) + parser.add_argument( + "--use-tensorboard", action="store_true", default=False, help="Enable Tensorboard logging." + ) parser.add_argument( "--tb-project-name", type=str, default=None, - help="Directory to store tensorboard logs. Default is os.environ.get('TENSORBOARD_DIR') directory.", + help="Tensorboard project directory.", ) - parser.add_argument("--tb-experiment-name", type=str, default=None) + parser.add_argument("--tb-experiment-name", type=str, default=None, help="Tensorboard experiment name.") return parser @@ -1068,70 +1222,53 @@ def add_debug_arguments(parser): "--save-debug-rollout-data", type=str, default=None, - help=( - "Save the rollout data to this path for debugging. " - "The file will be saved to `save_debug_rollout_data.format(rollout_id)`." - ), + help="Path to save rollout data for offline analysis. [Ref](../developer_guide/debug.md)", ) parser.add_argument( "--load-debug-rollout-data", type=str, default=None, - help=( - "Load the rollout data from this path for debugging. " - "The file will be loaded from `load_debug_rollout_data.format(rollout_id)`. " - "When this is enabled, miles will not instantiate sglang servers." - ), + help="Path to load debug rollout data (bypasses SGLang). [Ref](../developer_guide/debug.md)", ) parser.add_argument( "--load-debug-rollout-data-subsample", type=float, default=None, - help="Subsample a portion of the debug rollout data for faster debugging.", + help="Percentage of debug data to load (0.0 to 1.0). [Ref](../developer_guide/debug.md)", ) parser.add_argument( "--debug-rollout-only", action="store_true", default=False, - help=( - "Whether to only run the rollout generation without training. " - "This is useful for debugging the rollout generation function." - ), + help="Run the rollout phase only without training. [Ref](../developer_guide/debug.md)", ) parser.add_argument( "--debug-train-only", action="store_true", default=False, - help=( - "Whether to only run the training without sglang servers. " - "This is useful for debugging the rollout generation function." - ), + help="Run the training phase only without launching SGLang servers. [Ref](../developer_guide/debug.md)", ) parser.add_argument( "--save-debug-train-data", type=str, default=None, - help=( - "Save the train data to this path for debugging. " - "The file will be saved to `save_debug_train_data.format(rollout_id)`." - ), + help="Path to save training batches for offline math debugging.", ) parser.add_argument( "--dump-details", type=str, default=None, - help=("Dump all details of training for post-hoc analysis and visualization."), + help="Dump exhaustive training details for post-hoc visualization.", ) # use together with --record-memory-history and --memory-snapshot-path (defined in Megatron) parser.add_argument( - "--memory-snapshot-dir", - type=str, - default=".", + "--memory-snapshot-dir", type=str, default=".", help="Directory for PyTorch memory snapshots." ) parser.add_argument( "--memory-snapshot-num-steps", type=int, default=None, + help="Number of steps to record before saving snapshot.", ) parser.add_argument( "--profile-target", @@ -1139,19 +1276,32 @@ def add_debug_arguments(parser): choices=["train_overall", "train_actor", "train_log_probs"], default=["train_overall"], nargs="+", + help="Training components to profile (accepts multiple).", ) parser.add_argument( "--memory-recorder", type=str, choices=["torch", "memray"], default="torch", + help="Selection of the memory recording backend.", + ) + parser.add_argument( + "--check-weight-update-equal", + action="store_true", + help="Use SGLang's weight checker to check and ensure that the loaded weight from HF checkpoint and received from Megatron are bit-wise equal.", ) - parser.add_argument("--check-weight-update-equal", action="store_true") return parser def add_network_arguments(parser): - parser.add_argument("--http-proxy", type=str, default=None) - parser.add_argument("--use-distributed-post", action="store_true", default=False) + parser.add_argument( + "--http-proxy", type=str, default=None, help="HTTP proxy server for remote reward model calls." + ) + parser.add_argument( + "--use-distributed-post", + action="store_true", + default=False, + help="Use distributed POST requests for remote reward models.", + ) return parser def add_reward_model_arguments(parser): @@ -1159,40 +1309,43 @@ def add_reward_model_arguments(parser): "--rm-type", type=str, default=None, - help="Type of the reward model", + help="Built-in reward model selection.", ) parser.add_argument( "--reward-key", type=str, default=None, - help=( - "Some reward model may return a dict instead of a value, " - "this is the key to extract the reward value from the dict. " - ), + help="JSON key to extract the numerical reward from a returned dictionary if reward model returns a dict instead of a value.", ) parser.add_argument( "--eval-reward-key", type=str, default=None, - help="The eval variant for --reward-key", + help="Evaluation variant for `--reward-key`.", ) parser.add_argument( - "--group-rm", action="store_true", default=False, help="Whether to do rm on a whole group." + "--group-rm", + action="store_true", + default=False, + help=( + "Defer reward computation to process the entire group of samples (per-prompt) at once. " + "Essential for comparative/ranking reward models and improves throughput. " + "**Not supported in eval**." + ), ) parser.add_argument( "--rm-url", type=str, default=None, - help="URL for the reward model service for --rm-type remote_rm, e.g. http://localhost:8000", + help="URL for the reward model service (used with `--rm-type remote_rm`).", ) parser.add_argument( "--custom-rm-path", type=str, default=None, help=( - "Path to the custom reward model function. " - "If set, we will use this function to calculate the reward instead of the default one. " - "The function should have the signature `def custom_rm(args, sample) -> float`." + "Path to a custom Python reward function. " + "[Ref](../get_started/customization.md#3-reward-model---custom-rm-path)" ), ) parser.add_argument( @@ -1200,7 +1353,8 @@ def add_reward_model_arguments(parser): type=str, default=None, help=( - "Path to the custom function that will post process reward, by default it will be the normalization for grpo. " + "Path to a custom reward post-processor. " + "[Ref](../get_started/customization.md#12-reward-post-processing---custom-reward-post-process-path)" ), ) parser.add_argument( @@ -1208,9 +1362,8 @@ def add_reward_model_arguments(parser): type=str, default=None, help=( - "Path to a custom function that converts samples to training data. " - "If set, this function will replace the default _convert_samples_to_train_data. " - "The function should have the signature `def convert_samples_to_train_data(args, samples) -> dict`." + "Path to a custom data format converter. " + "[Ref](../get_started/customization.md#13-samples-to-train-data-conversion---custom-convert-samples-to-train-data-path)" ), ) return parser @@ -1220,49 +1373,47 @@ def add_rollout_buffer_arguments(parser): "--rollout-buffer-url", type=str, default=None, - help="URL for the rollout buffer", + help="URL for the rollout buffer service.", ) parser.add_argument( "--fetch-trajectory-retry-times", type=int, default=-1, - help="Number of times to retry fetching trajectory, -1 means unlimited retry", + help="Number of times to retry fetching trajectory, -1 means unlimited retry.", ) parser.add_argument( "--min-batch-collection-ratio", type=float, default=1, - help="Minimum batch collection ratio", - ) - parser.add_argument( - "--rollout-task-type", - type=str, - default="math", + help="Minimum batch collection ratio before proceeding.", ) + parser.add_argument("--rollout-task-type", type=str, default="math", help="Type of task being performed.") parser.add_argument( "--loss-mask-type", type=str, default="qwen", choices=["qwen", "qwen3", "distill_qwen"], - help="Loss mask type", + help="Selection of the token masking logic.", ) parser.add_argument( "--data-pad-size-multiplier", type=int, default=128, - help="Multiplier for data padding size in data processing.", + help=( + "Multiplier used to calculate the sequence padding boundary. " + "Miles rounds sequence lengths up to a multiple of `tensor_parallel_size * data_pad_size_multiplier`. " + "This optimization ensures that matrix dimensions are aligned with NVIDIA Tensor Core requirements, maximizing throughput and reducing VRAM fragmentation. " + "**Note:** better not change this; values `<128` may trigger accuracy loss under `--qkv-format thd` when `TP>=4`." + ), ) parser.add_argument( "--rollout-sample-filter-path", type=str, default=None, help=( - "Path to the rollout sample filter function. " - "This function determines whether a sample will participate in loss calculation. " - "The function should take args and samples (list[Sample]) as input, and return None. " - "Please directly modify the remove_sample attribute of Sample. " - "Note: This attribute does not determine whether the sample participates in advantage normalization." + "Path to the function that marks individual samples to be excluded from loss calculation. " + "[Ref](../get_started/customization.md#6-rollout-sample-filter---rollout-sample-filter-path)" ), ) parser.add_argument( @@ -1270,21 +1421,21 @@ def add_rollout_buffer_arguments(parser): type=str, default=None, help=( - "Path to the rollout all samples process function that " - "can process all samples including filtered ones." + "Path to the function to process all samples (including filtered ones) after rollout. " + "[Ref](../get_started/customization.md#7-rollout-all-samples-process---rollout-all-samples-process-path)" ), ) parser.add_argument( "--disable-rollout-trim-samples", action="store_true", default=False, - help="disable trim samples in rollout buffer when converting samples to train data", + help="Disable trim samples in rollout buffer when converting samples to train data.", ) parser.add_argument( "--use-dynamic-global-batch-size", action="store_true", default=False, - help="enable dynamic global batch size, disable trim samples in rollout buffer when converting samples to train data", + help="Enable dynamic global batch size, disable trim samples in rollout buffer when converting samples to train data.", ) return parser @@ -1298,16 +1449,28 @@ def add_custom_megatron_plugins_arguments(parser): "--custom-megatron-init-path", type=str, default=None, + help=( + "Path to custom Megatron initialization logic. " + "[Ref](../get_started/customization.md#17-megatron-hooks)" + ), ) parser.add_argument( "--custom-megatron-before-log-prob-hook-path", type=str, default=None, + help=( + "Hook called before calculating log probabilities. " + "[Ref](../get_started/customization.md#17-megatron-hooks)" + ), ) parser.add_argument( "--custom-megatron-before-train-step-hook-path", type=str, default=None, + help=( + "Hook called before each training step. " + "[Ref](../get_started/customization.md#17-megatron-hooks)" + ), ) return parser @@ -1319,7 +1482,7 @@ def add_mtp_training_arguments(parser): "--enable-mtp-training", action="store_true", default=False, - help="Enable MTP layer parameter updates during training", + help="Enable MTP layer parameter updates during training.", ) return parser @@ -1329,38 +1492,29 @@ def add_prefill_decode_disaggregation_arguments(parser): "--prefill-num-servers", type=int, default=None, - help="Number of prefill servers for disaggregation.", + help="Number of dedicated prefill servers for PD disaggregation.", ) return parser def add_ci_arguments(parser): + parser.add_argument("--ci-test", action="store_true", help="Enable Continuous Integration testing mode.") parser.add_argument( - "--ci-test", - action="store_true", + "--ci-disable-kl-checker", action="store_true", help="Disable KL divergence sanity checks in CI." ) parser.add_argument( - "--ci-disable-kl-checker", - action="store_true", - ) - parser.add_argument( - "--ci-metric-checker-key", - type=str, - default=None, + "--ci-metric-checker-key", type=str, default=None, help="Metric key to monitor for pass/fail in CI." ) parser.add_argument( "--ci-metric-checker-threshold", type=float, default=None, + help="Pass/fail threshold (minimum value) for the monitored metric.", ) parser.add_argument( - "--ci-save-grad-norm", - type=str, - default=None, + "--ci-save-grad-norm", type=str, default=None, help="Path to save gradient norms for CI comparison." ) parser.add_argument( - "--ci-load-grad-norm", - type=str, - default=None, + "--ci-load-grad-norm", type=str, default=None, help="Path to load gradient norms for CI verification." ) return parser