Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bfb1dd0
feat: add Guppy type aliases
ss2165 Apr 7, 2026
e3be9ac
fix: Improve cycle detection for recursive type aliases
ss2165 Jun 16, 2026
8aaef5d
feat: Require explicit name argument in type_alias()
ss2165 Jun 16, 2026
0d80d14
feat: Support generic type aliases with params= argument
ss2165 Jun 16, 2026
d89d84b
test: Add error tests for invalid type aliases
ss2165 Jun 16, 2026
42154f3
test: Extend integration tests for type aliases
ss2165 Jun 16, 2026
83b067b
feat: Support Python 3.12+ type statement syntax for aliases
ss2165 Jun 16, 2026
fabcce0
fix: Remove misleading help text from recursive alias error
ss2165 Jun 16, 2026
b006269
fix: Improve cycle error notes to show definition sites Rust-style
ss2165 Jun 16, 2026
647dd00
chore: Apply ruff format
ss2165 Jun 16, 2026
59c1629
test: Add error test for recursive alias through a struct type argument
ss2165 Jun 16, 2026
63b0715
fix: Move type_alias overloads inside py312 guard to avoid single-ove…
ss2165 Jun 16, 2026
c9b4551
fix: Address Copilot review comments
ss2165 Jun 16, 2026
ad3c9b1
test: Remove redundant float alias test
ss2165 Jun 16, 2026
9baaa9b
feat: Remove Python 3.12 type statement syntax for aliases
ss2165 Jun 18, 2026
8b4b169
feat: Support const_var params in type aliases via deferred resolution
ss2165 Jun 18, 2026
3ca3df6
refactor: Use a single TypeParsingCtx when checking type aliases
ss2165 Jun 18, 2026
538bb1a
refactor: Tidy cycle-note construction for recursive aliases
ss2165 Jun 18, 2026
38631d3
refactor: Revert unrelated linecache change in _parse_expr_string
ss2165 Jun 18, 2026
8e384cb
test: Add error tests for direct type_alias argument errors
ss2165 Jun 18, 2026
12a1c8a
fix: Assert note_file is not None rather than silently skipping
ss2165 Jun 19, 2026
0a6f59e
refactor: Drop redundant runtime guards covered by the type signature
ss2165 Jun 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 238 additions & 0 deletions guppylang-internals/src/guppylang_internals/definition/alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import ast
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import ClassVar

from guppylang_internals.ast_util import AstNode, get_file
from guppylang_internals.checker.core import Globals
from guppylang_internals.definition.common import (
CheckableDef,
CompiledDef,
DefId,
ParsableDef,
)
from guppylang_internals.definition.parameter import ParamDef, RawConstVarDef
from guppylang_internals.definition.ty import TypeDef
from guppylang_internals.diagnostic import Error, Note
from guppylang_internals.engine import DEF_STORE
from guppylang_internals.error import GuppyError, InternalGuppyError
from guppylang_internals.span import SourceMap, to_span
from guppylang_internals.tys.arg import Argument
from guppylang_internals.tys.param import Parameter, check_all_args
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
from guppylang_internals.tys.subst import Instantiator
from guppylang_internals.tys.ty import Type

_active_alias_checks: ContextVar[tuple["ParsedTypeAliasDef", ...]] = ContextVar(
"_active_alias_checks", default=()
)


@dataclass(frozen=True)
class RecursiveTypeAliasError(Error):
title: ClassVar[str] = "Recursive type alias"
cycle: tuple[str, ...]

@property
def rendered_span_label(self) -> str:
if len(self.cycle) == 2 and self.cycle[0] == self.cycle[1]:
return f"Type alias `{self.cycle[0]}` expands to itself"
return "Type alias cycle detected:\n" + " -> ".join(
f"`{alias}`" for alias in self.cycle
)

@dataclass(frozen=True)
class AliasNote(Note):
alias_name: str
defn_id: DefId
span_label: ClassVar[str] = "`{alias_name}` defined here"


