From 74e509d01e8b54a76f97f1e11cfc4d08a780b3c7 Mon Sep 17 00:00:00 2001 From: Dan Bennett Date: Tue, 23 Apr 2024 14:26:26 -0700 Subject: [PATCH] Decentralized module loading and argument definitions --- .pre-commit-config.yaml | 2 +- devenv/commands/__init__.py | 0 devenv/{ => commands}/bootstrap.py | 43 ++-- devenv/{ => commands}/doctor.py | 13 +- devenv/{ => commands}/fetch.py | 28 +-- devenv/{ => commands}/pin_gha.py | 8 +- devenv/{ => commands}/sync.py | 8 +- devenv/lib/config.py | 8 +- devenv/lib/context.py | 2 + devenv/lib/modules.py | 219 ++++++++++++++++++++- devenv/main.py | 131 +++++++++--- devenv/usercommands/__init__.py | 0 setup.cfg | 2 +- tests/doctor/test_attempt_fix.py | 2 +- tests/doctor/test_filter_failing_checks.py | 2 +- tests/doctor/test_load_checks.py | 2 +- tests/doctor/test_prompt_for_fix.py | 2 +- tests/doctor/test_run_checks.py | 2 +- 18 files changed, 386 insertions(+), 88 deletions(-) create mode 100644 devenv/commands/__init__.py rename devenv/{ => commands}/bootstrap.py (81%) rename devenv/{ => commands}/doctor.py (97%) rename devenv/{ => commands}/fetch.py (83%) rename devenv/{ => commands}/pin_gha.py (90%) rename devenv/{ => commands}/sync.py (79%) create mode 100644 devenv/usercommands/__init__.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6f995951..9bef2ba1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: v3.12.0 hooks: - id: reorder-python-imports - args: [--py311-plus, --add-import, "from __future__ import annotations"] + args: [--py310-plus, --add-import, "from __future__ import annotations"] - repo: https://github.com/psf/black rev: 23.10.0 hooks: diff --git a/devenv/commands/__init__.py b/devenv/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/devenv/bootstrap.py b/devenv/commands/bootstrap.py similarity index 81% rename from devenv/bootstrap.py rename to devenv/commands/bootstrap.py index 6c0188fd..3ba2a4c0 100644 --- a/devenv/bootstrap.py +++ b/devenv/commands/bootstrap.py @@ -1,9 +1,9 @@ from __future__ import annotations -import argparse import os import shutil from collections.abc import Sequence +from typing import cast from devenv.constants import CI from devenv.constants import EXTERNAL_CONTRIBUTOR @@ -14,25 +14,33 @@ from devenv.lib.config import Config from devenv.lib.config import initialize_config from devenv.lib.context import Context -from devenv.lib.modules import DevModuleInfo +from devenv.lib.modules import argument_fn +from devenv.lib.modules import command from devenv.lib.modules import ExitCode - - -def main(context: Context, argv: Sequence[str] | None = None) -> ExitCode: - parser = argparse.ArgumentParser() - parser.add_argument( - "-d", - "--default-config", - action="append", - help="Provide a default config value. e.g., -d coderoot:path/to/root", +from devenv.lib.modules import ModuleDef +from devenv.lib.modules import ParserFn + + +@command("bootstrap", "Bootstraps the development environment.") +@argument_fn( + cast( + ParserFn, + lambda x: x.add_argument( + "-d", + "--default-config", + metavar="config:value", + required=False, + action="append", + help="Provide a default config value. e.g., -d coderoot:path/to/root", + ), ) - - args = parser.parse_args(argv) +) +def main(context: Context, argv: Sequence[str] | None = None) -> ExitCode: + args = context["args"] configs = { k: v for k, v in [i.split(":", 1) for i in args.default_config or []] } - if "coderoot" not in configs and "code_root" in context: configs["coderoot"] = context["code_root"] @@ -103,9 +111,8 @@ def main(context: Context, argv: Sequence[str] | None = None) -> ExitCode: return 0 -module_info = DevModuleInfo( - action=main, - name=__name__, - command="bootstrap", +module_info = ModuleDef( + module_name=__name__, + name="bootstrap", help="Bootstraps the development environment.", ) diff --git a/devenv/doctor.py b/devenv/commands/doctor.py similarity index 97% rename from devenv/doctor.py rename to devenv/commands/doctor.py index 356aba80..fccc4d0d 100644 --- a/devenv/doctor.py +++ b/devenv/commands/doctor.py @@ -14,7 +14,8 @@ from typing import List from devenv.lib.context import Context -from devenv.lib.modules import DevModuleInfo +from devenv.lib.modules import command +from devenv.lib.modules import ModuleDef from devenv.lib.modules import require_repo from devenv.lib.repository import Repository from devenv.lib_check.types import checker @@ -161,6 +162,7 @@ def attempt_fix(check: Check, executor: ThreadPoolExecutor) -> tuple[bool, str]: return False, f"Fix threw a runtime exception: {e}" +@command("doctor", "Diagnose common issues, and optionally try to fix them.") @require_repo def main(context: Context, argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser() @@ -232,9 +234,8 @@ def main(context: Context, argv: Sequence[str] | None = None) -> int: return 1 -module_info = DevModuleInfo( - action=main, - name=__name__, - command="doctor", - help="Diagnose common issues, and optionally try to fix them.", +module_info = ModuleDef( + module_name=__name__, + name="doctor", + help="Diagnose common issues, and optionally try to fix them", ) diff --git a/devenv/fetch.py b/devenv/commands/fetch.py similarity index 83% rename from devenv/fetch.py rename to devenv/commands/fetch.py index 371d15bd..beecab27 100644 --- a/devenv/fetch.py +++ b/devenv/commands/fetch.py @@ -1,6 +1,5 @@ from __future__ import annotations -import argparse import os import sys from collections.abc import Sequence @@ -11,18 +10,16 @@ from devenv.constants import homebrew_bin from devenv.lib import proc from devenv.lib.context import Context -from devenv.lib.modules import DevModuleInfo +from devenv.lib.modules import argument +from devenv.lib.modules import command from devenv.lib.modules import ExitCode +from devenv.lib.modules import ModuleDef +@command("fetch", "Fetches a repository") +@argument("repo", help="the repository to fetch") def main(context: Context, argv: Sequence[str] | None = None) -> ExitCode: - parser = argparse.ArgumentParser() - - parser.add_argument( - "repo", type=str, help="the repository to fetch e.g., getsentry/sentry" - ) - - args = parser.parse_args(argv) + args = context["args"] code_root = context["code_root"] if args.repo in ["ops", "getsentry/ops"]: @@ -78,17 +75,20 @@ def main(context: Context, argv: Sequence[str] | None = None) -> ExitCode: ) else: + if "/" not in args.repo: + print("Repository names must be in the form of /") + return 1 fetch(code_root, args.repo) return 0 def fetch( - coderoot: str, repo: str, auth: bool = True, sync: bool = True + code_root: str, repo: str, auth: bool = True, sync: bool = True ) -> None: org, slug = repo.split("/") - codepath = f"{coderoot}/{slug}" + codepath = f"{code_root}/{slug}" if os.path.exists(codepath): print(f"{codepath} already exists") @@ -112,7 +112,7 @@ def fetch( ( "git", "-C", - coderoot, + code_root, "clone", "--filter=blob:none", *additional_args, @@ -124,6 +124,6 @@ def fetch( proc.run((sys.executable, "-P", "-m", "devenv", "sync"), cwd=codepath) -module_info = DevModuleInfo( - action=main, name=__name__, command="fetch", help="Fetches a respository" +module_info = ModuleDef( + module_name=__name__, name="fetch", help="Fetches a repository" ) diff --git a/devenv/pin_gha.py b/devenv/commands/pin_gha.py similarity index 90% rename from devenv/pin_gha.py rename to devenv/commands/pin_gha.py index 8d0966aa..12457a0b 100644 --- a/devenv/pin_gha.py +++ b/devenv/commands/pin_gha.py @@ -7,7 +7,8 @@ from functools import lru_cache from devenv.lib.context import Context -from devenv.lib.modules import DevModuleInfo +from devenv.lib.modules import command +from devenv.lib.modules import ModuleDef @lru_cache(maxsize=None) @@ -38,6 +39,7 @@ def extract_repo(action: str) -> str: return f"{parts[0]}/{parts[1]}" +@command("pin-gha", "Pins github actions.") def main(context: Context, argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser() parser.add_argument( @@ -70,6 +72,6 @@ def main(context: Context, argv: Sequence[str] | None = None) -> int: return 0 -module_info = DevModuleInfo( - action=main, name=__name__, command="pin_gha", help="Pins github actions." +module_info = ModuleDef( + module_name=__name__, name="pin_gha", help="Pins github actions" ) diff --git a/devenv/sync.py b/devenv/commands/sync.py similarity index 79% rename from devenv/sync.py rename to devenv/commands/sync.py index e03ac639..203a8dff 100644 --- a/devenv/sync.py +++ b/devenv/commands/sync.py @@ -5,10 +5,12 @@ from collections.abc import Sequence from devenv.lib.context import Context -from devenv.lib.modules import DevModuleInfo +from devenv.lib.modules import command +from devenv.lib.modules import ModuleDef from devenv.lib.modules import require_repo +@command("sync", "Resyncs the current project") @require_repo def main(context: Context, argv: Sequence[str] | None = None) -> int: repo = context["repo"] @@ -33,6 +35,6 @@ def main(context: Context, argv: Sequence[str] | None = None) -> int: return module.main(context_compat) # type: ignore -module_info = DevModuleInfo( - action=main, name=__name__, command="sync", help="Resyncs the environment." +module_info = ModuleDef( + module_name=__name__, name="sync", help="Resyncs the current project" ) diff --git a/devenv/lib/config.py b/devenv/lib/config.py index 625ff699..b5370622 100644 --- a/devenv/lib/config.py +++ b/devenv/lib/config.py @@ -60,13 +60,9 @@ def initialize_config(config_path: str, defaults: Config) -> None: else _val, ) - if not CI: + if not CI and opts: try: - if opts: - print(opts.prompt) - else: - print(f"{var}?") - + print(opts.prompt) val = input(f" [{val}]: ") or val except EOFError: # noninterative, use the defaults diff --git a/devenv/lib/context.py b/devenv/lib/context.py index 9377ca46..9f40dfaf 100644 --- a/devenv/lib/context.py +++ b/devenv/lib/context.py @@ -1,5 +1,6 @@ from __future__ import annotations +from argparse import Namespace from typing import TypedDict from devenv.lib.repository import Repository @@ -9,3 +10,4 @@ class Context(TypedDict): config_path: str code_root: str repo: Repository | None + args: Namespace diff --git a/devenv/lib/modules.py b/devenv/lib/modules.py index 1f294dbe..aae402f4 100644 --- a/devenv/lib/modules.py +++ b/devenv/lib/modules.py @@ -1,26 +1,191 @@ from __future__ import annotations +import argparse +import importlib +import inspect +import logging +import sys from collections.abc import Callable from collections.abc import Sequence from dataclasses import dataclass +from pkgutil import walk_packages +from types import ModuleType +from typing import List +from typing import Tuple from typing import TypeAlias +from typing import TypedDict + +import sentry_sdk +from typing_extensions import NotRequired from devenv.lib.context import Context ExitCode: TypeAlias = "str | int | None" - Action: TypeAlias = "Callable[[Context, Sequence[str] | None], ExitCode]" +ParserFn: TypeAlias = "Callable[[argparse.ArgumentParser], None]" @dataclass(frozen=True) -class DevModuleInfo: +class CommandInfo: name: str - command: str - help: str action: Action + help: str + arguments: Sequence[ParserFn] + + +@dataclass(frozen=True) +class ModuleDef: + module_name: str + name: str + help: str + + +@dataclass(frozen=True) +class DevModuleInfo: + module_def: ModuleDef + commands: Sequence[CommandInfo] + + +class ModuleAction: + def __init__(self, action: Action): + self.name: str + self.help: str + + self.action = action + self.argument_parsers: List[ParserFn] = [] + + def __call__( + self, context: Context, args: Sequence[str] | None + ) -> ExitCode: + return self.action(context, args) + + def add_argparser(self, fn: ParserFn) -> None: + self.argument_parsers.append(fn) + + def command_info(self) -> CommandInfo: + return CommandInfo( + self.name, self.action, self.help, arguments=self.argument_parsers + ) + + +def command(name: str, help: str) -> Callable[[Action], Action]: + """ + Marks a function as being a CLI command. + @commmand("commandname", "This command makes cookies") + """ + + def wrap(main: Action | ModuleAction) -> Action: + if isinstance(main, ModuleAction): + module_action = main + else: + module_action = ModuleAction(main) + + module_action.name = name + module_action.help = help + + return module_action + + return wrap + + +class ArgArgs(TypedDict): + required: NotRequired[bool] + metavar: NotRequired[str] + help: NotRequired[str] + choices: NotRequired[Sequence[str]] + action: NotRequired[str] + + +def _convert_argument_params( + *names: str, + help: str | None = None, + required: bool = True, + var: str | None = None, + choices: Sequence[str] | None = None, +) -> Tuple[Tuple[str, ...], ArgArgs]: + if len(names) == 0: + print("Argument must have a name") + raise SystemExit(1) + if names[0].startswith("-"): + argtype = "option" + else: + argtype = "positional" + + kwargs: ArgArgs = {} + + if argtype == "option": + kwargs["required"] = required + + if var: + kwargs["metavar"] = var + if choices: + kwargs["choices"] = choices + + if not (var or choices): + kwargs["action"] = "store_true" + + if help: + kwargs["help"] = help + else: + if choices: + kwargs["help"] = f"must be one of {choices}" + else: + name = var or names[0] + kwargs["metavar"] = name + + if choices: + kwargs["choices"] = choices + if help: + kwargs["help"] = help + else: + if choices: + kwargs["help"] = f"{name} must be one of {choices}" + return names, kwargs + + +def argument( + *names: str, + help: str | None = None, + required: bool = True, + var: str | None = None, + choices: Sequence[str] | None = None, +) -> Callable[[Action], Action]: + """ + Provides arguments to a command function. + + @command(name, help) + @argument("-v", help="Verbose") + """ + ak = _convert_argument_params( + *names, help=help, required=required, var=var, choices=choices + ) + + def add_args(argparse: argparse.ArgumentParser) -> None: + argparse.add_argument(*ak[0], **ak[1]) + + return argument_fn(add_args) + + +def argument_fn(fn: ParserFn) -> Callable[[Action], Action]: + def wrap(main: Action) -> Action: + if isinstance(main, ModuleAction): + module_action = main + else: + module_action = ModuleAction(main) + + module_action.add_argparser(fn) + return module_action + + return wrap def require(var: str, message: str) -> Callable[[Action], Action]: + """ + Indicates that a Context var is required for this command function + + @require("repo", "You need to be in a repository to use this command") + """ + def outer(main: Action) -> Action: def inner(context: Context, args: Sequence[str] | None) -> ExitCode: if context.get(var) is None: @@ -33,3 +198,49 @@ def inner(context: Context, args: Sequence[str] | None) -> ExitCode: require_repo = require("repo", "This command requires a repository") + + +def command_info(module: ModuleType) -> Sequence[CommandInfo]: + return [ + action.command_info() + for name, action in inspect.getmembers( + module, lambda m: isinstance(m, ModuleAction) + ) + ] + + +def module_info(module: ModuleType) -> DevModuleInfo: + info = module.module_info + return DevModuleInfo(module_def=info, commands=command_info(module)) + + +def load_modules(path: str, package: str) -> Sequence[ModuleType]: + all_modules = [] + for module_finder, module_name, _ in walk_packages( + (path,), prefix=f"{package}." + ): + module_spec = module_finder.find_spec(module_name, None) + + # it "should be" impossible to fail these: + assert module_spec is not None, module_name + assert module_spec.loader is not None, module_name + + module = importlib.util.module_from_spec(module_spec) + + if module_name not in sys.modules: + # load if not already loaded + sys.modules[module_name] = module + try: + module.__loader__.exec_module(module) # type: ignore + except Exception as e: + logger = logging.getLogger(__name__) + logger.warning( + f"Failed to load module {module_name}", exc_info=e + ) + sentry_sdk.capture_exception(e) + continue + + if hasattr(module, "module_info"): + all_modules.append(module) + + return all_modules diff --git a/devenv/main.py b/devenv/main.py index 32db0da5..2720c7af 100644 --- a/devenv/main.py +++ b/devenv/main.py @@ -3,21 +3,85 @@ import argparse import os from collections.abc import Sequence +from types import ModuleType +from typing import List + +import sentry_sdk -from devenv import bootstrap -from devenv import doctor -from devenv import fetch -from devenv import pin_gha -from devenv import sync from devenv.constants import home +from devenv.lib import modules from devenv.lib.config import read_config from devenv.lib.context import Context from devenv.lib.fs import gitroot +from devenv.lib.modules import CommandInfo from devenv.lib.modules import DevModuleInfo from devenv.lib.modules import ExitCode +from devenv.lib.modules import module_info from devenv.lib.repository import Repository +def generate_parser( + modinfo_list: Sequence[DevModuleInfo], +) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + subparser = parser.add_subparsers( + title=argparse.SUPPRESS, + metavar="command", + dest="command", + required=True, + ) + + for info in modinfo_list: + # don't show modules with no actions defined + if not info.commands: + continue + + module_def = info.module_def + child = subparser.add_parser(module_def.name, help=module_def.help) + + if len(info.commands) == 1 and ( + info.commands[0].name == module_def.name + ): + # single command, matching name case; i.e., module == command + command = info.commands[0] + for fn in command.arguments: + fn(child) + continue + + subsubparser = child.add_subparsers( + title="subcommands", + metavar="subcommand", + dest="subcommand", + required=True, + ) + + for command in info.commands: + if command.name == module_def.name: + # module has a default command, subcommand not required + subsubparser.required = False + for fn in command.arguments: + fn(child) + else: + grandchild = subsubparser.add_parser( + command.name, help=command.help + ) + for fn in command.arguments: + fn(grandchild) + return parser + + +def load_modules(path: str, package: ModuleType) -> List[DevModuleInfo]: + if not os.path.exists(path): + return [] + if path not in package.__path__: + package.__path__.append(path) + return [ + module_info(module) + for module in modules.load_modules(path, package.__name__) + if hasattr(module, "module_info") + ] + + def devenv(argv: Sequence[str], config_path: str) -> ExitCode: # determine current repo, if applicable fake_reporoot = os.getenv("CI_DEVENV_INTEGRATION_FAKE_REPOROOT") @@ -39,26 +103,25 @@ def devenv(argv: Sequence[str], config_path: str) -> ExitCode: else os.path.expanduser("~/code") ) - modinfo_list: Sequence[DevModuleInfo] = [ - module.module_info - for module in [bootstrap, fetch, doctor, pin_gha, sync] - if hasattr(module, "module_info") - ] + import devenv.commands - # TODO: Search for modules in work repo - - parser = argparse.ArgumentParser() - subparser = parser.add_subparsers( - title=argparse.SUPPRESS, - metavar="command", - dest="command", - required=True, + # load local commands from $code_root/.devenv/commands -- allows a user to override defaults + modinfo_list = load_modules( + f"{code_root}/.devenv/commands", devenv.commands + ) + # load default commands from installed devenv + modinfo_list.extend( + load_modules(devenv.commands.__path__[0], devenv.commands) ) - for info in modinfo_list: - # Argparse stuff - subparser.add_parser(info.command, help=info.help) + # load repo-specific commands + import devenv.usercommands + + if current_root: + user_path = f"{Repository(current_root).config_path}/usercommands" + modinfo_list.extend(load_modules(user_path, devenv.usercommands)) + parser = generate_parser(modinfo_list) args, remainder = parser.parse_known_args(argv[1:]) # context for subcommands @@ -66,17 +129,28 @@ def devenv(argv: Sequence[str], config_path: str) -> ExitCode: "config_path": config_path, "code_root": code_root, "repo": Repository(current_root) if current_root else None, + "args": args, } - command_actions = {info.command: info.action for info in modinfo_list} - action = command_actions.get(args.command) - assert action is not None - return action(context, remainder) + modinfo = next( + module + for module in modinfo_list + if module.module_def.name == args.command + ) + + commands: dict[str, CommandInfo] = { + command.name: command for command in modinfo.commands + } + command_name = getattr(args, "subcommand", None) or args.command + command = commands.get(command_name) + + assert command is not None + + return command.action(context, remainder) def main() -> ExitCode: import sys - import sentry_sdk sentry_sdk.init( # https://sentry.sentry.io/settings/projects/sentry-dev-env/keys/ @@ -85,7 +159,10 @@ def main() -> ExitCode: enable_tracing=True, ) - return devenv(sys.argv, f"{home}/.config/sentry-devenv/config.ini") + try: + return devenv(sys.argv, f"{home}/.config/sentry-devenv/config.ini") + except KeyboardInterrupt: + return -1 if __name__ == "__main__": diff --git a/devenv/usercommands/__init__.py b/devenv/usercommands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/setup.cfg b/setup.cfg index 487746c7..a2fada9b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,4 +3,4 @@ max-line-length = 100 extend-ignore = ## black takes care of that: # line too long (X > Y characters) - E501 + E501,E203,E701 diff --git a/tests/doctor/test_attempt_fix.py b/tests/doctor/test_attempt_fix.py index 8cf749cd..b021c89b 100644 --- a/tests/doctor/test_attempt_fix.py +++ b/tests/doctor/test_attempt_fix.py @@ -2,7 +2,7 @@ from concurrent.futures import ThreadPoolExecutor -from devenv import doctor +from devenv.commands import doctor from tests.doctor.devenv.checks import broken_fix from tests.doctor.devenv.checks import failing_check from tests.doctor.devenv.checks import failing_check_with_msg diff --git a/tests/doctor/test_filter_failing_checks.py b/tests/doctor/test_filter_failing_checks.py index 92a68d38..c38bf9c6 100644 --- a/tests/doctor/test_filter_failing_checks.py +++ b/tests/doctor/test_filter_failing_checks.py @@ -1,6 +1,6 @@ from __future__ import annotations -from devenv import doctor +from devenv.commands import doctor from tests.doctor.devenv.checks import failing_check from tests.doctor.devenv.checks import passing_check diff --git a/tests/doctor/test_load_checks.py b/tests/doctor/test_load_checks.py index 60b5017f..b8b7e194 100644 --- a/tests/doctor/test_load_checks.py +++ b/tests/doctor/test_load_checks.py @@ -4,7 +4,7 @@ import pytest -from devenv import doctor +from devenv.commands import doctor from devenv.lib.repository import Repository diff --git a/tests/doctor/test_prompt_for_fix.py b/tests/doctor/test_prompt_for_fix.py index 73e84970..c794b513 100644 --- a/tests/doctor/test_prompt_for_fix.py +++ b/tests/doctor/test_prompt_for_fix.py @@ -5,7 +5,7 @@ import pytest -from devenv import doctor +from devenv.commands import doctor from tests.doctor.devenv.checks import passing_check diff --git a/tests/doctor/test_run_checks.py b/tests/doctor/test_run_checks.py index 8e8f4c97..05fb4ca7 100644 --- a/tests/doctor/test_run_checks.py +++ b/tests/doctor/test_run_checks.py @@ -4,7 +4,7 @@ import pytest -from devenv import doctor +from devenv.commands import doctor from tests.doctor.devenv.checks import broken_check from tests.doctor.devenv.checks import failing_check from tests.doctor.devenv.checks import failing_check_with_msg