Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion bindings/python/src/smg/router_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,17 @@ def add_cli_args(
"Control Plane Authentication", "API key and JWT/OIDC authentication"
)

if use_router_prefix:
parser.add_argument(
"--router-disable-arg-fallback",
action="store_true",
default=False,
help=(
"When set, only use explicitly provided --router-* arguments and do not"
" fall back to backend arguments with the same name."
),
)

# Worker configuration
if not exclude_host_port:
worker_group.add_argument(
Expand Down Expand Up @@ -1107,6 +1118,7 @@ def from_cli_args(cls, args: argparse.Namespace, use_router_prefix: bool = False
prefix = "router_" if use_router_prefix else ""
cli_args_dict = vars(args)
args_dict = {}
disable_arg_fallback = bool(cli_args_dict.get(f"{prefix}disable_arg_fallback", False))

for attr in dataclasses.fields(cls):
# Auto strip prefix from args.
Expand All @@ -1117,7 +1129,11 @@ def from_cli_args(cls, args: argparse.Namespace, use_router_prefix: bool = False
prefixed_key = f"{prefix}{attr.name}"
if prefixed_key in cli_args_dict and cli_args_dict[prefixed_key] is not None:
args_dict[attr.name] = cli_args_dict[prefixed_key]
elif attr.name in cli_args_dict and cli_args_dict[attr.name] not in (None, ""):
elif (
not disable_arg_fallback
and attr.name in cli_args_dict
and cli_args_dict[attr.name] not in (None, "")
Comment on lines +1133 to +1135
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep host/port fallback when router fallback is disabled

Disabling fallback here also drops host/port inheritance even though smg serve/launch_server register router args with exclude_host_port=True (bindings/python/src/smg/serve.py lines 540/554 and bindings/python/src/smg/launch_server.py line 161), so the router has no --router-host/--router-port source. With --router-disable-arg-fallback, RouterArgs.from_cli_args no longer reads unprefixed --host/--port, and the router silently reverts to dataclass defaults (0.0.0.0:30000) instead of the serve endpoint requested via CLI, which can bind the gateway to the wrong address/port.

Useful? React with 👍 / 👎.

):
Comment on lines +1132 to +1136
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve serve host/port fallback when disabling arg fallback

With --router-disable-arg-fallback, this branch blocks all unprefixed values, but serve builds router args with use_router_prefix=True and exclude_host_port=True, so router host/port can only come from unprefixed --host/--port. In that mode, enabling the new flag makes the router ignore the configured serve endpoint and silently fall back to RouterArgs defaults (0.0.0.0:30000), which can bind the gateway to the wrong address/port. Keep host/port fallback (or provide prefixed host/port args) when this flag is set.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

args_dict[attr.name] = cli_args_dict[attr.name]

# Special handling for CLI args with dashes vs dataclass fields with underscores
Expand Down
46 changes: 46 additions & 0 deletions bindings/python/tests/test_arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,52 @@ def test_from_cli_args_without_prefix(self):
assert router_args.policy == "random"
assert router_args.pd_disaggregation is False

def test_prefixed_args_fall_back_to_backend_args_by_default(self):
"""Prefixed router args should still fall back to backend args unless disabled."""
args = SimpleNamespace(
router_model_path=None,
router_disable_arg_fallback=False,
model_path="backend/model",
router_tokenizer_path=None,
tokenizer_path="backend/tokenizer",
router_worker_urls=[],
worker_urls=[],
router_prefill=None,
router_decode=None,
router_selector=None,
router_prefill_selector=None,
router_decode_selector=None,
router_router_selector=None,
)

router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)

assert router_args.model_path == "backend/model"
assert router_args.tokenizer_path == "backend/tokenizer"

def test_prefixed_args_can_disable_backend_fallback(self):
"""When router fallback is disabled, backend args should not fill router args."""
args = SimpleNamespace(
router_model_path=None,
router_disable_arg_fallback=True,
model_path="backend/model",
router_tokenizer_path=None,
tokenizer_path="backend/tokenizer",
router_worker_urls=[],
worker_urls=[],
router_prefill=None,
router_decode=None,
router_selector=None,
router_prefill_selector=None,
router_decode_selector=None,
router_router_selector=None,
)

router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)

assert router_args.model_path is None
assert router_args.tokenizer_path is None


class TestPolicyFromStr:
"""Test policy string to enum conversion."""
Expand Down
11 changes: 11 additions & 0 deletions bindings/python/tests/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,17 @@ def test_trtllm_includes_router_args(self):
)
assert args.router_policy == "round_robin"

def test_trtllm_accepts_router_disable_arg_fallback_flag(self):
"""--router-disable-arg-fallback should parse and be available on the namespace."""
_, args, _ = parse_serve_args(
[
"--backend",
"trtllm",
"--router-disable-arg-fallback",
]
)
assert args.router_disable_arg_fallback is True

def test_trtllm_router_args_defaults(self):
"""Router args should have sensible defaults."""
_, args, _ = parse_serve_args(["--backend", "trtllm"])
Expand Down
Loading