@dataclass(frozen=True)
class RawTypeAliasDef(TypeDef, ParsableDef):
"""A raw type alias definition that has not been parsed yet."""

type_ast: ast.expr
explicit_params: Sequence[ParamDef] | None = None
params: None = field(default=None, init=False)
description: str = field(default="type alias", init=False)

def parse(self, globals: Globals, sources: SourceMap) -> "ParsedTypeAliasDef":
return ParsedTypeAliasDef(
self.id,
self.name,
self.defined_at,
self.explicit_params,
self.type_ast,
)

def check_instantiate(
self, args: Sequence[Argument], loc: AstNode | None = None
) -> Type:
raise InternalGuppyError("Tried to instantiate raw type alias definition")


def _resolve_param(defn: ParamDef, idx: int, globals: Globals) -> Parameter:
"""Convert a parameter definition to a positional :class:`Parameter`.

``const_var`` definitions arrive unparsed (their type is still a raw AST), so we
parse them here, where the ``globals`` needed to resolve the type are available.
"""
if isinstance(defn, RawConstVarDef):
defn = defn.parse(globals, DEF_STORE.sources)
return defn.to_param(idx)


@dataclass(frozen=True)
class ParsedTypeAliasDef(TypeDef, CheckableDef):
"""A type alias definition whose target type has not been checked yet."""

param_defs: Sequence[ParamDef] | None
type_ast: ast.expr
params: None = field(default=None, init=False)
description: str = field(default="type alias", init=False)

def check(self, globals: Globals) -> "CheckedTypeAliasDef":
if self.param_defs is not None:
# Explicit params: resolve each definition to a parameter (parsing
# `const_var` types now that globals are available) and pre-load them
# into the context so that variables in the body bind to these params
# in order.
resolved = [
_resolve_param(p, i, globals) for i, p in enumerate(self.param_defs)
]
ctx = TypeParsingCtx(
globals, param_var_mapping={p.name: p for p in resolved}
)
check_not_recursive(self, ctx)
ty = type_from_ast(self.type_ast, ctx)
params = tuple(resolved)
else:
# Implicit: collect free type vars from the body in order of appearance.
ctx = TypeParsingCtx(globals, allow_free_vars=True)
check_not_recursive(self, ctx)
ty = type_from_ast(self.type_ast, ctx)
params = tuple(ctx.param_var_mapping.values())
return CheckedTypeAliasDef(
self.id,
self.name,
self.defined_at,
params,
ty,
)

def check_instantiate(
self, args: Sequence[Argument], loc: AstNode | None = None
) -> Type:
globals = Globals(DEF_STORE.frames[self.id])
checked_def = self.check(globals)
return checked_def.check_instantiate(args, loc)


@dataclass(frozen=True)
class CheckedTypeAliasDef(TypeDef, CompiledDef):
"""A fully checked type alias definition."""

params: Sequence[Parameter]
ty: Type
description: str = field(default="type alias", init=False)

def check_instantiate(
self, args: Sequence[Argument], loc: AstNode | None = None
) -> Type:
check_all_args(self.params, args, self.name, loc)
return self.ty.transform(Instantiator(args))


@contextmanager
def _patched_check_instantiate(
defn: ParsedTypeAliasDef,
replacement: Callable[[Sequence[Argument], AstNode | None], Type],
) -> Iterator[None]:
"""Temporarily override `check_instantiate` for recursive-alias detection."""
object.__setattr__(defn, "check_instantiate", replacement)
try:
yield
finally:
# Remove the instance attribute so method resolution falls back to the
# class descriptor, restoring the original behaviour cleanly.
object.__delattr__(defn, "check_instantiate")


def check_not_recursive(defn: ParsedTypeAliasDef, ctx: TypeParsingCtx) -> None:
"""Throws a user error if the given type alias is recursive.

We do not have a separate alias-expansion pass, so we detect recursion by
temporarily swapping out this alias's `check_instantiate` method while parsing its
target type. If parsing the alias body reaches this same alias again, the patched
method fires and turns that recursive re-entry into a user-facing cycle diagnostic.

All cycle notes are attached at once inside `dummy_check_instantiate` so that only
aliases that are actually part of the cycle receive notes (not outer aliases that
merely lead to a cycle).
"""
token = _active_alias_checks.set((*_active_alias_checks.get(), defn))

