diff --git a/src/experimaestro/core/objects/config.py b/src/experimaestro/core/objects/config.py index 9921a9d7..c658c968 100644 --- a/src/experimaestro/core/objects/config.py +++ b/src/experimaestro/core/objects/config.py @@ -34,7 +34,7 @@ import sys import experimaestro from experimaestro.utils import logger -from experimaestro.core.types import DeprecatedAttribute, ObjectType +from experimaestro.core.types import DeprecatedAttribute, ObjectType, TypeVarType from ..context import SerializationContext, SerializedPath, SerializedPathLoader if TYPE_CHECKING: @@ -145,6 +145,10 @@ def __init__(self, pyobject: "ConfigMixin"): # Explicitely added dependencies self.dependencies = [] + # Concrete type variables resolutions + # This is used to check typevars coherence + self.concrete_typevars: Dict[TypeVar, type] = {} + # Lightweight tasks self.pre_tasks: List["LightweightTask"] = [] @@ -199,6 +203,13 @@ def set(self, k, v, bypass=False): raise AttributeError("Property %s is read-only" % (k)) if v is not None: self.values[k] = argument.validate(v) + # Check for type variables + if type(argument.type) is TypeVarType: + self.check_typevar(argument.type.typevar, type(v)) + if isinstance(v, Config): + # If the value is a Config, fuse type variables + v.__xpm__.fuse_concrete_typevars(self.concrete_typevars) + self.fuse_concrete_typevars(v.__xpm__.concrete_typevars) elif argument.required: raise AttributeError("Cannot set required attribute to None") else: @@ -211,6 +222,43 @@ def set(self, k, v, bypass=False): logger.error("Error while setting value %s in %s", k, self.xpmtype) raise + def fuse_concrete_typevars(self, typevars: Dict[TypeVar, type]): + """Fuses concrete type variables with the current ones""" + for typevar, v in typevars.items(): + self.check_typevar(typevar, v) + + def check_typevar(self, typevar: TypeVar, v: type): + """Check if a type variable is coherent with the current typevars bindings, + updates the bindings if necessary""" + if typevar not in self.concrete_typevars: + self.concrete_typevars[typevar] = v + return + + concrete_typevar = self.concrete_typevars[typevar] + bound = typevar.__bound__ + # Check that v is a subclass of the typevar OR that typevar is a subclass of v + # Then set the concrete type variable to the most generic type + + # First, limiting to the specified bound + if bound is not None: + if not issubclass(v, bound): + raise TypeError( + f"Type variable {typevar} is bound to {bound}, but tried to set it to {v}" + ) + + if issubclass(v, concrete_typevar): + # v is a subclass of the typevar, keep the typevar + return + if issubclass(concrete_typevar, v): + # typevar is a subclass of v, keep v + self.concrete_typevars[typevar] = v + return + raise TypeError( + f"Type variable {typevar} is already set to {self.concrete_typevars[typevar]}, " + f"but tried to set it to {v}" + f" (current typevars bindings: {self.concrete_typevars})" + ) + def addtag(self, name, value): self._tags[name] = value diff --git a/src/experimaestro/core/types.py b/src/experimaestro/core/types.py index 3fc0fcf5..041f9b8e 100644 --- a/src/experimaestro/core/types.py +++ b/src/experimaestro/core/types.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import inspect import sys -from typing import Set, Union, Dict, Iterator, List, get_args, get_origin +from typing import Set, TypeVar, Union, Dict, Iterator, List, get_args, get_origin from collections import ChainMap from pathlib import Path import typing @@ -130,10 +130,13 @@ def fromType(key): if union_t := typingutils.get_union(key): return UnionType([Type.fromType(t) for t in union_t]) - # Takes care of generics + # Takes care of generics, like List[int], not List if get_origin(key): return GenericType(key) + if isinstance(key, TypeVar): + return TypeVarType(key) + raise Exception("No type found for %s", key) @@ -597,6 +600,23 @@ def validate(self, value): return value +class TypeVarType(Type): + def __init__(self, typevar: TypeVar): + self.typevar = typevar + + def name(self): + return str(self.typevar) + + def validate(self, value): + return value + + def __str__(self): + return f"TypeVar({self.typevar})" + + def __repr__(self): + return f"TypeVar({self.typevar})" + + Any = AnyType() @@ -698,6 +718,10 @@ def name(self): def __repr__(self): return repr(self.type) + def identifier(self): + """Returns the identifier of the type""" + return Identifier(f"{self.origin}.{self.type}") + def validate(self, value): # Now, let's check generics... mros = typingutils.generic_mro(type(value)) diff --git a/src/experimaestro/tests/core/__init__.py b/src/experimaestro/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/experimaestro/tests/core/test_generics.py b/src/experimaestro/tests/core/test_generics.py new file mode 100644 index 00000000..902a741a --- /dev/null +++ b/src/experimaestro/tests/core/test_generics.py @@ -0,0 +1,206 @@ +"""Tests for the use of generics in configurations""" + +from typing import Generic, Optional, TypeVar + +import pytest +from experimaestro import Config, Param +from experimaestro.core.arguments import Argument +from experimaestro.core.types import TypeVarType + +T = TypeVar("T") + + +class SimpleConfig(Config): + pass + + +class SimpleConfigChild(SimpleConfig): + pass + + +class SimpleGenericConfig(Config, Generic[T]): + x: Param[T] + + +class SimpleGenericConfigChild(SimpleGenericConfig, Generic[T]): + """A child class of SimpleGenericConfig that also uses generics""" + + pass + + +def test_core_generics_typevar(): + a = SimpleGenericConfig.C(x=1) + + x_arg = a.__xpmtype__.arguments["x"] + + # Check correct interpretation of typevar + assert type(x_arg) is Argument + assert isinstance(x_arg.type, TypeVarType) + assert x_arg.type.typevar == T + + assert isinstance(a.x, int) + + +def test_core_generics_simple(): + a = SimpleGenericConfig.C(x=2) + + # OK + a.x = 3 + + # Fails: changing generics is not allowed + with pytest.raises(TypeError): + a.x = "a string" + + # typevar bindings are local to the instance, + # so we can create a new instance with a different type + SimpleGenericConfig.C(x="a string") + + +class DoubleGenericConfig(Config, Generic[T]): + x: Param[T] + y: Param[T] + + +def test_core_generics_double(): + # OK + DoubleGenericConfig.C(x=1, y=1) + + # Fails + with pytest.raises(TypeError): + DoubleGenericConfig.C(x=1, y="a") + + a = DoubleGenericConfig.C(x=1, y=1) + a.y = 2 + with pytest.raises(TypeError): + a.x = "b" + + +def test_core_generics_double_rebind(): + a = DoubleGenericConfig.C(x=1, y=1) + # Rebinding to a different type should not work + with pytest.raises(TypeError): + a.x, a.y = "some", "string" + + +def test_core_generics_double_plus(): + # Testing with inheritance + # We allow subclasses of the typevar binding + # We also allow generalizing up the typevar binding + # This means that we can use a super class of the typevar binding + + # Works + a = DoubleGenericConfig.C(x=SimpleConfigChild.C()) + a.y = SimpleConfig.C() + + # Works also + b = DoubleGenericConfig.C(x=SimpleConfig.C()) + b.y = SimpleConfigChild.C() + + a.x = SimpleConfigChild.C() + + with pytest.raises(TypeError): + a.x = "a string" + + +def test_core_generics_double_type_escalation(): + a = DoubleGenericConfig.C(x=SimpleConfigChild.C()) + a.y = SimpleConfigChild.C() + # T is now bound to SimpleConfigChild + + a.y = SimpleConfig.C() + # T is now bound to SimpleConfig + + a.y = object() + # T is now bound to object, which is a super class of SimpleConfigChild + + # This is allowed, since we are not changing the typevar binding + a.x = "a string" + + a.y = dict() + # This is allowed, since we are not changing the typevar binding + + +def test_core_generics_double_deep_bind(): + # Since we are deep binding the typevar T to a specific type, + # we should not be able to have coherent *local-only* type bindings + # The type bindings are transient + + with pytest.raises(TypeError): + DoubleGenericConfig.C( + x=DoubleGenericConfig.C(x=1, y=2), y=DoubleGenericConfig.C(x=3, y=4) + ) + + +class NestedConfig(Config, Generic[T]): + x: Param[DoubleGenericConfig[T]] + y: Param[SimpleGenericConfig[T]] + + +def test_core_generics_nested(): + # OK + NestedConfig.C(x=DoubleGenericConfig.C(x=1, y=1), y=SimpleGenericConfig.C(x=2)) + + # Not OK + with pytest.raises(TypeError): + NestedConfig.C( + x=DoubleGenericConfig.C(x=1, y=1), y=SimpleGenericConfig.C(x="b") + ) + + with pytest.raises(TypeError): + a = NestedConfig.C( + x=DoubleGenericConfig.C(x=1, y=1), y=SimpleGenericConfig.C(x=1) + ) + a.x.x = "a string" + + +class TreeGenericConfig(Config, Generic[T]): + x: Param[T] + left: Optional["TreeGenericConfig[T]"] = None + right: Optional["TreeGenericConfig[T]"] = None + + +class TagTreeGenericConfig(TreeGenericConfig[T], Generic[T]): + """A tagged version of TreeGenericConfig to test recursive generics""" + + tag: Param[str] = "default" + + +def test_core_generics_recursive(): + a = TreeGenericConfig.C(x=1) + a.left = TreeGenericConfig.C(x=2) + a.right = TreeGenericConfig.C(x=3) + + with pytest.raises(TypeError): + a.left.x = "a string" + + # OK to use a child class + a.left = TagTreeGenericConfig.C(x=4, tag="left") + + with pytest.raises(TypeError): + a.left.x = "a string" + + +def test_core_generics_recursive_child(): + # Testing with a child class on the generic value + a = TreeGenericConfig.C(x=SimpleConfig.C()) + a.left = TreeGenericConfig.C(x=SimpleConfig.C()) + a.right = TreeGenericConfig.C(x=SimpleConfig.C()) + + a.left.x = SimpleConfigChild.C() + + with pytest.raises(TypeError): + a.left.x = "a string" + + +U = TypeVar("U", bound=SimpleConfigChild) + + +class BoundGenericConfig(Config, Generic[U]): + x: Param[U] + + +def test_core_generics_bound_typevar(): + a = BoundGenericConfig.C(x=SimpleConfigChild.C()) + assert isinstance(a.x, SimpleConfigChild) + with pytest.raises(TypeError): + a.x = SimpleConfig.C()