Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .github/workflows/ty.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: ty

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
ty:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Run ty
# We type-check the test suite (not src/) because tyro's public
# surface is exercised most realistically through tests, and ty
# catches regressions like #460 there.
run: uv run --extra dev-nn --python 3.13 ty check tests/
21 changes: 21 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dev = [
"omegaconf>=2.2.2",
"attrs>=21.4.0",
"pyright>=1.1.349,!=1.1.379",
"ty>=0.0.33",
"mypy>=1.4.1",
"pydantic>=2.5.2,!=2.10.0",
"coverage[toml]>=6.5.0",
Expand Down Expand Up @@ -143,3 +144,23 @@ target-version = "py312"
[tool.pyright]
pythonVersion = "3.13"
ignore = ["**/_argparse.py", "./docs/**/*"]

[tool.ty.environment]
python-version = "3.13"

[[tool.ty.overrides]]
# Generated tests are reformatted from the originals; ruff sometimes moves
# `# ty: ignore` comments off the line where the diagnostic actually fires.
# The originals are checked strictly, so suppressing on the generated copies
# is safe.
include = ["tests/test_py311_generated/**"]
rules = { invalid-type-form = "ignore", invalid-assignment = "ignore", unused-ignore-comment = "ignore" }

[tool.ty.rules]
# `# type: ignore` comments are written for mypy/pyright; ty's narrower set
# of diagnostics often makes them redundant from its POV.
unused-type-ignore-comment = "ignore"
# `get_parser` is intentionally exercised by helptext tests.
deprecated = "ignore"
redundant-cast = "ignore"
possibly-missing-submodule = "ignore"
65 changes: 52 additions & 13 deletions src/tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import sys
import warnings
from contextlib import nullcontext
from typing import Callable, Literal, Sequence, TypeVar, cast, overload
from typing import Callable, Literal, Sequence, Type, TypeVar, cast, overload

from typing_extensions import Annotated, assert_never, deprecated
from typing_extensions import Annotated, TypeForm, assert_never, deprecated

from . import (
_arguments,
Expand All @@ -28,18 +28,57 @@
NonpropagatingMissingType,
PropagatingMissingType,
)
from ._typing import TypeForm
from .constructors import ConstructorRegistry
from .constructors._primitive_spec import UnsupportedTypeAnnotationError

OutT = TypeVar("OutT")


# The overload here is necessary for pyright and pylance due to special-casing
# related to using typing.Type[] as a temporary replacement for
# typing.TypeForm[].
#
# https://github.com/microsoft/pyright/issues/4298
# Two parallel sets of `f` overloads. `Type[OutT]` exists for pyright/pylance
# (see microsoft/pyright#4298 and the comment in `_resolver.py`); `TypeForm`
# exists for ty and any checker implementing PEP 747, and additionally covers
# patterns the `Type[T]` hack misses (e.g. `Annotated[A] | Annotated[B]`).
# Each checker picks the first overload it can match.


@overload
def cli(
f: Type[OutT],
*,
prog: None | str = None,
description: None | str = None,
args: None | Sequence[str] = None,
default: OutT
| NonpropagatingMissingType
| PropagatingMissingType = MISSING_NONPROP,
return_unknown_args: Literal[False] = False,
use_underscores: bool = False,
console_outputs: bool = True,
add_help: bool = True,
compact_help: bool = False,
config: None | Sequence[conf._markers.Marker] = None,
registry: None | ConstructorRegistry = None,
) -> OutT: ...


@overload
def cli(
f: Type[OutT],
*,
prog: None | str = None,
description: None | str = None,
args: None | Sequence[str] = None,
default: OutT
| NonpropagatingMissingType
| PropagatingMissingType = MISSING_NONPROP,
return_unknown_args: Literal[True],
use_underscores: bool = False,
console_outputs: bool = True,
add_help: bool = True,
compact_help: bool = False,
config: None | Sequence[conf._markers.Marker] = None,
registry: None | ConstructorRegistry = None,
) -> tuple[OutT, list[str]]: ...


@overload
Expand Down Expand Up @@ -124,8 +163,8 @@ def cli(
) -> tuple[OutT, list[str]]: ...


def cli(
f: TypeForm[OutT] | Callable[..., OutT],
def cli( # pyright: ignore[reportInconsistentOverload]
f: Type[OutT] | Callable[..., OutT],
*,
prog: None | str = None,
description: None | str = None,
Expand Down Expand Up @@ -288,7 +327,7 @@ class Config:
@overload
@deprecated("get_parser() is deprecated and will be removed in a future version.")
def get_parser(
f: TypeForm[OutT],
f: Type[OutT],
*,
prog: None | str = None,
description: None | str = None,
Expand Down Expand Up @@ -323,7 +362,7 @@ def get_parser(

@deprecated("get_parser() is deprecated and will be removed in a future version.")
def get_parser(
f: TypeForm[OutT] | Callable[..., OutT],
f: Type[OutT] | Callable[..., OutT],
*,
# We have no `args` argument, since this is only used when
# parser.parse_args() is called.
Expand Down Expand Up @@ -388,7 +427,7 @@ def get_parser(


def _cli_impl(
f: TypeForm[OutT] | Callable[..., OutT],
f: Type[OutT] | Callable[..., OutT],
*,
prog: None | str = None,
description: None | str,
Expand Down
17 changes: 8 additions & 9 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import functools
import inspect
import sys
from typing import Any, Callable, Dict, Literal, Tuple
from typing import Any, Callable, Dict, Literal, Tuple, Type

import docstring_parser
from typing_extensions import (
Expand All @@ -26,7 +26,6 @@
from . import _docstrings, _resolver, _strings, _unsafe_cache
from . import _fmtlib as fmt
from ._singleton import MISSING_NONPROP, is_missing
from ._typing import TypeForm
from ._typing_compat import is_typing_annotated, is_typing_unpack
from .conf import _confstruct, _markers
from .constructors._registry import ConstructorRegistry, check_default_instances
Expand All @@ -44,9 +43,9 @@
class FieldDefinition:
intern_name: str
extern_name: str
type: TypeForm[Any] | Callable
type: Type[Any] | Callable
"""Full type, including runtime annotations."""
type_stripped: TypeForm[Any] | Callable
type_stripped: Type[Any] | Callable
default: Any
helptext: str | Callable[[], str | None] | None
markers: set[Any]
Expand Down Expand Up @@ -90,7 +89,7 @@ def from_field_spec(field_spec: StructFieldSpec) -> FieldDefinition:
@staticmethod
def make(
name: str,
typ: TypeForm[Any] | Callable,
typ: Type[Any] | Callable,
default: Any,
helptext: str | Callable[[], str | None] | None,
call_argname_override: Any | None = None,
Expand Down Expand Up @@ -187,7 +186,7 @@ def make(
return out

def with_new_type_stripped(
self, new_type_stripped: TypeForm[Any] | Callable
self, new_type_stripped: Type[Any] | Callable
) -> FieldDefinition:
if is_typing_annotated(get_origin(self.type)):
new_type = Annotated[(new_type_stripped, *get_args(self.type)[1:])] # type: ignore
Expand All @@ -202,7 +201,7 @@ def with_new_type_stripped(

@_unsafe_cache.unsafe_cache(maxsize=1024)
def is_struct_type(
typ: TypeForm[Any] | Callable, default_instance: Any, in_union_context: bool
typ: Type[Any] | Callable, default_instance: Any, in_union_context: bool
) -> bool:
"""Determine whether a type should be treated as a 'struct type', where a single
type can be broken down into multiple fields (eg for nested dataclasses or
Expand Down Expand Up @@ -230,14 +229,14 @@ def is_struct_type(


def field_list_from_type_or_callable(
f: Callable | TypeForm[Any],
f: Callable | Type[Any],
default_instance: Any,
support_single_arg_types: bool,
in_union_context: bool,
) -> (
UnsupportedStructTypeMessage
| InvalidDefaultInstanceError
| tuple[Callable | TypeForm[Any], list[FieldDefinition]]
| tuple[Callable | Type[Any], list[FieldDefinition]]
):
"""Generate a list of generic 'field' objects corresponding to the inputs of some
annotated callable.
Expand Down
2 changes: 1 addition & 1 deletion src/tyro/_fmtlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def render(self, width: int) -> list[str]:
return ["".join(parts) for parts in out_parts]


_FORCE_UTF8_BOXES = False
_FORCE_UTF8_BOXES: bool = False


@final
Expand Down
3 changes: 1 addition & 2 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
_subcommand_matching,
)
from . import _fmtlib as fmt
from ._typing import TypeForm
from ._typing_compat import is_typing_union
from .conf import _confstruct, _markers
from .constructors._primitive_spec import (
Expand Down Expand Up @@ -443,7 +442,7 @@ class SubparsersSpecification:
extern_prefix: str
required: bool
default_instance: Any
options: Tuple[Union[TypeForm[Any], Callable], ...]
options: Tuple[Union[Type[Any], Callable], ...]
prog_suffix: str

@staticmethod
Expand Down
32 changes: 18 additions & 14 deletions src/tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@

from . import _unsafe_cache, conf
from ._singleton import is_missing, is_sentinel
from ._typing import TypeForm
from ._typing_compat import (
is_typing_annotated,
is_typing_classvar,
Expand All @@ -57,7 +56,12 @@
Python. types.UnionType was added in Python 3.10, and is created when the `X |
Y` syntax is used for unions."""

TypeOrCallable = TypeVar("TypeOrCallable", TypeForm[Any], Callable)
# `Type[T]` is used loosely throughout tyro as a stand-in for PEP 747's
# `TypeForm[T]`: it's accepted by pyright for arbitrary type forms (Annotated,
# unions, etc.) via microsoft/pyright#4298, and pragmatically conveys "any
# type expression" for runtime introspection. Switch to `TypeForm` once it
# lands in `typing` and is supported across checkers.
TypeOrCallable = TypeVar("TypeOrCallable", Type[Any], Callable)


@dataclasses.dataclass(frozen=True)
Expand All @@ -81,13 +85,13 @@ def unwrap_origin_strip_extras(typ: TypeOrCallable) -> TypeOrCallable:
return typ


def is_dataclass(cls: Union[TypeForm, Callable]) -> bool:
def is_dataclass(cls: Union[Type, Callable]) -> bool:
"""Same as `dataclasses.is_dataclass`, but also handles generic aliases."""
return dataclasses.is_dataclass(unwrap_origin_strip_extras(cls)) # type: ignore


# @_unsafe_cache.unsafe_cache(maxsize=1024)
def resolved_fields(cls: TypeForm) -> List[dataclasses.Field]:
def resolved_fields(cls: Type) -> List[dataclasses.Field]:
"""Similar to dataclasses.fields(), but includes dataclasses.InitVar types and
resolves forward references."""

Expand Down Expand Up @@ -117,7 +121,7 @@ def resolved_fields(cls: TypeForm) -> List[dataclasses.Field]:
return fields


def is_namedtuple(cls: TypeForm) -> bool:
def is_namedtuple(cls: Type) -> bool:
return (
isinstance(cls, type)
and issubclass(cls, tuple)
Expand All @@ -126,7 +130,7 @@ def is_namedtuple(cls: TypeForm) -> bool:
)


TypeOrCallableOrNone = TypeVar("TypeOrCallableOrNone", Callable, TypeForm[Any], None)
TypeOrCallableOrNone = TypeVar("TypeOrCallableOrNone", Callable, Type[Any], None)


def resolve_newtype_and_aliases(
Expand Down Expand Up @@ -231,14 +235,14 @@ def swap_type_using_confstruct(typ: TypeOrCallable) -> TypeOrCallable:
def narrow_collection_types(
typ: TypeOrCallable, default_instance: Any
) -> TypeOrCallable:
"""TypeForm narrowing for containers. Infers types of container contents."""
"""Type narrowing for containers. Infers types of container contents."""

# Can't narrow if we don't have a default value!
if is_missing(default_instance):
return typ

# We'll recursively narrow contained types too!
def _get_type(val: Any) -> TypeForm:
def _get_type(val: Any) -> Type:
return narrow_collection_types(type(val), val)

args = get_args(typ)
Expand Down Expand Up @@ -291,7 +295,7 @@ def _get_type(val: Any) -> TypeForm:
@overload
def unwrap_annotated(
typ: TypeOrCallable,
search_type: TypeForm[MetadataType],
search_type: Type[MetadataType],
) -> Tuple[TypeOrCallable, Tuple[MetadataType, ...]]: ...


Expand All @@ -311,7 +315,7 @@ def unwrap_annotated(

def unwrap_annotated(
typ: TypeOrCallable,
search_type: Union[TypeForm[MetadataType], Literal["all"], object, None] = None,
search_type: Union[Type[MetadataType], Literal["all"], object, None] = None,
) -> Union[Tuple[TypeOrCallable, Tuple[MetadataType, ...]], TypeOrCallable]:
"""Helper for parsing typing.Annotated types.

Expand Down Expand Up @@ -384,7 +388,7 @@ def unwrap_annotated(


class TypeParamResolver:
param_assignments: List[Dict[TypeVar, TypeForm[Any]]] = []
param_assignments: List[Dict[TypeVar, Type[Any]]] = []

@classmethod
def get_assignment_context(cls, typ: TypeOrCallable) -> TypeParamAssignmentContext:
Expand Down Expand Up @@ -559,7 +563,7 @@ class TypeParamAssignmentContext:
def __init__(
self,
origin_type: TypeOrCallable,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
type_from_typevar: Dict[TypeVar, Type[Any]],
):
# `Any` is needed for mypy...
self.origin_type: Any = origin_type
Expand Down Expand Up @@ -649,7 +653,7 @@ def isinstance_with_fuzzy_numeric_tower(

def resolve_generic_types(
typ: TypeOrCallable,
) -> Tuple[TypeOrCallable, Dict[TypeVar, TypeForm[Any]]]:
) -> Tuple[TypeOrCallable, Dict[TypeVar, Type[Any]]]:
"""If the input is a class: no-op. If it's a generic alias: returns the origin
class, and a mapping from typevars to concrete types."""

Expand All @@ -666,7 +670,7 @@ def resolve_generic_types(

# We'll ignore NewType when getting the origin + args for generics.
origin_cls = get_origin(typ)
type_from_typevar: Dict[TypeVar, TypeForm[Any]] = {}
type_from_typevar: Dict[TypeVar, Type[Any]] = {}

# Support typing.Self.
# We'll do this by pretending that `Self` is a TypeVar...
Expand Down
Loading
Loading