def dummy_check_instantiate(
args: Sequence[Argument],
loc: AstNode | None = None,
) -> Type:
active = _active_alias_checks.get()
start = next(
i for i, active_defn in enumerate(active) if active_defn.id == defn.id
)
cycle_defs = (*active[start:], defn)
cycle = tuple(d.name for d in cycle_defs)
err = RecursiveTypeAliasError(loc, cycle)
_add_alias_notes_for_cycle(err, cycle_defs)
raise GuppyError(err)

try:
with _patched_check_instantiate(defn, dummy_check_instantiate):
type_from_ast(defn.type_ast, ctx)
finally:
_active_alias_checks.reset(token)


def _add_alias_notes_for_cycle(
err: RecursiveTypeAliasError,
cycle_defs: tuple["ParsedTypeAliasDef", ...],
) -> None:
"""Attach notes for every alias in the cycle in a single pass.

`cycle_defs` is `(A, B, ..., A)` where the first and last element are identical.
We skip self-cycles (only one unique member) since the span label on the error
already says the alias "expands to itself".

Notes are only emitted when the alias definition has a valid, same-file span — i.e.
when the AST node was annotated with file information by `_parse_expr_string`.
Cross-file or un-annotated spans are silently skipped; the cycle chain in the main
error's span label is still fully informative on its own.
"""
unique_defs = cycle_defs[:-1] # drop the repeated last element
if len(unique_defs) <= 1:
return

# File the main error is anchored to. `add_sub_diagnostic` requires every note to
# share this file, so definitions from a different file are skipped below.
err_file = to_span(err.span).file if err.span is not None else None

# The last element of `unique_defs` is the alias whose definition the error span
# already underlines — skip it to avoid a redundant note on the same line. Use
# DefId for deduplication so that aliases with identical names don't collide.
seen_ids: set[DefId] = set()
for defn in unique_defs[:-1]:
if defn.id in seen_ids or defn.defined_at is None:
continue
note_file = get_file(defn.defined_at)
assert note_file is not None, (
f"defined_at node for alias `{defn.name}` has no file annotation"
)
if note_file != err_file:
continue
seen_ids.add(defn.id)
err.add_sub_diagnostic(
RecursiveTypeAliasError.AliasNote(defn.defined_at, defn.name, defn.id)
)
80 changes: 78 additions & 2 deletions guppylang/src/guppylang/decorator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import ast
import builtins
import inspect
from collections.abc import Callable
from collections.abc import Callable, Sequence
from types import FrameType
from typing import Any, NamedTuple, ParamSpec, TypedDict, TypeVar, cast, overload
from typing import (
Any,
NamedTuple,
ParamSpec,
TypedDict,
TypeVar,
cast,
overload,
)

from guppylang_internals.ast_util import annotate_location
from guppylang_internals.definition.alias import RawTypeAliasDef
from guppylang_internals.definition.common import DefId
from guppylang_internals.definition.const import RawConstDef
from guppylang_internals.definition.custom import RawCustomFunctionDef
Expand All @@ -16,6 +25,7 @@
from guppylang_internals.definition.overloaded import OverloadedFunctionDef
from guppylang_internals.definition.parameter import (
ConstVarDef,
ParamDef,
RawConstVarDef,
TypeVarDef,
)
Expand Down Expand Up @@ -365,6 +375,50 @@ def const_var(self, name: str, ty: str) -> TypeVar:
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]

