Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/experimaestro/core/objects/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import sys
import experimaestro
from experimaestro.utils import logger
from experimaestro.core.types import DeprecatedAttribute, ObjectType
from experimaestro.core.types import DeprecatedAttribute, ObjectType, TypeVarType
from ..context import SerializationContext, SerializedPath, SerializedPathLoader

if TYPE_CHECKING:
Expand Down Expand Up @@ -145,6 +145,10 @@ def __init__(self, pyobject: "ConfigMixin"):
# Explicitely added dependencies
self.dependencies = []

# Concrete type variables resolutions
# This is used to check typevars coherence
self.concrete_typevars: Dict[TypeVar, type] = {}

# Lightweight tasks
self.pre_tasks: List["LightweightTask"] = []

Expand Down Expand Up @@ -199,6 +203,13 @@ def set(self, k, v, bypass=False):
raise AttributeError("Property %s is read-only" % (k))
if v is not None:
self.values[k] = argument.validate(v)
# Check for type variables
if type(argument.type) is TypeVarType:
self.check_typevar(argument.type.typevar, type(v))
if isinstance(v, Config):
# If the value is a Config, fuse type variables
v.__xpm__.fuse_concrete_typevars(self.concrete_typevars)
self.fuse_concrete_typevars(v.__xpm__.concrete_typevars)
elif argument.required:
raise AttributeError("Cannot set required attribute to None")
else:
Expand All @@ -211,6 +222,43 @@ def set(self, k, v, bypass=False):
logger.error("Error while setting value %s in %s", k, self.xpmtype)
raise

def fuse_concrete_typevars(self, typevars: Dict[TypeVar, type]):
"""Fuses concrete type variables with the current ones"""
for typevar, v in typevars.items():
self.check_typevar(typevar, v)

def check_typevar(self, typevar: TypeVar, v: type):
"""Check if a type variable is coherent with the current typevars bindings,
updates the bindings if necessary"""
if typevar not in self.concrete_typevars:
self.concrete_typevars[typevar] = v
return

concrete_typevar = self.concrete_typevars[typevar]
bound = typevar.__bound__
# Check that v is a subclass of the typevar OR that typevar is a subclass of v
# Then set the concrete type variable to the most generic type

# First, limiting to the specified bound
if bound is not None:
if not issubclass(v, bound):
raise TypeError(
f"Type variable {typevar} is bound to {bound}, but tried to set it to {v}"
)

if issubclass(v, concrete_typevar):
# v is a subclass of the typevar, keep the typevar
return
if issubclass(concrete_typevar, v):
# typevar is a subclass of v, keep v
self.concrete_typevars[typevar] = v
return
raise TypeError(
f"Type variable {typevar} is already set to {self.concrete_typevars[typevar]}, "
f"but tried to set it to {v}"
f" (current typevars bindings: {self.concrete_typevars})"
)

def addtag(self, name, value):
self._tags[name] = value

Expand Down
28 changes: 26 additions & 2 deletions src/experimaestro/core/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
import inspect
import sys
from typing import Set, Union, Dict, Iterator, List, get_args, get_origin
from typing import Set, TypeVar, Union, Dict, Iterator, List, get_args, get_origin
from collections import ChainMap
from pathlib import Path
import typing
Expand Down Expand Up @@ -130,10 +130,13 @@ def fromType(key):
if union_t := typingutils.get_union(key):
return UnionType([Type.fromType(t) for t in union_t])

# Takes care of generics
# Takes care of generics, like List[int], not List
if get_origin(key):
return GenericType(key)

if isinstance(key, TypeVar):
return TypeVarType(key)

raise Exception("No type found for %s", key)


Expand Down Expand Up @@ -597,6 +600,23 @@ def validate(self, value):
return value


class TypeVarType(Type):
def __init__(self, typevar: TypeVar):
self.typevar = typevar

def name(self):
return str(self.typevar)

def validate(self, value):
return value

def __str__(self):
return f"TypeVar({self.typevar})"

def __repr__(self):
return f"TypeVar({self.typevar})"


Any = AnyType()


Expand Down Expand Up @@ -698,6 +718,10 @@ def name(self):
def __repr__(self):
return repr(self.type)

def identifier(self):
"""Returns the identifier of the type"""
return Identifier(f"{self.origin}.{self.type}")

