From 23ba4a6e50b852812947f228233c68f69ee86c68 Mon Sep 17 00:00:00 2001 From: sinisaos Date: Thu, 12 Jun 2025 07:06:23 +0200 Subject: [PATCH] type hint update for Python 3.9 and later --- example_script.py | 4 ++-- targ/__init__.py | 55 ++++++++++++++++++++----------------------- tests/test_command.py | 24 +++++++++---------- 3 files changed, 39 insertions(+), 44 deletions(-) diff --git a/example_script.py b/example_script.py index 85392ad..1f4267e 100644 --- a/example_script.py +++ b/example_script.py @@ -1,6 +1,6 @@ import asyncio import decimal -import typing as t +from typing import Optional from targ import CLI @@ -49,7 +49,7 @@ def say_hello(name: str, greeting: str = "hello"): # print_address --number=1 --street="Royal Avenue" --postcode="XYZ 123" # --city=London def print_address( - number: int, street: str, postcode: str, city: t.Optional[str] = None + number: int, street: str, postcode: str, city: Optional[str] = None ): """ Print out the full address. diff --git a/targ/__init__.py b/targ/__init__.py index 8c8f04e..c9ec7fb 100644 --- a/targ/__init__.py +++ b/targ/__init__.py @@ -6,20 +6,15 @@ import json import sys import traceback -import typing as t +from collections.abc import Callable from dataclasses import dataclass, field - -try: - from typing import get_args, get_origin # type: ignore -except ImportError: - # For Python 3.7 support - from typing_extensions import get_args, get_origin # type: ignore +from typing import Any, Optional, Union, get_args, get_origin, get_type_hints from docstring_parser import Docstring, DocstringParam, parse # type: ignore from .format import Color, format_text, get_underline -__VERSION__ = "0.5.0" +__VERSION__ = "0.6.0" # If an annotation is one of these values, we will convert the string value @@ -29,8 +24,8 @@ @dataclass class Arguments: - args: t.List[str] = field(default_factory=list) - kwargs: t.Dict[str, t.Any] = field(default_factory=dict) + args: list[str] = field(default_factory=list) + kwargs: dict[str, Any] = field(default_factory=dict) @dataclass @@ -54,14 +49,14 @@ class Command: """ - command: t.Callable - group_name: t.Optional[str] = None - command_name: t.Optional[str] = None - aliases: t.List[str] = field(default_factory=list) + command: Callable + group_name: Optional[str] = None + command_name: Optional[str] = None + aliases: list[str] = field(default_factory=list) def __post_init__(self) -> None: self.command_docstring: Docstring = parse(self.command.__doc__ or "") - self.annotations = t.get_type_hints(self.command) + self.annotations = get_type_hints(self.command) self.signature = inspect.signature(self.command) self.solo = False if not self.command_name: @@ -85,7 +80,7 @@ def description(self) -> str: ] ) - def _get_docstring_param(self, arg_name) -> t.Optional[DocstringParam]: + def _get_docstring_param(self, arg_name) -> Optional[DocstringParam]: for param in self.command_docstring.params: if param.arg_name == arg_name: return param @@ -200,7 +195,7 @@ def call_with(self, arg_class: Arguments): self.print_help() return - annotations = t.get_type_hints(self.command) + annotations = get_type_hints(self.command) kwargs = arg_class.kwargs.copy() for index, value in enumerate(arg_class.args): @@ -215,8 +210,8 @@ def call_with(self, arg_class: Arguments): if annotation in CONVERTABLE_TYPES: value = annotation(value) - elif get_origin(annotation) is t.Union: # type: ignore - # t.Union is used to detect t.Optional + elif get_origin(annotation) is Union: # type: ignore + # Union is used to detect Optional inner_annotations = get_args(annotation) filtered = [i for i in inner_annotations if i is not None] if len(filtered) == 1: @@ -250,7 +245,7 @@ class CLI: """ description: str = "Targ CLI" - commands: t.List[Command] = field(default_factory=list, init=False) + commands: list[Command] = field(default_factory=list, init=False) def command_exists(self, group_name: str, command_name: str) -> bool: """ @@ -277,10 +272,10 @@ def _validate_name(self, name: str) -> bool: def register( self, - command: t.Callable, - group_name: t.Optional[str] = None, - command_name: t.Optional[str] = None, - aliases: t.List[str] = [], + command: Callable, + group_name: Optional[str] = None, + command_name: Optional[str] = None, + aliases: list[str] = [], ): """ Register a function or coroutine as a CLI command. @@ -337,15 +332,15 @@ def get_help_text(self) -> str: return "\n".join(lines) - def _get_cleaned_args(self) -> t.List[str]: + def _get_cleaned_args(self) -> list[str]: """ Remove any redundant arguments. """ return sys.argv[1:] def _get_command( - self, command_name: str, group_name: t.Optional[str] = None - ) -> t.Optional[Command]: + self, command_name: str, group_name: Optional[str] = None + ) -> Optional[Command]: for command in self.commands: if ( command.command_name == command_name @@ -356,14 +351,14 @@ def _get_command( return command return None - def _clean_cli_argument(self, value: str) -> t.Any: + def _clean_cli_argument(self, value: str) -> Any: if value in ["True", "true", "t"]: return True elif value in ["False", "false", "f"]: return False return value - def _get_arg_class(self, args: t.List[str]) -> Arguments: + def _get_arg_class(self, args: list[str]) -> Arguments: arguments = Arguments() for arg_str in args: if arg_str.startswith("--"): @@ -401,7 +396,7 @@ def run(self, solo: bool = False): """ cleaned_args = self._get_cleaned_args() - command: t.Optional[Command] = None + command: Optional[Command] = None # Work out if to enable tracebacks try: diff --git a/tests/test_command.py b/tests/test_command.py index 1343d09..a55265c 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -1,7 +1,7 @@ import dataclasses import decimal import sys -import typing as t +from typing import Any, Optional from unittest import TestCase from unittest.mock import MagicMock, patch @@ -15,7 +15,7 @@ def add(a: int, b: int): print(a + b) -def print_(value: t.Any, *args, **kwargs): +def print_(value: Any, *args, **kwargs): """ When patching the builtin print statement, this is used as a side effect, so we can still use debug statements. @@ -26,7 +26,7 @@ def print_(value: t.Any, *args, **kwargs): @dataclasses.dataclass class Config: - params: t.List[str] + params: list[str] output: str @@ -158,7 +158,7 @@ def test_command(arg1: bool = False): with patch("builtins.print", side_effect=print_) as print_mock: - configs: t.List[Config] = [ + configs: list[Config] = [ Config(params=["test_command"], output="arg1 is False"), Config(params=["test_command", "f"], output="arg1 is False"), Config( @@ -209,7 +209,7 @@ def test_optional_bool_arg(self, _get_cleaned_args: MagicMock): Test command arguments which are of type Optional[bool]. """ - def test_command(arg1: t.Optional[bool] = None): + def test_command(arg1: Optional[bool] = None): """ A command for testing optional boolean arguments. """ @@ -227,7 +227,7 @@ def test_command(arg1: t.Optional[bool] = None): with patch("builtins.print", side_effect=print_) as print_mock: - configs: t.List[Config] = [ + configs: list[Config] = [ Config( params=["test_command", "--arg1"], output="arg1 is True", @@ -285,7 +285,7 @@ def test_command(arg1: decimal.Decimal): with patch("builtins.print", side_effect=print_) as print_mock: - configs: t.List[Config] = [ + configs: list[Config] = [ Config( params=["test_command", "1"], output="arg1 is int", @@ -322,7 +322,7 @@ def test_command(arg1: decimal.Decimal): with patch("builtins.print", side_effect=print_) as print_mock: - configs: t.List[Config] = [ + configs: list[Config] = [ Config( params=["test_command", "1.11"], output="arg1 is Decimal", @@ -359,7 +359,7 @@ def test_command(arg1: float): with patch("builtins.print", side_effect=print_) as print_mock: - configs: t.List[Config] = [ + configs: list[Config] = [ Config( params=["test_command", "1.11"], output="arg1 is float", @@ -396,7 +396,7 @@ def test_command(arg1: float, arg2: bool): with patch("builtins.print", side_effect=print_) as print_mock: - configs: t.List[Config] = [ + configs: list[Config] = [ Config( params=["test_command", "1.11", "true"], output="arg1 is float, arg2 is bool", @@ -435,7 +435,7 @@ def test_command(): with patch("builtins.print", side_effect=print_) as print_mock: - configs: t.List[Config] = [ + configs: list[Config] = [ Config(params=["test_command"], output="Command called"), Config(params=["tc"], output="Command called"), ] @@ -461,7 +461,7 @@ def test_command(name): with patch("builtins.print", side_effect=print_) as print_mock: - configs: t.List[Config] = [ + configs: list[Config] = [ Config(params=["test_command", "hello"], output="hello"), ]