Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ jobs:
run: uv run --dev ruff check .
continue-on-error: true

- name: Type check
run: |
uv run --dev pyrefly check .
uv run --dev mypy .
continue-on-error: true
- name: Pyrefly type check
run: uv run --dev pyrefly check .

- name: Mypy type check
run: uv run --dev mypy .

- name: Run tests
run: |
Expand All @@ -61,15 +61,15 @@ jobs:
uv run --dev pytest -vv --cov=alternative --cov-report=xml --junitxml=test-results.xml

- name: Upload coverage to Codecov
if: always() && env.CODECOV_TOKEN != ''
if: (success() || hashFiles('coverage.xml') != '') && env.CODECOV_TOKEN != ''
uses: codecov/codecov-action@v5
with:
token: ${{ env.CODECOV_TOKEN }}
files: ./coverage.xml
fail_ci_if_error: true

- name: Upload test results to Codecov
if: always() && env.CODECOV_TOKEN != ''
if: (success() || hashFiles('test-results.xml') != '') && env.CODECOV_TOKEN != ''
uses: codecov/codecov-action@v5
with:
token: ${{ env.CODECOV_TOKEN }}
Expand Down
49 changes: 33 additions & 16 deletions alternative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import os
from functools import wraps, lru_cache
from typing import Callable
from typing import Callable, Protocol
from typing import cast, overload


Expand All @@ -31,11 +31,19 @@
class _UNDEFINED: ...


class _SupportsLessThan(Protocol):
def __lt__(self, other: object, /) -> bool: ...


_UNDEFINED_VALUE = _UNDEFINED()

type ImplementationSig[**P, R] = Callable[P, R] | Implementation[P, R]
type AlternativesWrapper[**P, R] = Callable[[ImplementationSig], Alternatives[P, R]]
type ImplementationWrapper[**P, R] = Callable[[ImplementationSig], Implementation[P, R]]
type AlternativesWrapper[**P, R] = Callable[
[ImplementationSig[P, R]], Alternatives[P, R]
]
type ImplementationWrapper[**P, R] = Callable[
[ImplementationSig[P, R]], Implementation[P, R]
]


class AlternativeError(Exception):
Expand Down Expand Up @@ -104,7 +112,6 @@ def __init__(self, implementation: Callable[P, R], *, default: bool = False):
self._debug_invoked_site: str | None = None
# tracks the use of the set should be
self._enumerated = False
self._debug_invoked_site: str | None = None

self._callable: Callable[P, R] | None = None
self._debug_callable_used: str | None = None
Expand Down Expand Up @@ -187,7 +194,7 @@ def callable(self) -> Callable[P, R]:
else:
self._callable = self.reference
self._debug_callable_used = _maybe_get_caller_path()
self.__call__ = self._callable
setattr(self, "__call__", self._callable)
# access the list of implementations to freeze them
assert self.implementations
return self._callable
Expand Down Expand Up @@ -219,7 +226,15 @@ def measure[M](
}
try:
# try to sort the dictionary by the measurements
return dict(sorted(result.items(), key=lambda x: x[1]))
return dict(
sorted(
result.items(),
key=cast(
Callable[[tuple[Implementation[P, R], M]], _SupportsLessThan],
lambda x: cast(_SupportsLessThan, x[1]),
),
)
)
except TypeError:
return result

Expand Down Expand Up @@ -252,10 +267,10 @@ def pytest_parametrize(

if isinstance(test, _UNDEFINED):

def inner(f: Callable):
def decorator(f: Callable):
return self.pytest_parametrize(f, only_default=only_default)

return inner
return decorator

implementations = self._select_parametrize_implementations(
only_default=only_default
Expand Down Expand Up @@ -309,15 +324,15 @@ def pytest_parametrize_pairs(

if isinstance(test, _UNDEFINED):

def inner(f: Callable):
def decorator(f: Callable):
return self.pytest_parametrize_pairs(
f,
n_cache=n_cache,
double_reference=double_reference,
only_default=only_default,
)

return inner
return decorator

reference_implementation = lru_cache(maxsize=n_cache)(
self.reference.implementation
Expand Down Expand Up @@ -390,7 +405,7 @@ def __repr__(self) -> str:
return f"Implementation({implementation_name})"

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
self.__call__ = self.implementation
setattr(self, "__call__", self.implementation)
return self.__call__(*args, **kwargs)

@overload
Expand All @@ -412,18 +427,20 @@ def add(
@overload
def reference[**P, R](
implementation: _UNDEFINED = _UNDEFINED_VALUE, *, default: bool = False
) -> AlternativesWrapper[P, R]: ...
) -> Callable[[Callable[P, R]], Alternatives[P, R]]: ...


@overload
def reference[**P, R](
implementation: ImplementationSig[P, R], *, default: bool = False
def reference[**P, R]( # pyrefly: ignore[inconsistent-overload]
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 this is just a temporary hack to get the pyrefly into the CI pipeline. mypy already passes so that is probably good enough. I'll get this looked into outside of Codex cloud, as that doesn't have pyrefly in the current environment

implementation: Callable[P, R], *, default: bool = False
) -> Alternatives[P, R]: ...


def reference[**P, R](
implementation=_UNDEFINED_VALUE, *, default=False
) -> Alternatives[P, R] | AlternativesWrapper[P, R]:
implementation=_UNDEFINED_VALUE,
*,
default=False,
) -> Alternatives[P, R] | Callable[[Callable[P, R]], Alternatives[P, R]]:
if isinstance(implementation, _UNDEFINED):

def inner(f: Callable[P, R]) -> Alternatives[P, R]:
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ build-backend = "hatchling.build"
[tool.pytest.ini_options]
# do 5 rounds of 0.01 benchmarks, as the benchmarks are examples or very fast
addopts = "--cov=alternative --cov-report=html --benchmark-max-time=0.01"


[tool.mypy]
python_version = "3.12"

[[tool.mypy.overrides]]
module = ["pytest"]
ignore_missing_imports = true
Loading