def validate(self, value):
# Now, let's check generics...
mros = typingutils.generic_mro(type(value))
Expand Down
Empty file.
206 changes: 206 additions & 0 deletions src/experimaestro/tests/core/test_generics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""Tests for the use of generics in configurations"""

from typing import Generic, Optional, TypeVar

import pytest
from experimaestro import Config, Param
from experimaestro.core.arguments import Argument
from experimaestro.core.types import TypeVarType

T = TypeVar("T")


class SimpleConfig(Config):
pass


class SimpleConfigChild(SimpleConfig):
pass


class SimpleGenericConfig(Config, Generic[T]):
x: Param[T]


class SimpleGenericConfigChild(SimpleGenericConfig, Generic[T]):
"""A child class of SimpleGenericConfig that also uses generics"""

pass


def test_core_generics_typevar():
a = SimpleGenericConfig.C(x=1)

x_arg = a.__xpmtype__.arguments["x"]

# Check correct interpretation of typevar
assert type(x_arg) is Argument
assert isinstance(x_arg.type, TypeVarType)
assert x_arg.type.typevar == T

assert isinstance(a.x, int)


def test_core_generics_simple():
a = SimpleGenericConfig.C(x=2)

# OK
a.x = 3

# Fails: changing generics is not allowed
with pytest.raises(TypeError):
a.x = "a string"

# typevar bindings are local to the instance,
# so we can create a new instance with a different type
SimpleGenericConfig.C(x="a string")


class DoubleGenericConfig(Config, Generic[T]):
x: Param[T]
y: Param[T]


def test_core_generics_double():
# OK
DoubleGenericConfig.C(x=1, y=1)

# Fails
with pytest.raises(TypeError):
DoubleGenericConfig.C(x=1, y="a")

a = DoubleGenericConfig.C(x=1, y=1)
a.y = 2
with pytest.raises(TypeError):
a.x = "b"


def test_core_generics_double_rebind():
a = DoubleGenericConfig.C(x=1, y=1)
# Rebinding to a different type should not work
with pytest.raises(TypeError):
a.x, a.y = "some", "string"


def test_core_generics_double_plus():
# Testing with inheritance
# We allow subclasses of the typevar binding
# We also allow generalizing up the typevar binding
# This means that we can use a super class of the typevar binding

# Works
a = DoubleGenericConfig.C(x=SimpleConfigChild.C())
a.y = SimpleConfig.C()

# Works also
b = DoubleGenericConfig.C(x=SimpleConfig.C())
b.y = SimpleConfigChild.C()

a.x = SimpleConfigChild.C()

with pytest.raises(TypeError):
a.x = "a string"


def test_core_generics_double_type_escalation():
a = DoubleGenericConfig.C(x=SimpleConfigChild.C())
a.y = SimpleConfigChild.C()
# T is now bound to SimpleConfigChild

a.y = SimpleConfig.C()
# T is now bound to SimpleConfig

a.y = object()
# T is now bound to object, which is a super class of SimpleConfigChild

# This is allowed, since we are not changing the typevar binding
a.x = "a string"

a.y = dict()
# This is allowed, since we are not changing the typevar binding


def test_core_generics_double_deep_bind():
# Since we are deep binding the typevar T to a specific type,
# we should not be able to have coherent *local-only* type bindings
# The type bindings are transient

with pytest.raises(TypeError):
DoubleGenericConfig.C(
x=DoubleGenericConfig.C(x=1, y=2), y=DoubleGenericConfig.C(x=3, y=4)
)


class NestedConfig(Config, Generic[T]):
x: Param[DoubleGenericConfig[T]]
y: Param[SimpleGenericConfig[T]]


def test_core_generics_nested():
# OK
NestedConfig.C(x=DoubleGenericConfig.C(x=1, y=1), y=SimpleGenericConfig.C(x=2))

# Not OK
with pytest.raises(TypeError):
NestedConfig.C(
x=DoubleGenericConfig.C(x=1, y=1), y=SimpleGenericConfig.C(x="b")
)

with pytest.raises(TypeError):
a = NestedConfig.C(
x=DoubleGenericConfig.C(x=1, y=1), y=SimpleGenericConfig.C(x=1)
)
a.x.x = "a string"


class TreeGenericConfig(Config, Generic[T]):
x: Param[T]
left: Optional["TreeGenericConfig[T]"] = None
right: Optional["TreeGenericConfig[T]"] = None


class TagTreeGenericConfig(TreeGenericConfig[T], Generic[T]):
"""A tagged version of TreeGenericConfig to test recursive generics"""

tag: Param[str] = "default"


def test_core_generics_recursive():
a = TreeGenericConfig.C(x=1)
a.left = TreeGenericConfig.C(x=2)
a.right = TreeGenericConfig.C(x=3)

with pytest.raises(TypeError):
a.left.x = "a string"

# OK to use a child class
a.left = TagTreeGenericConfig.C(x=4, tag="left")

with pytest.raises(TypeError):
a.left.x = "a string"


def test_core_generics_recursive_child():
# Testing with a child class on the generic value
a = TreeGenericConfig.C(x=SimpleConfig.C())
a.left = TreeGenericConfig.C(x=SimpleConfig.C())
a.right = TreeGenericConfig.C(x=SimpleConfig.C())

a.left.x = SimpleConfigChild.C()

with pytest.raises(TypeError):
a.left.x = "a string"


U = TypeVar("U", bound=SimpleConfigChild)


class BoundGenericConfig(Config, Generic[U]):
x: Param[U]


def test_core_generics_bound_typevar():
a = BoundGenericConfig.C(x=SimpleConfigChild.C())
assert isinstance(a.x, SimpleConfigChild)
with pytest.raises(TypeError):
a.x = SimpleConfig.C()