def type_alias(self, name: str, ty: str, params: list[Any] | None = None) -> Any:
"""Creates a new type alias.

.. code-block:: python

from guppylang import guppy, array

Row = guppy.type_alias("Row", "array[int, 4]")

@guppy
def sum_row(row: Row) -> int:
return row[0] + row[1] + row[2] + row[3]

Generic aliases are supported by passing a list of type variables as ``params``.
The order determines how the alias is instantiated (e.g. ``Alias[int, bool]``
binds the first param to ``int`` and the second to ``bool``):

.. code-block:: python

T = guppy.type_var("T")
U = guppy.type_var("U")
Pair = guppy.type_alias("Pair", "tuple[T, U]", params=[T, U])

When ``params`` is omitted, free type variables are collected from the body
in order of first appearance.
"""
frame = get_calling_frame()

type_ast = _parse_expr_string(
ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources
)
explicit_params: Sequence[ParamDef] | None = (
_params_from_list(params) if params is not None else None
)
defn = RawTypeAliasDef(
DefId.fresh(),
name,
type_ast,
type_ast,
explicit_params,
)
DEF_STORE.register_def(defn, frame)
return GuppyDefinition(defn)

@overload
def declare(
self, /, **kwargs: Unpack[GuppyKwargs]
Expand Down Expand Up @@ -789,3 +843,25 @@ def _parse_kwargs(kwargs: GuppyKwargs) -> ParsedGuppyKwargs:


guppy = cast("_Guppy", _DummyGuppy()) if sphinx_running() else _Guppy()


def _params_from_list(params: list[Any]) -> list[ParamDef]:
"""Validate a list of Guppy type-variable definitions for use as alias params.

Each entry must be a type variable created with :func:`guppy.type_var`,
:func:`guppy.nat_var`, or :func:`guppy.const_var`. The underlying
:class:`~guppylang_internals.definition.parameter.ParamDef`\\ s are returned in
order; they are converted to :class:`~guppylang_internals.tys.param.Parameter`\\ s
later (in :meth:`ParsedTypeAliasDef.check`) where the globals needed to resolve
``const_var`` types are available.
"""
result: list[ParamDef] = []
for p in params:
defn = p.wrapped if isinstance(p, GuppyDefinition) else None
if not isinstance(defn, ParamDef):
raise TypeError(
"type_alias params must be type variables created with "
f"guppy.type_var(), guppy.nat_var(), or guppy.const_var(), got {p!r}"
)
result.append(defn)
return result
1 change: 1 addition & 0 deletions tests/error/alias_errors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

15 changes: 15 additions & 0 deletions tests/error/alias_errors/mutual_recursive.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Error: Recursive type alias (at $FILE:5:0)
|
3 |
4 | Alias1 = guppy.type_alias("Alias1", "Alias2")
5 | Alias2 = guppy.type_alias("Alias2", "Alias1")
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Type alias cycle detected:
| `Alias1` -> `Alias2` -> `Alias1`

Note:
|
3 |
4 | Alias1 = guppy.type_alias("Alias1", "Alias2")
| --------------------------------------------- `Alias1` defined here

Guppy compilation failed due to 1 previous error
13 changes: 13 additions & 0 deletions tests/error/alias_errors/mutual_recursive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from guppylang import guppy


Alias1 = guppy.type_alias("Alias1", "Alias2")
Alias2 = guppy.type_alias("Alias2", "Alias1")


@guppy
def main(x: Alias1) -> Alias2:
return x


main.compile_function()
15 changes: 15 additions & 0 deletions tests/error/alias_errors/partial_cycle.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Error: Recursive type alias (at $FILE:6:0)
|
4 | Alias1 = guppy.type_alias("Alias1", "Alias2")
5 | Alias2 = guppy.type_alias("Alias2", "Alias3")
6 | Alias3 = guppy.type_alias("Alias3", "Alias2")
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Type alias cycle detected:
| `Alias2` -> `Alias3` -> `Alias2`

Note:
|
4 | Alias1 = guppy.type_alias("Alias1", "Alias2")
5 | Alias2 = guppy.type_alias("Alias2", "Alias3")
| --------------------------------------------- `Alias2` defined here

Guppy compilation failed due to 1 previous error
Loading
Loading