Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions example_script.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import decimal
import typing as t
from typing import Optional

from targ import CLI

Expand Down Expand Up @@ -49,7 +49,7 @@ def say_hello(name: str, greeting: str = "hello"):
# print_address --number=1 --street="Royal Avenue" --postcode="XYZ 123"
# --city=London
def print_address(
number: int, street: str, postcode: str, city: t.Optional[str] = None
number: int, street: str, postcode: str, city: Optional[str] = None
):
"""
Print out the full address.
Expand Down
55 changes: 25 additions & 30 deletions targ/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@
import json
import sys
import traceback
import typing as t
from collections.abc import Callable
from dataclasses import dataclass, field

try:
from typing import get_args, get_origin # type: ignore
except ImportError:
# For Python 3.7 support
from typing_extensions import get_args, get_origin # type: ignore
from typing import Any, Optional, Union, get_args, get_origin, get_type_hints

from docstring_parser import Docstring, DocstringParam, parse # type: ignore

from .format import Color, format_text, get_underline

__VERSION__ = "0.5.0"
__VERSION__ = "0.6.0"


# If an annotation is one of these values, we will convert the string value
Expand All @@ -29,8 +24,8 @@

@dataclass
class Arguments:
args: t.List[str] = field(default_factory=list)
kwargs: t.Dict[str, t.Any] = field(default_factory=dict)
args: list[str] = field(default_factory=list)
kwargs: dict[str, Any] = field(default_factory=dict)


@dataclass
Expand All @@ -54,14 +49,14 @@ class Command:

"""

command: t.Callable
group_name: t.Optional[str] = None
command_name: t.Optional[str] = None
aliases: t.List[str] = field(default_factory=list)
command: Callable
group_name: Optional[str] = None
command_name: Optional[str] = None
aliases: list[str] = field(default_factory=list)

def __post_init__(self) -> None:
self.command_docstring: Docstring = parse(self.command.__doc__ or "")
self.annotations = t.get_type_hints(self.command)
self.annotations = get_type_hints(self.command)
self.signature = inspect.signature(self.command)
self.solo = False
if not self.command_name:
Expand All @@ -85,7 +80,7 @@ def description(self) -> str:
]
)

def _get_docstring_param(self, arg_name) -> t.Optional[DocstringParam]:
def _get_docstring_param(self, arg_name) -> Optional[DocstringParam]:
for param in self.command_docstring.params:
if param.arg_name == arg_name:
return param
Expand Down Expand Up @@ -200,7 +195,7 @@ def call_with(self, arg_class: Arguments):
self.print_help()
return

annotations = t.get_type_hints(self.command)
annotations = get_type_hints(self.command)

kwargs = arg_class.kwargs.copy()
for index, value in enumerate(arg_class.args):
Expand All @@ -215,8 +210,8 @@ def call_with(self, arg_class: Arguments):

if annotation in CONVERTABLE_TYPES:
value = annotation(value)
elif get_origin(annotation) is t.Union: # type: ignore
# t.Union is used to detect t.Optional
elif get_origin(annotation) is Union: # type: ignore
# Union is used to detect Optional
inner_annotations = get_args(annotation)
filtered = [i for i in inner_annotations if i is not None]
if len(filtered) == 1:
Expand Down Expand Up @@ -250,7 +245,7 @@ class CLI:
"""

description: str = "Targ CLI"
commands: t.List[Command] = field(default_factory=list, init=False)
commands: list[Command] = field(default_factory=list, init=False)

def command_exists(self, group_name: str, command_name: str) -> bool:
"""
Expand All @@ -277,10 +272,10 @@ def _validate_name(self, name: str) -> bool:

