diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4a0664a..d88b18d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9, "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v2 with: @@ -21,7 +21,6 @@ jobs: - name: Install dependencies run: | sudo apt-get install gfortran - # JAX isn't yet NumPy 2 compatible. pip install --upgrade pip setuptools 'setuptools_scm[toml]' setuptools_scm_git_archive numpy Cython python setup.py --version LAB_BUILD=1 pip install --no-cache-dir -U -r requirements.txt | cat diff --git a/lab/types.py b/lab/types.py index 5c5232b..7c75524 100644 --- a/lab/types.py +++ b/lab/types.py @@ -103,6 +103,7 @@ def _jax_version(): "jaxlib._jax", "ArrayImpl", condition=lambda: _jax_version() >= (0, 6, 0), + faithful=True, ), ] _jax_tracer = ModuleType("jax.core", "Tracer") diff --git a/lab/util.py b/lab/util.py index 7622435..a605e38 100644 --- a/lab/util.py +++ b/lab/util.py @@ -2,8 +2,7 @@ import numpy as np import plum -import plum.signature -import plum.type +from plum._method import MethodList from . import B @@ -210,8 +209,8 @@ def wrapper(*args, **kw_args): # means that an implementation is not available. types_after = tuple(type(arg) for arg in args) if types_before == types_after: - signature = plum.signature.Signature(*types_after) - raise plum.NotFoundLookupError(f.__name__, signature, []) + signature = plum.Signature(*types_after) + raise plum.NotFoundLookupError(f.__name__, signature, MethodList()) # Retry call. return getattr(B, f.__name__)(*args, **kw_args) diff --git a/setup.py b/setup.py index d45f4fd..8d52a6d 100755 --- a/setup.py +++ b/setup.py @@ -112,13 +112,13 @@ requirements = [ "numpy>=1.16", "scipy>=1.3", - "plum-dispatch>=2.6.0", + "plum-dispatch>=2.7.1", "opt-einsum", ] setup( packages=find_packages(exclude=["docs"]), - python_requires=">=3.9", + python_requires=">=3.10", install_requires=requirements, cmdclass={"build_ext": build_ext}, ext_modules=ext_modules, diff --git a/tests/test_types.py b/tests/test_types.py index 9f4e12a..8c531b1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -5,13 +5,12 @@ import tensorflow as tf import torch from autograd import grad -from plum import isinstance -from plum.promotion import _promotion_rule, convert +from plum import convert, isinstance +from plum._promotion import _promotion_rule import lab as B -# noinspection PyUnresolvedReferences -from .util import autograd_box, check_lazy_shapes +from .util import autograd_box, check_lazy_shapes # noqa: F401 def test_numeric(check_lazy_shapes):