Skip to content

Commit c1b210d

Browse files
szaherRobotSail
andauthored
fix(torchrun): use correct datatypes for torchrun args (Red-Hat-AI-Innovation-Team#44)
* fix(torchrun): use correct datatypes for torchrun args Torchrun supports nproc_per_node and rdzv_id as str. TorchrunArgs only supports int, which is permissible by pytorch. This change will enable TorchrunArgs to support both str, int. Also, remove unset or empty parameters before passing it to torchrun args. Signed-off-by: Saad Zaher <szaher@redhat.com> * Use python3.11 style for pydatnic model Signed-off-by: Saad Zaher <szaher@redhat.com> * replace - with _ for cli args Signed-off-by: Saad Zaher <szaher@redhat.com> * make nproc_per_node to only accept gpu or int. Remove Defaults Signed-off-by: Saad Zaher <szaher@redhat.com> * add master_{addr, port} validate args Signed-off-by: Saad Zaher <szaher@redhat.com> * deep check if variables are set and not empty Co-authored-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> * Update src/mini_trainer/training_types.py Co-authored-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> * Update src/mini_trainer/api_train.py Co-authored-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> * does not automatically set --master-port * Update api_train.py * use standalone when neither rdzv_endpoint nor master_addr are provided * Update training_types.py * update tests --------- Signed-off-by: Saad Zaher <szaher@redhat.com> Co-authored-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
1 parent 37881b5 commit c1b210d

3 files changed

Lines changed: 48 additions & 15 deletions

File tree

src/mini_trainer/api_train.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,34 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
8585

8686
# Build torchrun command
8787
train_script = Path(__file__).parent / "train.py"
88-
88+
8989
command = [
9090
"torchrun",
9191
f"--nnodes={torch_args.nnodes}",
92-
f"--node_rank={torch_args.node_rank}",
93-
f"--nproc_per_node={torch_args.nproc_per_node}",
94-
f"--rdzv_id={torch_args.rdzv_id}",
95-
f"--rdzv_endpoint={torch_args.rdzv_endpoint}",
92+
f"--node-rank={torch_args.node_rank}",
93+
f"--nproc-per-node={torch_args.nproc_per_node}",
94+
f"--rdzv-id={torch_args.rdzv_id}",
95+
]
96+
97+
if torch_args.master_addr and torch_args.rdzv_endpoint:
98+
raise ValueError("Provide either `rdzv_endpoint` OR `master_addr`, not both.")
99+
100+
if torch_args.master_addr:
101+
# master-addr + master-port are only compatible with the static backend
102+
# so here we pass it explicitly
103+
command += [
104+
f"--master-addr={torch_args.master_addr}",
105+
"--rdzv-backend=static"
106+
]
107+
if torch_args.master_port:
108+
command += [f"--master-port={torch_args.master_port}"]
109+
110+
elif torch_args.rdzv_endpoint:
111+
command += [f"--rdzv-endpoint={torch_args.rdzv_endpoint}"]
112+
else:
113+
command += ["--standalone"]
114+
115+
command.extend([
96116
str(train_script),
97117
f"--model-name-or-path={train_args.model_name_or_path}",
98118
f"--data-path={train_args.data_path}",
@@ -109,7 +129,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
109129
f"--max-steps={train_args.max_steps}",
110130
f"--max-tokens={train_args.max_tokens}",
111131
f"--train-dtype={train_args.train_dtype}",
112-
]
132+
])
113133

114134

115135
# wandb-related arguments

src/mini_trainer/training_types.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dataclasses import dataclass, field
99
from enum import Enum
10-
from typing import Optional, Dict, Any
10+
from typing import Optional, Dict, Any, Literal
1111

1212

1313
class TrainingMode(str, Enum):
@@ -22,11 +22,24 @@ class TrainingMode(str, Enum):
2222
class TorchrunArgs:
2323
"""Arguments for torchrun distributed training configuration."""
2424
nnodes: int = 1
25-
nproc_per_node: int = 1
25+
nproc_per_node: Literal["gpu"] | int = 1
2626
node_rank: int = 0
27-
rdzv_id: int = 123
28-
rdzv_endpoint: str = "127.0.0.1:1738"
27+
rdzv_id: str | int = 123
2928

29+
# Optional rendezvous / master fields
30+
rdzv_endpoint: Optional[str] = None
31+
master_addr: Optional[str] = None
32+
master_port: Optional[int] = None
33+
34+
def __post_init__(self):
35+
# in order to support systems which are still relying on `master_addr`
36+
# to construct the rendezvous address, torchrun must not be given a non-empty value
37+
# for rdzv_endpoint:
38+
# https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/distributed/run.py#L799
39+
if self.rdzv_endpoint and self.master_addr:
40+
raise ValueError(
41+
"Provide either `rdzv_endpoint` OR both `master_addr` and `master_port`, not both."
42+
)
3043

3144
@dataclass
3245
class TrainingArgs:

tests/test_api_train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_torchrun_args_defaults(self):
2828
assert args.nproc_per_node == 1
2929
assert args.node_rank == 0
3030
assert args.rdzv_id == 123
31-
assert args.rdzv_endpoint == "127.0.0.1:1738"
31+
assert args.rdzv_endpoint == None
3232

3333
# Test with custom nproc_per_node only
3434
args = TorchrunArgs(nproc_per_node=8)
@@ -405,10 +405,10 @@ def test_run_training_command_construction(self, mock_popen_class):
405405
# Verify command structure
406406
assert command[0] == "torchrun"
407407
assert "--nnodes=2" in command
408-
assert "--node_rank=1" in command
409-
assert "--nproc_per_node=4" in command
410-
assert "--rdzv_id=999" in command
411-
assert "--rdzv_endpoint=master:1234" in command
408+
assert "--node-rank=1" in command
409+
assert "--nproc-per-node=4" in command
410+
assert "--rdzv-id=999" in command
411+
assert "--rdzv-endpoint=master:1234" in command
412412

413413
# Verify training arguments
414414
assert "--model-name-or-path=my-model" in command

0 commit comments

Comments
 (0)