diff --git a/src/tyro/_backends/_tyro_backend.py b/src/tyro/_backends/_tyro_backend.py index 13407c5f..d8be7654 100644 --- a/src/tyro/_backends/_tyro_backend.py +++ b/src/tyro/_backends/_tyro_backend.py @@ -25,6 +25,10 @@ from ._argparse_formatter import TyroArgumentParser from ._base import ParserBackend +_FLAG_ACTIONS = frozenset( + {"store_true", "store_false", "boolean_optional_action", "count"} +) + class KwargMap: """Look-up table for tracking keyword arguments. Due to aliases, each @@ -78,17 +82,33 @@ def contains_normalized(self, token_key: str) -> bool: return normalized in self._arg_from_kwarg return False - def is_counter_flag(self, token: str) -> bool: - """Check if a token like -vvv is a repeated counter short flag.""" - if len(token) <= 2 or not token.startswith("-") or token.startswith("--"): - return False - short_key = token[:2] - arg = self._arg_from_kwarg.get(short_key) - return ( - arg is not None - and arg.lowered.action == "count" - and all(token[i] == token[1] for i in range(2, len(token))) - ) + def expand_short_cluster(self, token: str) -> list[str] | None: + """POSIX-style expansion of clustered short flags. + + ``-abc`` -> ``[-a, -b, -c]``; ``-nfoo`` -> ``[-n, foo]`` when ``-n`` + takes a value; ``-vvv`` -> ``[-v, -v, -v]`` (the count-action handler + increments once per character). Returns ``None`` if the token isn't + a short cluster or contains an unknown character. + + The caller must strip any trailing ``=value`` before calling, and + must already have ruled out an exact alias match (so explicit + multi-char shorts like ``-cail`` are preferred). + """ + if not (token.startswith("-") and len(token) > 2 and token[1] != "-"): + return None + expanded: list[str] = [] + for i, ch in enumerate(token[1:], start=1): + arg = self._arg_from_kwarg.get("-" + ch) + if arg is None: + return None + expanded.append("-" + ch) + if arg.lowered.action not in _FLAG_ACTIONS: + # Value-taking short: the rest of the token is its value. + rest = token[i + 1 :] + if rest: + expanded.append(rest) + break + return expanded def get_boolean_value(self, kwarg: str) -> bool | None: return self._value_from_boolean_flag.get(kwarg, None) @@ -449,16 +469,8 @@ def _recurse( maybe_flag_delimiter_swapped ) full_arg = kwarg_map.get_kwarg(maybe_flag_delimiter_swapped) - short_counter_arg = kwarg_map.get_kwarg(arg_value[:2]) - enforce_mutex_group(short_counter_arg, maybe_flag_delimiter_swapped) enforce_mutex_group(full_arg, maybe_flag_delimiter_swapped) - if kwarg_map.is_counter_flag(arg_value): - assert short_counter_arg is not None - dest = short_counter_arg.lowered.dest - output[dest] = cast(int, output[dest]) + len(arg_value) - 1 - args_to_pop.append(short_counter_arg) - continue - elif boolean_value is not None: + if boolean_value is not None: assert full_arg is not None output[full_arg.lowered.dest] = boolean_value args_to_pop.append(full_arg) @@ -530,6 +542,19 @@ def _recurse( ) continue + # POSIX-style short flag clustering (-abc -> -a -b -c). + # Tried only after exact-match lookups, so registered + # multi-char shorts like -cail still win. + flag_token = ( + arg_value if equals_value is None else arg_value.partition("=")[0] + ) + expanded = kwarg_map.expand_short_cluster(flag_token) + if expanded is not None: + if equals_value is not None: + expanded.append(equals_value) + args_deque.extendleft(reversed(expanded)) + continue + # Implicitly select default subcommands. if CascadeSubcommandArgs in parser_spec.markers: # Note: maybe_flag_delimiter_swapped already has the "=value" @@ -870,8 +895,6 @@ def _consume_argument( token_key = args_deque[0].partition("=")[0] if kwarg_map.contains_normalized(token_key): break - if kwarg_map.is_counter_flag(token_key): - break # To match argparse behavior, any flag-like string # terminates positional nargs consumption. We check for diff --git a/tests/test_py311_generated/test_short_flag_clustering_generated.py b/tests/test_py311_generated/test_short_flag_clustering_generated.py new file mode 100644 index 00000000..dbe1679c --- /dev/null +++ b/tests/test_py311_generated/test_short_flag_clustering_generated.py @@ -0,0 +1,249 @@ +"""Tests for POSIX-style short flag clustering (issue #465). + +POSIX commands let you combine single-letter options: ``-abc`` is equivalent +to ``-a -b -c``. If a flag in the cluster takes a value, the rest of the +token (after an optional ``=``) becomes that value: ``-nfoo`` -> ``-n foo``. +""" + +from __future__ import annotations + +import dataclasses +from typing import Annotated, Optional, Tuple + +import pytest + +import tyro +from tyro.conf import UseCounterAction, arg + + +@dataclasses.dataclass +class Flags: + a: Annotated[bool, arg(aliases=["-a"])] = False + b: Annotated[bool, arg(aliases=["-b"])] = False + c: Annotated[bool, arg(aliases=["-c"])] = False + + +def test_cluster_all_bool() -> None: + out = tyro.cli(Flags, args=["-abc"]) + assert out == Flags(a=True, b=True, c=True) + + +def test_cluster_partial_bool() -> None: + assert tyro.cli(Flags, args=["-ab"]) == Flags(a=True, b=True, c=False) + assert tyro.cli(Flags, args=["-ac"]) == Flags(a=True, b=False, c=True) + assert tyro.cli(Flags, args=["-bc"]) == Flags(a=False, b=True, c=True) + + +def test_cluster_with_separate_short() -> None: + assert tyro.cli(Flags, args=["-ab", "-c"]) == Flags(a=True, b=True, c=True) + assert tyro.cli(Flags, args=["-a", "-bc"]) == Flags(a=True, b=True, c=True) + + +def test_cluster_unknown_char_raises() -> None: + with pytest.raises(SystemExit): + tyro.cli(Flags, args=["-abz"]) + + +def test_cluster_repeated_char_is_idempotent_for_bool() -> None: + # -aa -> -a -a (both store_true, second overwrites; net effect a=True) + assert tyro.cli(Flags, args=["-aa"]) == Flags(a=True) + + +def test_value_taking_short_glued() -> None: + @dataclasses.dataclass + class C: + n: Annotated[str, arg(aliases=["-n"])] = "default" + + assert tyro.cli(C, args=["-nfoo"]).n == "foo" + # -n=foo also works. + assert tyro.cli(C, args=["-n=foo"]).n == "foo" + # Spaced form still works. + assert tyro.cli(C, args=["-n", "foo"]).n == "foo" + + +def test_value_taking_short_at_end_of_cluster() -> None: + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + b: Annotated[bool, arg(aliases=["-b"])] = False + n: Annotated[str, arg(aliases=["-n"])] = "default" + + # -abn foo -> -a -b -n foo + out = tyro.cli(C, args=["-abn", "foo"]) + assert out == C(a=True, b=True, n="foo") + # Glued: -abnfoo -> -a -b -n foo + out = tyro.cli(C, args=["-abnfoo"]) + assert out == C(a=True, b=True, n="foo") + # Equals: -abn=foo + out = tyro.cli(C, args=["-abn=foo"]) + assert out == C(a=True, b=True, n="foo") + + +def test_value_taking_short_in_middle_consumes_rest() -> None: + """If a value-taking short appears mid-cluster, the rest is its value + (POSIX semantics), even if subsequent characters look like other flags.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + b: Annotated[bool, arg(aliases=["-b"])] = False + n: Annotated[str, arg(aliases=["-n"])] = "default" + + # -anb -> -a -n b (b becomes the value of -n, not a flag). + out = tyro.cli(C, args=["-anb"]) + assert out == C(a=True, b=False, n="b") + + +def test_counter_short_cluster() -> None: + @dataclasses.dataclass + class C: + verbose: Annotated[int, arg(aliases=["-v"]), UseCounterAction] = 0 + + assert tyro.cli(C, args=["-vvv"]).verbose == 3 + assert tyro.cli(C, args=["-v", "-v"]).verbose == 2 + assert tyro.cli(C, args=[]).verbose == 0 + + +def test_counter_mixed_with_bool_cluster() -> None: + @dataclasses.dataclass + class C: + verbose: Annotated[int, arg(aliases=["-v"]), UseCounterAction] = 0 + a: Annotated[bool, arg(aliases=["-a"])] = False + + assert tyro.cli(C, args=["-va"]) == C(verbose=1, a=True) + assert tyro.cli(C, args=["-vva"]) == C(verbose=2, a=True) + assert tyro.cli(C, args=["-avv"]) == C(verbose=2, a=True) + assert tyro.cli(C, args=["-vav"]) == C(verbose=2, a=True) + + +def test_registered_multichar_short_takes_precedence() -> None: + """If ``-cail`` is explicitly registered as an alias, it must win over + cluster expansion of ``-c -a -i -l``.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + c: Annotated[bool, arg(aliases=["-c"])] = False + i: Annotated[bool, arg(aliases=["-i"])] = False + l: Annotated[bool, arg(aliases=["-l"])] = False + cail: Annotated[bool, arg(aliases=["-cail"])] = False + + out = tyro.cli(C, args=["-cail"]) + # The exact alias wins. + assert out.cail is True + assert out.a is False + assert out.c is False + + +def test_double_dash_long_flag_not_clustered() -> None: + """``--abc`` is a long flag; never expanded as a cluster.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + b: Annotated[bool, arg(aliases=["-b"])] = False + c: Annotated[bool, arg(aliases=["-c"])] = False + + with pytest.raises(SystemExit): + tyro.cli(C, args=["--abc"]) + + +def test_cluster_after_double_dash_marker_treated_as_positional() -> None: + """Tokens after the ``--`` end-of-options marker are not flags.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + rest: tyro.conf.Positional[Tuple[str, ...]] = () + + out = tyro.cli(C, args=["-a", "--", "-bc"]) + assert out.a is True + assert out.rest == ("-bc",) + + +def test_negative_number_not_cluster() -> None: + """Negative numbers must still be parseable as positional/value args.""" + + @dataclasses.dataclass + class C: + n: int = 0 + + assert tyro.cli(C, args=["--n", "-3"]).n == -3 + + +def test_cluster_with_value_taking_first() -> None: + """If the first short in a cluster takes a value, the entire rest is + its value (no further flag interpretation).""" + + @dataclasses.dataclass + class C: + n: Annotated[str, arg(aliases=["-n"])] = "x" + a: Annotated[bool, arg(aliases=["-a"])] = False + + # -nabc -> -n abc, NOT -n -a -b -c. + assert tyro.cli(C, args=["-nabc"]).n == "abc" + + +def test_cluster_int_value() -> None: + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + n: Annotated[int, arg(aliases=["-n"])] = 0 + + out = tyro.cli(C, args=["-an42"]) + assert out == C(a=True, n=42) + + +def test_cluster_optional_value_taking() -> None: + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + n: Annotated[Optional[str], arg(aliases=["-n"])] = None + + out = tyro.cli(C, args=["-an", "hi"]) + assert out == C(a=True, n="hi") + + +def test_lone_short_not_affected() -> None: + """Sanity check: lone ``-a`` still works.""" + + out = tyro.cli(Flags, args=["-a"]) + assert out == Flags(a=True) + + +def test_cluster_does_not_match_long_flag_chars() -> None: + """Cluster expansion must use only registered single-letter shorts, not + arbitrary characters from long flag names.""" + + @dataclasses.dataclass + class C: + apple: bool = False + banana: bool = False + + with pytest.raises(SystemExit): + tyro.cli(C, args=["-ab"]) + + +def test_cluster_with_unrelated_short() -> None: + """If only some chars are registered shorts, the cluster as a whole + fails (we don't partial-expand).""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + # -b NOT registered. + + with pytest.raises(SystemExit): + tyro.cli(C, args=["-ab"]) + + +def test_help_short_unaffected() -> None: + """``-h`` still triggers help and is not interpreted as a cluster.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + + with pytest.raises(SystemExit) as exc_info: + tyro.cli(C, args=["-h"]) + assert exc_info.value.code == 0 diff --git a/tests/test_short_flag_clustering.py b/tests/test_short_flag_clustering.py new file mode 100644 index 00000000..6c114fba --- /dev/null +++ b/tests/test_short_flag_clustering.py @@ -0,0 +1,250 @@ +"""Tests for POSIX-style short flag clustering (issue #465). + +POSIX commands let you combine single-letter options: ``-abc`` is equivalent +to ``-a -b -c``. If a flag in the cluster takes a value, the rest of the +token (after an optional ``=``) becomes that value: ``-nfoo`` -> ``-n foo``. +""" + +from __future__ import annotations + +import dataclasses +from typing import Optional, Tuple + +import pytest +from typing_extensions import Annotated + +import tyro +from tyro.conf import UseCounterAction, arg + + +@dataclasses.dataclass +class Flags: + a: Annotated[bool, arg(aliases=["-a"])] = False + b: Annotated[bool, arg(aliases=["-b"])] = False + c: Annotated[bool, arg(aliases=["-c"])] = False + + +def test_cluster_all_bool() -> None: + out = tyro.cli(Flags, args=["-abc"]) + assert out == Flags(a=True, b=True, c=True) + + +def test_cluster_partial_bool() -> None: + assert tyro.cli(Flags, args=["-ab"]) == Flags(a=True, b=True, c=False) + assert tyro.cli(Flags, args=["-ac"]) == Flags(a=True, b=False, c=True) + assert tyro.cli(Flags, args=["-bc"]) == Flags(a=False, b=True, c=True) + + +def test_cluster_with_separate_short() -> None: + assert tyro.cli(Flags, args=["-ab", "-c"]) == Flags(a=True, b=True, c=True) + assert tyro.cli(Flags, args=["-a", "-bc"]) == Flags(a=True, b=True, c=True) + + +def test_cluster_unknown_char_raises() -> None: + with pytest.raises(SystemExit): + tyro.cli(Flags, args=["-abz"]) + + +def test_cluster_repeated_char_is_idempotent_for_bool() -> None: + # -aa -> -a -a (both store_true, second overwrites; net effect a=True) + assert tyro.cli(Flags, args=["-aa"]) == Flags(a=True) + + +def test_value_taking_short_glued() -> None: + @dataclasses.dataclass + class C: + n: Annotated[str, arg(aliases=["-n"])] = "default" + + assert tyro.cli(C, args=["-nfoo"]).n == "foo" + # -n=foo also works. + assert tyro.cli(C, args=["-n=foo"]).n == "foo" + # Spaced form still works. + assert tyro.cli(C, args=["-n", "foo"]).n == "foo" + + +def test_value_taking_short_at_end_of_cluster() -> None: + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + b: Annotated[bool, arg(aliases=["-b"])] = False + n: Annotated[str, arg(aliases=["-n"])] = "default" + + # -abn foo -> -a -b -n foo + out = tyro.cli(C, args=["-abn", "foo"]) + assert out == C(a=True, b=True, n="foo") + # Glued: -abnfoo -> -a -b -n foo + out = tyro.cli(C, args=["-abnfoo"]) + assert out == C(a=True, b=True, n="foo") + # Equals: -abn=foo + out = tyro.cli(C, args=["-abn=foo"]) + assert out == C(a=True, b=True, n="foo") + + +def test_value_taking_short_in_middle_consumes_rest() -> None: + """If a value-taking short appears mid-cluster, the rest is its value + (POSIX semantics), even if subsequent characters look like other flags.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + b: Annotated[bool, arg(aliases=["-b"])] = False + n: Annotated[str, arg(aliases=["-n"])] = "default" + + # -anb -> -a -n b (b becomes the value of -n, not a flag). + out = tyro.cli(C, args=["-anb"]) + assert out == C(a=True, b=False, n="b") + + +def test_counter_short_cluster() -> None: + @dataclasses.dataclass + class C: + verbose: Annotated[int, arg(aliases=["-v"]), UseCounterAction] = 0 + + assert tyro.cli(C, args=["-vvv"]).verbose == 3 + assert tyro.cli(C, args=["-v", "-v"]).verbose == 2 + assert tyro.cli(C, args=[]).verbose == 0 + + +def test_counter_mixed_with_bool_cluster() -> None: + @dataclasses.dataclass + class C: + verbose: Annotated[int, arg(aliases=["-v"]), UseCounterAction] = 0 + a: Annotated[bool, arg(aliases=["-a"])] = False + + assert tyro.cli(C, args=["-va"]) == C(verbose=1, a=True) + assert tyro.cli(C, args=["-vva"]) == C(verbose=2, a=True) + assert tyro.cli(C, args=["-avv"]) == C(verbose=2, a=True) + assert tyro.cli(C, args=["-vav"]) == C(verbose=2, a=True) + + +def test_registered_multichar_short_takes_precedence() -> None: + """If ``-cail`` is explicitly registered as an alias, it must win over + cluster expansion of ``-c -a -i -l``.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + c: Annotated[bool, arg(aliases=["-c"])] = False + i: Annotated[bool, arg(aliases=["-i"])] = False + l: Annotated[bool, arg(aliases=["-l"])] = False + cail: Annotated[bool, arg(aliases=["-cail"])] = False + + out = tyro.cli(C, args=["-cail"]) + # The exact alias wins. + assert out.cail is True + assert out.a is False + assert out.c is False + + +def test_double_dash_long_flag_not_clustered() -> None: + """``--abc`` is a long flag; never expanded as a cluster.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + b: Annotated[bool, arg(aliases=["-b"])] = False + c: Annotated[bool, arg(aliases=["-c"])] = False + + with pytest.raises(SystemExit): + tyro.cli(C, args=["--abc"]) + + +def test_cluster_after_double_dash_marker_treated_as_positional() -> None: + """Tokens after the ``--`` end-of-options marker are not flags.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + rest: tyro.conf.Positional[Tuple[str, ...]] = () + + out = tyro.cli(C, args=["-a", "--", "-bc"]) + assert out.a is True + assert out.rest == ("-bc",) + + +def test_negative_number_not_cluster() -> None: + """Negative numbers must still be parseable as positional/value args.""" + + @dataclasses.dataclass + class C: + n: int = 0 + + assert tyro.cli(C, args=["--n", "-3"]).n == -3 + + +def test_cluster_with_value_taking_first() -> None: + """If the first short in a cluster takes a value, the entire rest is + its value (no further flag interpretation).""" + + @dataclasses.dataclass + class C: + n: Annotated[str, arg(aliases=["-n"])] = "x" + a: Annotated[bool, arg(aliases=["-a"])] = False + + # -nabc -> -n abc, NOT -n -a -b -c. + assert tyro.cli(C, args=["-nabc"]).n == "abc" + + +def test_cluster_int_value() -> None: + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + n: Annotated[int, arg(aliases=["-n"])] = 0 + + out = tyro.cli(C, args=["-an42"]) + assert out == C(a=True, n=42) + + +def test_cluster_optional_value_taking() -> None: + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + n: Annotated[Optional[str], arg(aliases=["-n"])] = None + + out = tyro.cli(C, args=["-an", "hi"]) + assert out == C(a=True, n="hi") + + +def test_lone_short_not_affected() -> None: + """Sanity check: lone ``-a`` still works.""" + + out = tyro.cli(Flags, args=["-a"]) + assert out == Flags(a=True) + + +def test_cluster_does_not_match_long_flag_chars() -> None: + """Cluster expansion must use only registered single-letter shorts, not + arbitrary characters from long flag names.""" + + @dataclasses.dataclass + class C: + apple: bool = False + banana: bool = False + + with pytest.raises(SystemExit): + tyro.cli(C, args=["-ab"]) + + +def test_cluster_with_unrelated_short() -> None: + """If only some chars are registered shorts, the cluster as a whole + fails (we don't partial-expand).""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + # -b NOT registered. + + with pytest.raises(SystemExit): + tyro.cli(C, args=["-ab"]) + + +def test_help_short_unaffected() -> None: + """``-h`` still triggers help and is not interpreted as a cluster.""" + + @dataclasses.dataclass + class C: + a: Annotated[bool, arg(aliases=["-a"])] = False + + with pytest.raises(SystemExit) as exc_info: + tyro.cli(C, args=["-h"]) + assert exc_info.value.code == 0