def register(
self,
command: t.Callable,
group_name: t.Optional[str] = None,
command_name: t.Optional[str] = None,
aliases: t.List[str] = [],
command: Callable,
group_name: Optional[str] = None,
command_name: Optional[str] = None,
aliases: list[str] = [],
):
"""
Register a function or coroutine as a CLI command.
Expand Down Expand Up @@ -337,15 +332,15 @@ def get_help_text(self) -> str:

return "\n".join(lines)

def _get_cleaned_args(self) -> t.List[str]:
def _get_cleaned_args(self) -> list[str]:
"""
Remove any redundant arguments.
"""
return sys.argv[1:]

def _get_command(
self, command_name: str, group_name: t.Optional[str] = None
) -> t.Optional[Command]:
self, command_name: str, group_name: Optional[str] = None
) -> Optional[Command]:
for command in self.commands:
if (
command.command_name == command_name
Expand All @@ -356,14 +351,14 @@ def _get_command(
return command
return None

def _clean_cli_argument(self, value: str) -> t.Any:
def _clean_cli_argument(self, value: str) -> Any:
if value in ["True", "true", "t"]:
return True
elif value in ["False", "false", "f"]:
return False
return value

def _get_arg_class(self, args: t.List[str]) -> Arguments:
def _get_arg_class(self, args: list[str]) -> Arguments:
arguments = Arguments()
for arg_str in args:
if arg_str.startswith("--"):
Expand Down Expand Up @@ -401,7 +396,7 @@ def run(self, solo: bool = False):

"""
cleaned_args = self._get_cleaned_args()
command: t.Optional[Command] = None
command: Optional[Command] = None

# Work out if to enable tracebacks
try:
Expand Down
24 changes: 12 additions & 12 deletions tests/test_command.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import decimal
import sys
import typing as t
from typing import Any, Optional
from unittest import TestCase
from unittest.mock import MagicMock, patch

Expand All @@ -15,7 +15,7 @@ def add(a: int, b: int):
print(a + b)


def print_(value: t.Any, *args, **kwargs):
def print_(value: Any, *args, **kwargs):
"""
When patching the builtin print statement, this is used as a side effect,
so we can still use debug statements.
Expand All @@ -26,7 +26,7 @@ def print_(value: t.Any, *args, **kwargs):

@dataclasses.dataclass
class Config:
params: t.List[str]
params: list[str]
output: str


Expand Down Expand Up @@ -158,7 +158,7 @@ def test_command(arg1: bool = False):

with patch("builtins.print", side_effect=print_) as print_mock:

configs: t.List[Config] = [
configs: list[Config] = [
Config(params=["test_command"], output="arg1 is False"),
Config(params=["test_command", "f"], output="arg1 is False"),
Config(
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_optional_bool_arg(self, _get_cleaned_args: MagicMock):
Test command arguments which are of type Optional[bool].
"""

def test_command(arg1: t.Optional[bool] = None):
def test_command(arg1: Optional[bool] = None):
"""
A command for testing optional boolean arguments.
"""
Expand All @@ -227,7 +227,7 @@ def test_command(arg1: t.Optional[bool] = None):

with patch("builtins.print", side_effect=print_) as print_mock:

configs: t.List[Config] = [
configs: list[Config] = [
Config(
params=["test_command", "--arg1"],
output="arg1 is True",
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_command(arg1: decimal.Decimal):

with patch("builtins.print", side_effect=print_) as print_mock:

configs: t.List[Config] = [
configs: list[Config] = [
Config(
params=["test_command", "1"],
output="arg1 is int",
Expand Down Expand Up @@ -322,7 +322,7 @@ def test_command(arg1: decimal.Decimal):

with patch("builtins.print", side_effect=print_) as print_mock:

configs: t.List[Config] = [
configs: list[Config] = [
Config(
params=["test_command", "1.11"],
output="arg1 is Decimal",
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_command(arg1: float):

with patch("builtins.print", side_effect=print_) as print_mock:

configs: t.List[Config] = [
configs: list[Config] = [
Config(
params=["test_command", "1.11"],
output="arg1 is float",
Expand Down Expand Up @@ -396,7 +396,7 @@ def test_command(arg1: float, arg2: bool):

with patch("builtins.print", side_effect=print_) as print_mock:

configs: t.List[Config] = [
configs: list[Config] = [
Config(
params=["test_command", "1.11", "true"],
output="arg1 is float, arg2 is bool",
Expand Down Expand Up @@ -435,7 +435,7 @@ def test_command():

with patch("builtins.print", side_effect=print_) as print_mock:

configs: t.List[Config] = [
configs: list[Config] = [
Config(params=["test_command"], output="Command called"),
Config(params=["tc"], output="Command called"),
]
Expand All @@ -461,7 +461,7 @@ def test_command(name):

with patch("builtins.print", side_effect=print_) as print_mock:

configs: t.List[Config] = [
configs: list[Config] = [
Config(params=["test_command", "hello"], output="hello"),
]

Expand Down