diff --git a/guppylang-internals/src/guppylang_internals/definition/alias.py b/guppylang-internals/src/guppylang_internals/definition/alias.py new file mode 100644 index 000000000..098f323ef --- /dev/null +++ b/guppylang-internals/src/guppylang_internals/definition/alias.py @@ -0,0 +1,238 @@ +import ast +from collections.abc import Callable, Iterator, Sequence +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import ClassVar + +from guppylang_internals.ast_util import AstNode, get_file +from guppylang_internals.checker.core import Globals +from guppylang_internals.definition.common import ( + CheckableDef, + CompiledDef, + DefId, + ParsableDef, +) +from guppylang_internals.definition.parameter import ParamDef, RawConstVarDef +from guppylang_internals.definition.ty import TypeDef +from guppylang_internals.diagnostic import Error, Note +from guppylang_internals.engine import DEF_STORE +from guppylang_internals.error import GuppyError, InternalGuppyError +from guppylang_internals.span import SourceMap, to_span +from guppylang_internals.tys.arg import Argument +from guppylang_internals.tys.param import Parameter, check_all_args +from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast +from guppylang_internals.tys.subst import Instantiator +from guppylang_internals.tys.ty import Type + +_active_alias_checks: ContextVar[tuple["ParsedTypeAliasDef", ...]] = ContextVar( + "_active_alias_checks", default=() +) + + +@dataclass(frozen=True) +class RecursiveTypeAliasError(Error): + title: ClassVar[str] = "Recursive type alias" + cycle: tuple[str, ...] + + @property + def rendered_span_label(self) -> str: + if len(self.cycle) == 2 and self.cycle[0] == self.cycle[1]: + return f"Type alias `{self.cycle[0]}` expands to itself" + return "Type alias cycle detected:\n" + " -> ".join( + f"`{alias}`" for alias in self.cycle + ) + + @dataclass(frozen=True) + class AliasNote(Note): + alias_name: str + defn_id: DefId + span_label: ClassVar[str] = "`{alias_name}` defined here" + + +@dataclass(frozen=True) +class RawTypeAliasDef(TypeDef, ParsableDef): + """A raw type alias definition that has not been parsed yet.""" + + type_ast: ast.expr + explicit_params: Sequence[ParamDef] | None = None + params: None = field(default=None, init=False) + description: str = field(default="type alias", init=False) + + def parse(self, globals: Globals, sources: SourceMap) -> "ParsedTypeAliasDef": + return ParsedTypeAliasDef( + self.id, + self.name, + self.defined_at, + self.explicit_params, + self.type_ast, + ) + + def check_instantiate( + self, args: Sequence[Argument], loc: AstNode | None = None + ) -> Type: + raise InternalGuppyError("Tried to instantiate raw type alias definition") + + +def _resolve_param(defn: ParamDef, idx: int, globals: Globals) -> Parameter: + """Convert a parameter definition to a positional :class:`Parameter`. + + ``const_var`` definitions arrive unparsed (their type is still a raw AST), so we + parse them here, where the ``globals`` needed to resolve the type are available. + """ + if isinstance(defn, RawConstVarDef): + defn = defn.parse(globals, DEF_STORE.sources) + return defn.to_param(idx) + + +@dataclass(frozen=True) +class ParsedTypeAliasDef(TypeDef, CheckableDef): + """A type alias definition whose target type has not been checked yet.""" + + param_defs: Sequence[ParamDef] | None + type_ast: ast.expr + params: None = field(default=None, init=False) + description: str = field(default="type alias", init=False) + + def check(self, globals: Globals) -> "CheckedTypeAliasDef": + if self.param_defs is not None: + # Explicit params: resolve each definition to a parameter (parsing + # `const_var` types now that globals are available) and pre-load them + # into the context so that variables in the body bind to these params + # in order. + resolved = [ + _resolve_param(p, i, globals) for i, p in enumerate(self.param_defs) + ] + ctx = TypeParsingCtx( + globals, param_var_mapping={p.name: p for p in resolved} + ) + check_not_recursive(self, ctx) + ty = type_from_ast(self.type_ast, ctx) + params = tuple(resolved) + else: + # Implicit: collect free type vars from the body in order of appearance. + ctx = TypeParsingCtx(globals, allow_free_vars=True) + check_not_recursive(self, ctx) + ty = type_from_ast(self.type_ast, ctx) + params = tuple(ctx.param_var_mapping.values()) + return CheckedTypeAliasDef( + self.id, + self.name, + self.defined_at, + params, + ty, + ) + + def check_instantiate( + self, args: Sequence[Argument], loc: AstNode | None = None + ) -> Type: + globals = Globals(DEF_STORE.frames[self.id]) + checked_def = self.check(globals) + return checked_def.check_instantiate(args, loc) + + +@dataclass(frozen=True) +class CheckedTypeAliasDef(TypeDef, CompiledDef): + """A fully checked type alias definition.""" + + params: Sequence[Parameter] + ty: Type + description: str = field(default="type alias", init=False) + + def check_instantiate( + self, args: Sequence[Argument], loc: AstNode | None = None + ) -> Type: + check_all_args(self.params, args, self.name, loc) + return self.ty.transform(Instantiator(args)) + + +@contextmanager +def _patched_check_instantiate( + defn: ParsedTypeAliasDef, + replacement: Callable[[Sequence[Argument], AstNode | None], Type], +) -> Iterator[None]: + """Temporarily override `check_instantiate` for recursive-alias detection.""" + object.__setattr__(defn, "check_instantiate", replacement) + try: + yield + finally: + # Remove the instance attribute so method resolution falls back to the + # class descriptor, restoring the original behaviour cleanly. + object.__delattr__(defn, "check_instantiate") + + +def check_not_recursive(defn: ParsedTypeAliasDef, ctx: TypeParsingCtx) -> None: + """Throws a user error if the given type alias is recursive. + + We do not have a separate alias-expansion pass, so we detect recursion by + temporarily swapping out this alias's `check_instantiate` method while parsing its + target type. If parsing the alias body reaches this same alias again, the patched + method fires and turns that recursive re-entry into a user-facing cycle diagnostic. + + All cycle notes are attached at once inside `dummy_check_instantiate` so that only + aliases that are actually part of the cycle receive notes (not outer aliases that + merely lead to a cycle). + """ + token = _active_alias_checks.set((*_active_alias_checks.get(), defn)) + + def dummy_check_instantiate( + args: Sequence[Argument], + loc: AstNode | None = None, + ) -> Type: + active = _active_alias_checks.get() + start = next( + i for i, active_defn in enumerate(active) if active_defn.id == defn.id + ) + cycle_defs = (*active[start:], defn) + cycle = tuple(d.name for d in cycle_defs) + err = RecursiveTypeAliasError(loc, cycle) + _add_alias_notes_for_cycle(err, cycle_defs) + raise GuppyError(err) + + try: + with _patched_check_instantiate(defn, dummy_check_instantiate): + type_from_ast(defn.type_ast, ctx) + finally: + _active_alias_checks.reset(token) + + +def _add_alias_notes_for_cycle( + err: RecursiveTypeAliasError, + cycle_defs: tuple["ParsedTypeAliasDef", ...], +) -> None: + """Attach notes for every alias in the cycle in a single pass. + + `cycle_defs` is `(A, B, ..., A)` where the first and last element are identical. + We skip self-cycles (only one unique member) since the span label on the error + already says the alias "expands to itself". + + Notes are only emitted when the alias definition has a valid, same-file span — i.e. + when the AST node was annotated with file information by `_parse_expr_string`. + Cross-file or un-annotated spans are silently skipped; the cycle chain in the main + error's span label is still fully informative on its own. + """ + unique_defs = cycle_defs[:-1] # drop the repeated last element + if len(unique_defs) <= 1: + return + + # File the main error is anchored to. `add_sub_diagnostic` requires every note to + # share this file, so definitions from a different file are skipped below. + err_file = to_span(err.span).file if err.span is not None else None + + # The last element of `unique_defs` is the alias whose definition the error span + # already underlines — skip it to avoid a redundant note on the same line. Use + # DefId for deduplication so that aliases with identical names don't collide. + seen_ids: set[DefId] = set() + for defn in unique_defs[:-1]: + if defn.id in seen_ids or defn.defined_at is None: + continue + note_file = get_file(defn.defined_at) + assert note_file is not None, ( + f"defined_at node for alias `{defn.name}` has no file annotation" + ) + if note_file != err_file: + continue + seen_ids.add(defn.id) + err.add_sub_diagnostic( + RecursiveTypeAliasError.AliasNote(defn.defined_at, defn.name, defn.id) + ) diff --git a/guppylang/src/guppylang/decorator.py b/guppylang/src/guppylang/decorator.py index 89a08578d..61f60ad20 100644 --- a/guppylang/src/guppylang/decorator.py +++ b/guppylang/src/guppylang/decorator.py @@ -1,11 +1,20 @@ import ast import builtins import inspect -from collections.abc import Callable +from collections.abc import Callable, Sequence from types import FrameType -from typing import Any, NamedTuple, ParamSpec, TypedDict, TypeVar, cast, overload +from typing import ( + Any, + NamedTuple, + ParamSpec, + TypedDict, + TypeVar, + cast, + overload, +) from guppylang_internals.ast_util import annotate_location +from guppylang_internals.definition.alias import RawTypeAliasDef from guppylang_internals.definition.common import DefId from guppylang_internals.definition.const import RawConstDef from guppylang_internals.definition.custom import RawCustomFunctionDef @@ -16,6 +25,7 @@ from guppylang_internals.definition.overloaded import OverloadedFunctionDef from guppylang_internals.definition.parameter import ( ConstVarDef, + ParamDef, RawConstVarDef, TypeVarDef, ) @@ -365,6 +375,50 @@ def const_var(self, name: str, ty: str) -> TypeVar: # `GuppyDefinition` that pretends to be a TypeVar at runtime return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value] + def type_alias(self, name: str, ty: str, params: list[Any] | None = None) -> Any: + """Creates a new type alias. + + .. code-block:: python + + from guppylang import guppy, array + + Row = guppy.type_alias("Row", "array[int, 4]") + + @guppy + def sum_row(row: Row) -> int: + return row[0] + row[1] + row[2] + row[3] + + Generic aliases are supported by passing a list of type variables as ``params``. + The order determines how the alias is instantiated (e.g. ``Alias[int, bool]`` + binds the first param to ``int`` and the second to ``bool``): + + .. code-block:: python + + T = guppy.type_var("T") + U = guppy.type_var("U") + Pair = guppy.type_alias("Pair", "tuple[T, U]", params=[T, U]) + + When ``params`` is omitted, free type variables are collected from the body + in order of first appearance. + """ + frame = get_calling_frame() + + type_ast = _parse_expr_string( + ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources + ) + explicit_params: Sequence[ParamDef] | None = ( + _params_from_list(params) if params is not None else None + ) + defn = RawTypeAliasDef( + DefId.fresh(), + name, + type_ast, + type_ast, + explicit_params, + ) + DEF_STORE.register_def(defn, frame) + return GuppyDefinition(defn) + @overload def declare( self, /, **kwargs: Unpack[GuppyKwargs] @@ -789,3 +843,25 @@ def _parse_kwargs(kwargs: GuppyKwargs) -> ParsedGuppyKwargs: guppy = cast("_Guppy", _DummyGuppy()) if sphinx_running() else _Guppy() + + +def _params_from_list(params: list[Any]) -> list[ParamDef]: + """Validate a list of Guppy type-variable definitions for use as alias params. + + Each entry must be a type variable created with :func:`guppy.type_var`, + :func:`guppy.nat_var`, or :func:`guppy.const_var`. The underlying + :class:`~guppylang_internals.definition.parameter.ParamDef`\\ s are returned in + order; they are converted to :class:`~guppylang_internals.tys.param.Parameter`\\ s + later (in :meth:`ParsedTypeAliasDef.check`) where the globals needed to resolve + ``const_var`` types are available. + """ + result: list[ParamDef] = [] + for p in params: + defn = p.wrapped if isinstance(p, GuppyDefinition) else None + if not isinstance(defn, ParamDef): + raise TypeError( + "type_alias params must be type variables created with " + f"guppy.type_var(), guppy.nat_var(), or guppy.const_var(), got {p!r}" + ) + result.append(defn) + return result diff --git a/tests/error/alias_errors/__init__.py b/tests/error/alias_errors/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/error/alias_errors/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/error/alias_errors/mutual_recursive.err b/tests/error/alias_errors/mutual_recursive.err new file mode 100644 index 000000000..83cc6d729 --- /dev/null +++ b/tests/error/alias_errors/mutual_recursive.err @@ -0,0 +1,15 @@ +Error: Recursive type alias (at $FILE:5:0) + | +3 | +4 | Alias1 = guppy.type_alias("Alias1", "Alias2") +5 | Alias2 = guppy.type_alias("Alias2", "Alias1") + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Type alias cycle detected: + | `Alias1` -> `Alias2` -> `Alias1` + +Note: + | +3 | +4 | Alias1 = guppy.type_alias("Alias1", "Alias2") + | --------------------------------------------- `Alias1` defined here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/alias_errors/mutual_recursive.py b/tests/error/alias_errors/mutual_recursive.py new file mode 100644 index 000000000..c1bba8bb9 --- /dev/null +++ b/tests/error/alias_errors/mutual_recursive.py @@ -0,0 +1,13 @@ +from guppylang import guppy + + +Alias1 = guppy.type_alias("Alias1", "Alias2") +Alias2 = guppy.type_alias("Alias2", "Alias1") + + +@guppy +def main(x: Alias1) -> Alias2: + return x + + +main.compile_function() diff --git a/tests/error/alias_errors/partial_cycle.err b/tests/error/alias_errors/partial_cycle.err new file mode 100644 index 000000000..ff4e2a441 --- /dev/null +++ b/tests/error/alias_errors/partial_cycle.err @@ -0,0 +1,15 @@ +Error: Recursive type alias (at $FILE:6:0) + | +4 | Alias1 = guppy.type_alias("Alias1", "Alias2") +5 | Alias2 = guppy.type_alias("Alias2", "Alias3") +6 | Alias3 = guppy.type_alias("Alias3", "Alias2") + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Type alias cycle detected: + | `Alias2` -> `Alias3` -> `Alias2` + +Note: + | +4 | Alias1 = guppy.type_alias("Alias1", "Alias2") +5 | Alias2 = guppy.type_alias("Alias2", "Alias3") + | --------------------------------------------- `Alias2` defined here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/alias_errors/partial_cycle.py b/tests/error/alias_errors/partial_cycle.py new file mode 100644 index 000000000..714daef9c --- /dev/null +++ b/tests/error/alias_errors/partial_cycle.py @@ -0,0 +1,14 @@ +from guppylang import guppy + +# Alias1 is outside the cycle, but leads into it (Alias2 <-> Alias3) +Alias1 = guppy.type_alias("Alias1", "Alias2") +Alias2 = guppy.type_alias("Alias2", "Alias3") +Alias3 = guppy.type_alias("Alias3", "Alias2") + + +@guppy +def main(x: Alias1) -> Alias1: + return x + + +main.compile_function() diff --git a/tests/error/alias_errors/recursive.err b/tests/error/alias_errors/recursive.err new file mode 100644 index 000000000..2bfa9eacd --- /dev/null +++ b/tests/error/alias_errors/recursive.err @@ -0,0 +1,8 @@ +Error: Recursive type alias (at $FILE:4:0) + | +2 | +3 | +4 | MyAlias = guppy.type_alias("MyAlias", "MyAlias") + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Type alias `MyAlias` expands to itself + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/alias_errors/recursive.py b/tests/error/alias_errors/recursive.py new file mode 100644 index 000000000..2c04a4e48 --- /dev/null +++ b/tests/error/alias_errors/recursive.py @@ -0,0 +1,12 @@ +from guppylang import guppy + + +MyAlias = guppy.type_alias("MyAlias", "MyAlias") + + +@guppy +def main(x: MyAlias) -> MyAlias: + return x + + +main.compile_function() diff --git a/tests/error/alias_errors/struct_cycle.err b/tests/error/alias_errors/struct_cycle.err new file mode 100644 index 000000000..aec6a1a50 --- /dev/null +++ b/tests/error/alias_errors/struct_cycle.err @@ -0,0 +1,8 @@ +Error: Recursive type alias (at $FILE:14:0) + | +12 | +13 | # Alias whose body refers to itself via a struct type argument +14 | MyAlias = guppy.type_alias("MyAlias", "Box[MyAlias]") + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Type alias `MyAlias` expands to itself + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/alias_errors/struct_cycle.py b/tests/error/alias_errors/struct_cycle.py new file mode 100644 index 000000000..2d301a245 --- /dev/null +++ b/tests/error/alias_errors/struct_cycle.py @@ -0,0 +1,22 @@ +from typing import Generic + +from guppylang import guppy + +T = guppy.type_var("T") + + +@guppy.struct +class Box(Generic[T]): + value: T + + +# Alias whose body refers to itself via a struct type argument +MyAlias = guppy.type_alias("MyAlias", "Box[MyAlias]") + + +@guppy +def f(x: MyAlias) -> MyAlias: + return x + + +f.compile_function() diff --git a/tests/error/alias_errors/too_many_args.err b/tests/error/alias_errors/too_many_args.err new file mode 100644 index 000000000..a5e17ad00 --- /dev/null +++ b/tests/error/alias_errors/too_many_args.err @@ -0,0 +1,9 @@ +Error: Too many type arguments (at $FILE:19:12) + | +17 | +18 | @guppy +19 | def main(b: BoxAlias[int, bool]) -> int: + | ^^^^^^^^^^^^^^^^^^^ Unexpected type argument for type `BoxAlias` (expected 1, + | got 2) + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/alias_errors/too_many_args.py b/tests/error/alias_errors/too_many_args.py new file mode 100644 index 000000000..e3014e09f --- /dev/null +++ b/tests/error/alias_errors/too_many_args.py @@ -0,0 +1,23 @@ +from typing import Generic + +from guppylang import guppy + + +T = guppy.type_var("T") + + +@guppy.struct +class Box(Generic[T]): + value: T + + +# Too many type args for generic alias (Box[T] takes 1, given 2) +BoxAlias = guppy.type_alias("BoxAlias", "Box[T]", params=[T]) + + +@guppy +def main(b: BoxAlias[int, bool]) -> int: + return b.value + + +main.compile_function() diff --git a/tests/error/alias_errors/undefined_type.err b/tests/error/alias_errors/undefined_type.err new file mode 100644 index 000000000..0f114e4e1 --- /dev/null +++ b/tests/error/alias_errors/undefined_type.err @@ -0,0 +1,8 @@ +Error: Variable not defined (at $FILE:5:0) + | +3 | +4 | # Reference to a type that doesn't exist +5 | BadAlias = guppy.type_alias("BadAlias", "NonExistentType") + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `NonExistentType` is not defined + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/alias_errors/undefined_type.py b/tests/error/alias_errors/undefined_type.py new file mode 100644 index 000000000..eec2d12b1 --- /dev/null +++ b/tests/error/alias_errors/undefined_type.py @@ -0,0 +1,13 @@ +from guppylang import guppy + + +# Reference to a type that doesn't exist +BadAlias = guppy.type_alias("BadAlias", "NonExistentType") + + +@guppy +def main(x: BadAlias) -> BadAlias: + return x + + +main.compile_function() diff --git a/tests/error/test_alias_errors.py b/tests/error/test_alias_errors.py new file mode 100644 index 000000000..f62eff754 --- /dev/null +++ b/tests/error/test_alias_errors.py @@ -0,0 +1,47 @@ +import pathlib + +import pytest +from guppylang import guppy + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "alias_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() and x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_alias_errors(file, capsys, snapshot): + run_error_test(file, capsys, snapshot) + + +def test_type_alias_bad_type_syntax(): + with pytest.raises(SyntaxError, match="Not a valid Guppy type: `foo bar`"): + guppy.type_alias("MyAlias", "foo bar") + + +def test_type_alias_invalid_param(): + with pytest.raises( + TypeError, + match="type_alias params must be type variables created with", + ): + guppy.type_alias("MyAlias", "int", params=["not a type var"]) + + +def test_type_alias_param_not_a_param_def(): + # A `GuppyDefinition` that isn't a type variable (e.g. a struct) is rejected. + @guppy.struct + class SomeStruct: + x: int + + with pytest.raises( + TypeError, + match="type_alias params must be type variables created with", + ): + guppy.type_alias("MyAlias", "int", params=[SomeStruct]) diff --git a/tests/integration/test_type_alias.py b/tests/integration/test_type_alias.py new file mode 100644 index 000000000..a52069060 --- /dev/null +++ b/tests/integration/test_type_alias.py @@ -0,0 +1,265 @@ +from typing import Generic + +from guppylang import array, guppy, qubit +from guppylang.std.builtins import owned +from guppylang.std.quantum import discard, measure, x + + +def test_alias_chain(run_int_fn): + """Type aliases can chain through other aliases for scalar types.""" + MyInt = guppy.type_alias("MyInt", "int") + MyOtherInt = guppy.type_alias("MyOtherInt", "MyInt") + + @guppy + def main(x: MyOtherInt) -> MyInt: + return x + 1 + + run_int_fn(main, expected=42, args=[41]) + + +def test_array_alias(validate): + """Type aliases can name nested concrete array types.""" + Row = guppy.type_alias("Row", "array[int, 2]") + Matrix = guppy.type_alias("Matrix", "array[Row, 2]") + + @guppy + def main(xs: Matrix) -> int: + return xs[0][0] + xs[0][1] + xs[1][0] + xs[1][1] + + validate(main.compile_function()) + + +def test_qubit_array_alias(run_int_fn): + """Type aliases preserve owned linear array semantics for qubits.""" + QubitArray = guppy.type_alias("QubitArray", "array[qubit, 2]") + + @guppy + def use_qubits(qs: QubitArray @ owned) -> int: + q1, q2 = qs + discard(q1) + x(q2) + return 1 if measure(q2) else 0 + + @guppy + def main() -> int: + qs: QubitArray = array(qubit(), qubit()) + return use_qubits(qs) + + run_int_fn(main, expected=1, num_qubits=2) + + +def test_generic_struct_alias(run_int_fn): + """Type aliases can refer to concrete instantiations of generic structs.""" + T = guppy.type_var("T") + + @guppy.struct + class Box(Generic[T]): + value: T + + IntBox = guppy.type_alias("IntBox", "Box[int]") + + @guppy + def increment(box: IntBox) -> IntBox: + return Box(box.value + 1) + + @guppy + def main() -> int: + box = increment(Box(41)) + return box.value + + run_int_fn(main, expected=42) + + +def test_explicit_generic_alias_single_param(run_int_fn): + """Generic alias with a single explicit type param can be instantiated.""" + T = guppy.type_var("T") + + @guppy.struct + class Wrapper(Generic[T]): + item: T + + MyWrapper = guppy.type_alias("MyWrapper", "Wrapper[T]", params=[T]) + + @guppy + def make_int_wrapper(v: int) -> MyWrapper[int]: + return Wrapper(v) + + @guppy + def main() -> int: + w = make_int_wrapper(7) + return w.item + + run_int_fn(main, expected=7) + + +def test_explicit_generic_alias_two_params(run_int_fn): + """Generic alias with two explicit params respects given param order.""" + A = guppy.type_var("A") + B = guppy.type_var("B") + + @guppy.struct + class Pair(Generic[A, B]): + first: A + second: B + + # Explicitly reverse the param order: Swap[X, Y] = Pair[Y, X] + Swap = guppy.type_alias("Swap", "Pair[B, A]", params=[A, B]) + + @guppy + def main() -> int: + # Swap[int, bool] → Pair[bool, int] so first is bool, second is int + s: Swap[int, bool] = Pair(True, 42) + return s.second + + run_int_fn(main, expected=42) + + +def test_implicit_generic_alias(run_int_fn): + """When params is omitted, free vars are collected from body in appearance order.""" + T = guppy.type_var("T") + + @guppy.struct + class Box(Generic[T]): + value: T + + # No params= → T is a free var, collected automatically + BoxAlias = guppy.type_alias("BoxAlias", "Box[T]") + + @guppy + def get_value(b: BoxAlias[int]) -> int: + return b.value + + @guppy + def main() -> int: + return get_value(Box(99)) + + run_int_fn(main, expected=99) + + +def test_const_var_alias(run_int_fn): + """Generic aliases can be parameterised by const variables.""" + B = guppy.const_var("B", "bool") + + @guppy.struct + class Flagged(Generic[B]): + value: int + + # Alias parameterised by a const var; resolved lazily when the alias is checked. + MyFlagged = guppy.type_alias("MyFlagged", "Flagged[B]", params=[B]) + + @guppy + def get_value(f: MyFlagged[True]) -> int: + return f.value + + @guppy + def main() -> int: + return get_value(Flagged(7)) + + run_int_fn(main, expected=7) + + +# --------------------------------------------------------------------------- +# Struct / enum interaction tests +# --------------------------------------------------------------------------- + + +def test_alias_in_struct_field(run_int_fn): + """A struct field can be typed with a concrete alias.""" + IntAlias = guppy.type_alias("IntAlias", "int") + + @guppy.struct + class Point: + x: IntAlias + y: IntAlias + + @guppy + def main() -> int: + p = Point(3, 4) + return p.x + p.y + + run_int_fn(main, expected=7) + + +def test_alias_of_struct(run_int_fn): + """An alias can name a concrete struct type and be used transparently.""" + + @guppy.struct + class Vec2: + x: int + y: int + + VecAlias = guppy.type_alias("VecAlias", "Vec2") + + @guppy + def dot(a: VecAlias, b: VecAlias) -> int: + return a.x * b.x + a.y * b.y + + @guppy + def main() -> int: + return dot(Vec2(3, 4), Vec2(1, 2)) + + run_int_fn(main, expected=11) + + +def test_generic_alias_in_struct_field(run_int_fn): + """A generic alias used in a struct field is correctly expanded.""" + T = guppy.type_var("T") + + @guppy.struct + class Box(Generic[T]): + value: T + + Boxed = guppy.type_alias("Boxed", "Box[T]", params=[T]) + + @guppy.struct + class Outer: + inner: Boxed[int] + + @guppy + def main() -> int: + o = Outer(Box(42)) + return o.inner.value + + run_int_fn(main, expected=42) + + +def test_alias_of_enum(validate): + """An alias can name an enum type and be used in function signatures.""" + + @guppy.enum + class Color: + Red = {} + Green = {} + Blue = {} + + @guppy + def tag(self: "Color") -> int: + return 0 + + ColorAlias = guppy.type_alias("ColorAlias", "Color") + + @guppy + def use_color(c: ColorAlias) -> int: + return c.tag() + + @guppy + def main() -> int: + return use_color(Color.Red()) + + validate(main.compile_function()) + + +def test_alias_in_enum_variant_field(validate): + """An enum variant field can be typed with an alias.""" + IntAlias = guppy.type_alias("IntAlias", "int") + + @guppy.enum + class Msg: + Value = {"n": IntAlias} + Empty = {} + + @guppy + def make_value(n: int) -> Msg: + return Msg.Value(n) + + validate(make_value.compile_function())