diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bf01667..f084159 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,11 +46,11 @@ jobs: run: uv run --dev ruff check . continue-on-error: true - - name: Type check - run: | - uv run --dev pyrefly check . - uv run --dev mypy . - continue-on-error: true + - name: Pyrefly type check + run: uv run --dev pyrefly check . + + - name: Mypy type check + run: uv run --dev mypy . - name: Run tests run: | @@ -61,7 +61,7 @@ jobs: uv run --dev pytest -vv --cov=alternative --cov-report=xml --junitxml=test-results.xml - name: Upload coverage to Codecov - if: always() && env.CODECOV_TOKEN != '' + if: (success() || hashFiles('coverage.xml') != '') && env.CODECOV_TOKEN != '' uses: codecov/codecov-action@v5 with: token: ${{ env.CODECOV_TOKEN }} @@ -69,7 +69,7 @@ jobs: fail_ci_if_error: true - name: Upload test results to Codecov - if: always() && env.CODECOV_TOKEN != '' + if: (success() || hashFiles('test-results.xml') != '') && env.CODECOV_TOKEN != '' uses: codecov/codecov-action@v5 with: token: ${{ env.CODECOV_TOKEN }} diff --git a/alternative.py b/alternative.py index 1029269..0bfbd71 100644 --- a/alternative.py +++ b/alternative.py @@ -4,7 +4,7 @@ import inspect import os from functools import wraps, lru_cache -from typing import Callable +from typing import Callable, Protocol from typing import cast, overload @@ -31,11 +31,19 @@ class _UNDEFINED: ... +class _SupportsLessThan(Protocol): + def __lt__(self, other: object, /) -> bool: ... + + _UNDEFINED_VALUE = _UNDEFINED() type ImplementationSig[**P, R] = Callable[P, R] | Implementation[P, R] -type AlternativesWrapper[**P, R] = Callable[[ImplementationSig], Alternatives[P, R]] -type ImplementationWrapper[**P, R] = Callable[[ImplementationSig], Implementation[P, R]] +type AlternativesWrapper[**P, R] = Callable[ + [ImplementationSig[P, R]], Alternatives[P, R] +] +type ImplementationWrapper[**P, R] = Callable[ + [ImplementationSig[P, R]], Implementation[P, R] +] class AlternativeError(Exception): @@ -104,7 +112,6 @@ def __init__(self, implementation: Callable[P, R], *, default: bool = False): self._debug_invoked_site: str | None = None # tracks the use of the set should be self._enumerated = False - self._debug_invoked_site: str | None = None self._callable: Callable[P, R] | None = None self._debug_callable_used: str | None = None @@ -187,7 +194,7 @@ def callable(self) -> Callable[P, R]: else: self._callable = self.reference self._debug_callable_used = _maybe_get_caller_path() - self.__call__ = self._callable + setattr(self, "__call__", self._callable) # access the list of implementations to freeze them assert self.implementations return self._callable @@ -219,7 +226,15 @@ def measure[M]( } try: # try to sort the dictionary by the measurements - return dict(sorted(result.items(), key=lambda x: x[1])) + return dict( + sorted( + result.items(), + key=cast( + Callable[[tuple[Implementation[P, R], M]], _SupportsLessThan], + lambda x: cast(_SupportsLessThan, x[1]), + ), + ) + ) except TypeError: return result @@ -252,10 +267,10 @@ def pytest_parametrize( if isinstance(test, _UNDEFINED): - def inner(f: Callable): + def decorator(f: Callable): return self.pytest_parametrize(f, only_default=only_default) - return inner + return decorator implementations = self._select_parametrize_implementations( only_default=only_default @@ -309,7 +324,7 @@ def pytest_parametrize_pairs( if isinstance(test, _UNDEFINED): - def inner(f: Callable): + def decorator(f: Callable): return self.pytest_parametrize_pairs( f, n_cache=n_cache, @@ -317,7 +332,7 @@ def inner(f: Callable): only_default=only_default, ) - return inner + return decorator reference_implementation = lru_cache(maxsize=n_cache)( self.reference.implementation @@ -390,7 +405,7 @@ def __repr__(self) -> str: return f"Implementation({implementation_name})" def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - self.__call__ = self.implementation + setattr(self, "__call__", self.implementation) return self.__call__(*args, **kwargs) @overload @@ -412,18 +427,20 @@ def add( @overload def reference[**P, R]( implementation: _UNDEFINED = _UNDEFINED_VALUE, *, default: bool = False -) -> AlternativesWrapper[P, R]: ... +) -> Callable[[Callable[P, R]], Alternatives[P, R]]: ... @overload -def reference[**P, R]( - implementation: ImplementationSig[P, R], *, default: bool = False +def reference[**P, R]( # pyrefly: ignore[inconsistent-overload] + implementation: Callable[P, R], *, default: bool = False ) -> Alternatives[P, R]: ... def reference[**P, R]( - implementation=_UNDEFINED_VALUE, *, default=False -) -> Alternatives[P, R] | AlternativesWrapper[P, R]: + implementation=_UNDEFINED_VALUE, + *, + default=False, +) -> Alternatives[P, R] | Callable[[Callable[P, R]], Alternatives[P, R]]: if isinstance(implementation, _UNDEFINED): def inner(f: Callable[P, R]) -> Alternatives[P, R]: diff --git a/pyproject.toml b/pyproject.toml index c5dae07..fa64e2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,3 +31,11 @@ build-backend = "hatchling.build" [tool.pytest.ini_options] # do 5 rounds of 0.01 benchmarks, as the benchmarks are examples or very fast addopts = "--cov=alternative --cov-report=html --benchmark-max-time=0.01" + + +[tool.mypy] +python_version = "3.12" + +[[tool.mypy.overrides]] +module = ["pytest"] +ignore_missing_imports = true