diff --git a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py index 1796d4464..ca82e058e 100644 --- a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py @@ -15,6 +15,7 @@ from guppylang_internals.cfg.cfg import CFG, BaseCFG from guppylang_internals.checker.core import ( Context, + EffectLimitDecl, Globals, Locals, Place, @@ -76,6 +77,7 @@ def check_cfg( generic_args: dict[str, Argument], func_name: str, globals: Globals, + max_effects_from: EffectLimitDecl | None, first_modifier_node: ast.expr | None = None, ) -> CheckedCFG[Place]: """Instantiates a control-flow graph with the given `generic_args` and then type @@ -100,7 +102,13 @@ def check_cfg( # We start by compiling the entry BB checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty) checked_cfg.entry_bb = check_bb( - cfg.entry_bb, checked_cfg, inputs, return_ty, generic_args, globals + cfg.entry_bb, + checked_cfg, + inputs, + return_ty, + generic_args, + globals, + max_effects_from=max_effects_from, ) compiled = {cfg.entry_bb: checked_cfg.entry_bb} @@ -127,7 +135,13 @@ def check_cfg( else: # Otherwise, check the BB and enqueue its successors checked_bb = check_bb( - bb, checked_cfg, input_row, return_ty, generic_args, globals + bb, + checked_cfg, + input_row, + return_ty, + generic_args, + globals, + max_effects_from=max_effects_from, ) queue += [ # We enumerate the successor starting from the back, so we start with @@ -237,6 +251,7 @@ def check_bb( return_ty: Type, generic_args: dict[str, Argument], globals: Globals, + max_effects_from: EffectLimitDecl | None, ) -> CheckedBB[Variable]: cfg = bb.containing_cfg @@ -261,7 +276,9 @@ def check_bb( raise GuppyError(_assigned_in_modifier_error(x, use, assignment)) # Check the basic block - ctx = Context(globals, Locals({v.name: v for v in inputs}), generic_args) + ctx = Context( + globals, Locals({v.name: v for v in inputs}), generic_args, max_effects_from + ) checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) # If we branch, we also have to check the branch predicate diff --git a/guppylang-internals/src/guppylang_internals/checker/core.py b/guppylang-internals/src/guppylang_internals/checker/core.py index 8b43c027d..c6d71738b 100644 --- a/guppylang-internals/src/guppylang_internals/checker/core.py +++ b/guppylang-internals/src/guppylang_internals/checker/core.py @@ -3,7 +3,7 @@ import itertools from collections.abc import Iterable, Iterator from dataclasses import dataclass, field, replace -from functools import cache, cached_property +from functools import cache, cached_property, reduce from types import FrameType from typing import ( TYPE_CHECKING, @@ -26,9 +26,12 @@ ) from guppylang_internals.engine import BUILTIN_DEFS, DEF_STORE, ENGINE from guppylang_internals.error import InternalGuppyError, RequiresMonomorphizationError +from guppylang_internals.span import Span, to_span +from guppylang_internals.tys import Effect from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg from guppylang_internals.tys.const import BoundConstVar, ConstValue, ExistentialConstVar from guppylang_internals.tys.ty import ( + FunctionType, InputFlags, StructType, Type, @@ -441,6 +444,56 @@ def items(self) -> Iterable[tuple[VId, V]]: return itertools.chain(self.vars.items(), parent_items) +@dataclass(frozen=True) +class EffectLimitDecl: + """Records a declaration limiting the effects that may be used in a Context""" + + effects: list[Effect] + decl: ast.expr | Span + decl_name: str + + @classmethod + def for_def( + cls, ty: FunctionType, func_def: ast.FunctionDef + ) -> "EffectLimitDecl | None": + if ty.declared_effects is None: + return None + if (deco := _find_guppy_decorator(func_def.decorator_list)) is not None: + decl = deco + else: + # Could not identify decorator, so include all in context; union with + # returns will include name etc. inbetween but avoid the function body. + elems = func_def.decorator_list + if func_def.returns is not None: + elems += [func_def.returns] + + def union(s1: Span, s2: Span) -> Span: + r = s1 | s2 + assert r is not None # Function def should not cross file boundary + return r + + decl = reduce(union, (to_span(e) for e in elems)) + + return EffectLimitDecl( + ty.declared_effects, + decl, + func_def.name, + ) + + +def _find_guppy_decorator(decorators: list[ast.expr]) -> ast.expr | None: + for d in decorators: + if ( + isinstance(d, ast.Call) + and isinstance(d.func, ast.Name) + and d.func.id == "guppy" + ): + return d + if isinstance(d, ast.Name) and d.id == "guppy": + return d + return None + + class Context(NamedTuple): """The type checking context.""" @@ -448,6 +501,10 @@ class Context(NamedTuple): locals: Locals[str, Variable] generic_param_inst: dict[str, Argument] + """If not None, the effect constraints that function calls in this context must + respect, together with the AST node that gives rise to said constraint""" + max_effects_from: EffectLimitDecl | None = None + @property def parsing_ctx(self) -> "TypeParsingCtx": """A type parsing context derived from this checking context.""" diff --git a/guppylang-internals/src/guppylang_internals/checker/errors/type_errors.py b/guppylang-internals/src/guppylang_internals/checker/errors/type_errors.py index 5b7ebc0fe..8b819b95e 100644 --- a/guppylang-internals/src/guppylang_internals/checker/errors/type_errors.py +++ b/guppylang-internals/src/guppylang_internals/checker/errors/type_errors.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, ClassVar from guppylang_internals.diagnostic import Error, Help, Note +from guppylang_internals.tys import Effect if TYPE_CHECKING: from guppylang_internals.definition.util import CheckedField @@ -68,6 +69,39 @@ class ConstMismatchError(Error): actual: Const +@dataclass(frozen=True) +class TooManyEffectsError(Error): + title: ClassVar[str] = "Too many effects" + span_label: ClassVar[str] = "{target} not allowed inside `{in_func}`" + callee: str | FunctionType + effects: list[Effect] + in_func: str + + @property + def target(self) -> str: + if isinstance(self.callee, str): + return f"Call to `{self.callee}`" + self.note_effects() + msg = f"Callee of type `{self.callee}`" + if self.callee.declared_effects is None: + # FunctionType that will not display any effects, so list separately + msg += self.note_effects() + return msg + + def note_effects(self) -> str: + return f" has effects `{Effect.format_list(self.effects)}`" + + @dataclass(frozen=True) + class MaxFromDecl(Note): + span_label: ClassVar[str] = "Allowed effects {allowed_effects_str}declared here" + allowed_effects: list[Effect] | None + + @property + def allowed_effects_str(self) -> str: + if self.allowed_effects is None: + return "" + return "`" + Effect.format_list(self.allowed_effects) + "` " + + @dataclass(frozen=True) class AssignFieldTypeMismatchError(Error): title: ClassVar[str] = "Type mismatch" diff --git a/guppylang-internals/src/guppylang_internals/checker/expr_checker.py b/guppylang-internals/src/guppylang_internals/checker/expr_checker.py index 5ee4934cf..38cb581ec 100644 --- a/guppylang-internals/src/guppylang_internals/checker/expr_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/expr_checker.py @@ -83,6 +83,7 @@ NonLinearInstantiateError, NotCallableError, ParameterInferenceError, + TooManyEffectsError, TupleIndexOutOfBoundsError, TypeApplyNotGenericError, TypeInferenceError, @@ -1327,6 +1328,32 @@ def check_comptime_arg( return subst +def _check_effects(func_ty: FunctionType, ctx: Context, node: AstNode) -> None: + """Checks that a function call (AST provided) to a specified FunctionType + respects the effect constraints in the context.""" + if (mf := ctx.max_effects_from) is None: + return + surplus_effects = [e for e in func_ty.effects if e not in mf.effects] + if surplus_effects: + loc_node = node.func if isinstance(node, ast.Call) else node + show_effects_allowed = mf.effects + if isinstance(mf.decl, ast.expr): + # We found the decorator that is the source of the effect constraint, + # which will contain the allowed effects as an explicit argument + show_effects_allowed = None + # Otherwise, the error message points at all decorators, which may or may not + # list the allowed effects, so list them explicitly + + callee = loc_node.id if isinstance(loc_node, ast.Name) else func_ty + raise GuppyTypeError( + TooManyEffectsError( + loc_node, callee, surplus_effects, mf.decl_name + ).add_sub_diagnostic( + TooManyEffectsError.MaxFromDecl(mf.decl, show_effects_allowed) + ) + ) + + def synthesize_call( func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[list[ast.expr], Type, Inst]: @@ -1357,7 +1384,9 @@ def synthesize_call( inst = check_all_solved(subst, free_vars, func_ty, node) # Finally, check that the instantiation respects the linearity requirements + # and the effects allowed in the context. check_inst(func_ty, inst, node) + _check_effects(func_ty, ctx, node) return args, unquantified.output.substitute(subst), inst @@ -1443,7 +1472,9 @@ def check_call( subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars} # Finally, check that the instantiation respects the linearity requirements + # and the effects allowed in the context. check_inst(func_ty, inst, node) + _check_effects(func_ty, ctx, node) return inputs, subst, inst @@ -1565,7 +1596,12 @@ def check_generator( # The rest is checked in a new nested context to ensure that variables don't escape # their scope inner_locals: Locals[str, Variable] = Locals({}, parent_scope=ctx.locals) - inner_ctx = Context(ctx.globals, inner_locals, ctx.generic_param_inst) + inner_ctx = Context( + ctx.globals, + inner_locals, + ctx.generic_param_inst, + ctx.max_effects_from, + ) expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx) gen.iter, iter_ty = expr_sth.visit(gen.iter) gen.iter = with_type(iter_ty, gen.iter) diff --git a/guppylang-internals/src/guppylang_internals/checker/func_checker.py b/guppylang-internals/src/guppylang_internals/checker/func_checker.py index fa0e7d31c..1937d06c9 100644 --- a/guppylang-internals/src/guppylang_internals/checker/func_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/func_checker.py @@ -15,7 +15,13 @@ from guppylang_internals.cfg.bb import BB from guppylang_internals.cfg.builder import CFGBuilder from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg -from guppylang_internals.checker.core import Context, Globals, Place, Variable +from guppylang_internals.checker.core import ( + Context, + EffectLimitDecl, + Globals, + Place, + Variable, +) from guppylang_internals.checker.errors.generic import UnsupportedError from guppylang_internals.checker.unitary_checker import check_invalid_under_dagger from guppylang_internals.definition.common import DefId @@ -159,7 +165,16 @@ def check_global_func_def( generic_args = { param.name: arg for param, arg in zip(generic_ty.params, type_args, strict=True) } - return check_cfg(cfg, inputs, ty.output, generic_args, func_def.name, globals) + max_effects_from = EffectLimitDecl.for_def(ty, func_def) + return check_cfg( + cfg, + inputs, + ty.output, + generic_args, + func_def.name, + globals, + max_effects_from=max_effects_from, + ) def check_nested_func_def( @@ -168,7 +183,13 @@ def check_nested_func_def( ctx: Context, ) -> CheckedNestedFunctionDef: """Type checks a local (nested) function definition.""" - func_ty = check_signature(func_def, ctx.globals) + # For now we assume the nested function has the same effects as that enclosing. + # We could do better by allowing a separate annotation (rather than a parameter + # to @guppy), but we will wait for callgraph analysis to compute precisely: + # nested functions are not part of any public API, so changes are not breaking. + func_ty = check_signature(func_def, ctx.globals).with_effects( + None if ctx.max_effects_from is None else ctx.max_effects_from.effects + ) assert func_ty.input_names is not None if func_ty.parametrized: @@ -247,7 +268,17 @@ def check_nested_func_def( # Otherwise, we treat it like a local name inputs.append(Variable(func_def.name, func_def.ty, func_def)) - checked_cfg = check_cfg(cfg, inputs, func_ty.output, {}, func_def.name, globals) + checked_cfg = check_cfg( + cfg, + inputs, + func_ty.output, + {}, + func_def.name, + globals, + # As comment above, assume nested func has same effects as enclosing + # (hence the decl giving the effects is that of the enclosing func too). + max_effects_from=ctx.max_effects_from, + ) checked_def = CheckedNestedFunctionDef( def_id, checked_cfg, diff --git a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py index 9e7afe3ba..e452b3f08 100644 --- a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py @@ -5,7 +5,7 @@ from guppylang_internals.ast_util import with_loc from guppylang_internals.cfg.bb import BB from guppylang_internals.checker.cfg_checker import check_cfg -from guppylang_internals.checker.core import Context, Variable +from guppylang_internals.checker.core import Context, EffectLimitDecl, Variable from guppylang_internals.checker.unitary_checker import check_invalid_under_dagger from guppylang_internals.definition.common import DefId from guppylang_internals.nodes import CheckedModifiedBlock, ModifiedBlock @@ -19,7 +19,10 @@ def check_modified_block( - modified_block: ModifiedBlock, bb: BB, ctx: Context + modified_block: ModifiedBlock, + bb: BB, + ctx: Context, + max_effects_from: EffectLimitDecl | None, ) -> CheckedModifiedBlock: """Type checks a modifier definition.""" cfg = modified_block.cfg @@ -53,6 +56,7 @@ def check_modified_block( {}, "__modified__()", globals, + max_effects_from=max_effects_from, # We pass the first modifier node for better error messages in the cfg checker first_modifier_node=modified_block.first_modifier_node, ) diff --git a/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py b/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py index 1069b3a4e..2b340351c 100644 --- a/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py @@ -417,7 +417,9 @@ def visit_ModifiedBlock(self, node: ModifiedBlock) -> ast.stmt: raise InternalGuppyError("BB required to check with block!") # check the body of the modified block - checked_modified_block = check_modified_block(node, self.bb, self.ctx) + checked_modified_block = check_modified_block( + node, self.bb, self.ctx, max_effects_from=self.ctx.max_effects_from + ) # check the arguments of the control and power. for control in checked_modified_block.control: diff --git a/guppylang-internals/src/guppylang_internals/decorator.py b/guppylang-internals/src/guppylang_internals/decorator.py index fccc406a4..f59460e6c 100644 --- a/guppylang-internals/src/guppylang_internals/decorator.py +++ b/guppylang-internals/src/guppylang_internals/decorator.py @@ -59,6 +59,7 @@ from collections.abc import Callable, Sequence from types import FrameType + from guppylang_internals.tys import Effect from guppylang_internals.tys.arg import Argument from guppylang_internals.tys.param import Parameter from guppylang_internals.tys.subst import Inst @@ -88,6 +89,7 @@ def custom_function( signature: FunctionType | None = None, unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, has_var_args: bool = False, + effects: list[Effect] | None = None, ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: """Decorator to add custom typing or compilation behaviour to function decls. @@ -112,6 +114,7 @@ def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: signature, unitary_flags, has_var_args, + effects=effects, ) DEF_STORE.register_def(func, get_calling_frame()) return GuppyFunctionDefinition(func) @@ -126,6 +129,7 @@ def hugr_op( name: str = "", signature: FunctionType | None = None, unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, + effects: list[Effect] | None = None, ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: """Decorator to annotate function declarations as HUGR ops. @@ -144,6 +148,7 @@ def hugr_op( name, signature, unitary_flags=unitary_flags, + effects=effects, ) @@ -344,18 +349,17 @@ def dec(cls: builtins.type[T]) -> GuppyDefinition: ) # Add a constructor to the class - if init_arg: - init_fn_ty = FunctionType( - [ - FuncInput( - NumericType(NumericType.Kind.Nat), - flags=InputFlags.Owned, - ) - ], - ext_module_ty, - ) - else: - init_fn_ty = FunctionType([], ext_module_ty) + init_fn_ty = FunctionType( + [ + FuncInput( + NumericType(NumericType.Kind.Nat), + flags=InputFlags.Owned, + ) + ] + if init_arg + else [], + ext_module_ty, + ) call_method = CustomFunctionDef( DefId.fresh(), diff --git a/guppylang-internals/src/guppylang_internals/definition/custom.py b/guppylang-internals/src/guppylang_internals/definition/custom.py index afa0bd55b..8fcf7640d 100644 --- a/guppylang-internals/src/guppylang_internals/definition/custom.py +++ b/guppylang-internals/src/guppylang_internals/definition/custom.py @@ -41,6 +41,7 @@ make_opaque, read_bool, ) +from guppylang_internals.tys import Effect from guppylang_internals.tys.param import Parameter from guppylang_internals.tys.subst import Inst, Subst from guppylang_internals.tys.ty import ( @@ -126,6 +127,8 @@ class RawCustomFunctionDef(ParsableDef): # in Guppy functions in general but some custom functions make use of them). has_var_args: bool = field(default=False) + effects: list[Effect] | None = field(default=None, kw_only=True) + description: str = field(default="function", init=False) @override @@ -149,7 +152,7 @@ def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef": raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name)) sig = self.signature or self._get_signature(func_ast, globals) ty = sig or FunctionType([], NoneType()) - ty = ty.with_unitary_flags(self.unitary_flags) + ty = ty.with_unitary_flags(self.unitary_flags).with_effects(self.effects) return CustomFunctionDef( self.id, self.name, diff --git a/guppylang-internals/src/guppylang_internals/definition/declaration.py b/guppylang-internals/src/guppylang_internals/definition/declaration.py index f2402f0d9..06a992bb9 100644 --- a/guppylang-internals/src/guppylang_internals/definition/declaration.py +++ b/guppylang-internals/src/guppylang_internals/definition/declaration.py @@ -47,6 +47,7 @@ from guppylang_internals.metadata.common import FunctionMetadata, add_metadata from guppylang_internals.nodes import GlobalCall from guppylang_internals.span import SourceMap +from guppylang_internals.tys import Effect from guppylang_internals.tys.param import Parameter from guppylang_internals.tys.subst import Inst, Subst from guppylang_internals.tys.ty import Type, UnitaryFlags @@ -93,13 +94,15 @@ class RawFunctionDecl(ParsableDef, UserProvidedLinkName): metadata: FunctionMetadata | None = field(default=None, kw_only=True) + effects: list[Effect] | None = field(default=None, kw_only=True) + @override def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDecl": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func, sources) ty = check_signature( func_ast, globals, self.id, unitary_flags=self.unitary_flags - ) + ).with_effects(self.effects) link_name = self._user_set_link_name or default_func_link_name(self) # TODO: For the guppylang 1.0 break, we should consider disallowing generic diff --git a/guppylang-internals/src/guppylang_internals/definition/function.py b/guppylang-internals/src/guppylang_internals/definition/function.py index eaf970291..f6c641709 100644 --- a/guppylang-internals/src/guppylang_internals/definition/function.py +++ b/guppylang-internals/src/guppylang_internals/definition/function.py @@ -54,6 +54,7 @@ from guppylang_internals.metadata.common import FunctionMetadata, add_metadata from guppylang_internals.nodes import GlobalCall from guppylang_internals.span import SourceMap, to_span +from guppylang_internals.tys import Effect from guppylang_internals.tys.arg import ConstArg, TypeArg from guppylang_internals.tys.const import ConstValue from guppylang_internals.tys.subst import Inst, Subst @@ -118,13 +119,15 @@ class RawFunctionDef(ParsableDef, UserProvidedLinkName): metadata: FunctionMetadata | None = field(default=None, kw_only=True) + effects: list[Effect] | None = field(default=None, kw_only=True) + @override def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func, sources) ty = check_signature( func_ast, globals, self.id, unitary_flags=self.unitary_flags - ) + ).with_effects(self.effects) link_name = self._user_set_link_name or default_func_link_name(self) return ParsedFunctionDef( @@ -172,7 +175,12 @@ def params(self) -> "Sequence[Parameter]": @override def check(self, type_args: Inst, globals: Globals) -> "CheckedFunctionDef": """Type checks the body of the function.""" - cfg = check_global_func_def(self.defined_at, self.ty, type_args, globals) + cfg = check_global_func_def( + self.defined_at, + self.ty, + type_args, + globals, + ) mono_ty = self.ty.instantiate_partial(type_args) mono_link_name = monomorphized_link_name(self.link_name, type_args) return CheckedFunctionDef( diff --git a/guppylang-internals/src/guppylang_internals/definition/overloaded.py b/guppylang-internals/src/guppylang_internals/definition/overloaded.py index 71783c8ad..84ae721b9 100644 --- a/guppylang-internals/src/guppylang_internals/definition/overloaded.py +++ b/guppylang-internals/src/guppylang_internals/definition/overloaded.py @@ -8,7 +8,7 @@ from typing_extensions import override from guppylang_internals.ast_util import AstNode -from guppylang_internals.checker.core import Context +from guppylang_internals.checker.core import Context, EffectLimitDecl from guppylang_internals.checker.expr_checker import ExprSynthesizer from guppylang_internals.compiler.core import CompilerContext, DFContainer from guppylang_internals.definition.common import ( @@ -23,6 +23,7 @@ from guppylang_internals.diagnostic import Error, Note from guppylang_internals.error import GuppyError, InternalGuppyError from guppylang_internals.span import Span, to_span +from guppylang_internals.tys import Effect from guppylang_internals.tys.printing import signature_to_str from guppylang_internals.tys.subst import Subst from guppylang_internals.tys.ty import FunctionType, Type @@ -39,6 +40,7 @@ class OverloadNoMatchError(Error): func: str arg_tys: list[Type] return_ty: Type | None + max_effects_from: EffectLimitDecl | None @property def rendered_span_label(self) -> str: @@ -53,6 +55,10 @@ def rendered_span_label(self) -> str: stem += f"takes arguments {args}" if self.return_ty: stem += f" and returns `{self.return_ty}`" + if self.max_effects_from: + effects = self.max_effects_from.effects + if Effect.ANY not in effects: + stem += f" with effects no more than `{effects}`" return stem @@ -114,7 +120,7 @@ def check_call( @override def synthesize_call( - self, args: list[ast.expr], node: AstNode, ctx: "Context" + self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[ast.expr, Type]: available_sigs: list[OverloadVariant] = [] for def_id in self.func_ids: @@ -134,7 +140,7 @@ def _call_error( self, args: list[ast.expr], node: AstNode, - ctx: "Context", + ctx: Context, available_sigs: list[OverloadVariant], return_ty: Type | None = None, ) -> NoReturn: @@ -147,7 +153,9 @@ def _call_error( synth = ExprSynthesizer(ctx) arg_tys = [synth.synthesize(arg)[1] for arg in args] - err = OverloadNoMatchError(span, self.name, arg_tys, return_ty) + err = OverloadNoMatchError( + span, self.name, arg_tys, return_ty, ctx.max_effects_from + ) err.add_sub_diagnostic(AvailableOverloadsHint(None, self.name, available_sigs)) raise GuppyError(err) diff --git a/guppylang-internals/src/guppylang_internals/definition/traced.py b/guppylang-internals/src/guppylang_internals/definition/traced.py index f485ab9ac..cffbb25c6 100644 --- a/guppylang-internals/src/guppylang_internals/definition/traced.py +++ b/guppylang-internals/src/guppylang_internals/definition/traced.py @@ -41,6 +41,7 @@ from guppylang_internals.metadata.common import FunctionMetadata, add_metadata from guppylang_internals.nodes import GlobalCall from guppylang_internals.span import SourceMap +from guppylang_internals.tys import Effect from guppylang_internals.tys.subst import Inst, Subst from guppylang_internals.tys.ty import Type, UnitaryFlags, type_to_row @@ -57,12 +58,14 @@ class RawTracedFunctionDef(ParsableDef): metadata: FunctionMetadata | None = field(default=None, kw_only=True) + effects: list[Effect] | None = field(default=None, kw_only=True) + def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef": """Parses and checks the user-provided signature of the function.""" func_ast, _docstring = parse_py_func(self.python_func, sources) ty = check_signature( func_ast, globals, self.id, unitary_flags=self.unitary_flags - ) + ).with_effects(self.effects) if ty.parametrized: raise GuppyError(UnsupportedError(func_ast, "Generic comptime functions")) return TracedFunctionDef( @@ -73,6 +76,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef": self.python_func, unitary_flags=self.unitary_flags, metadata=self.metadata, + effects=self.effects, ) @@ -132,6 +136,7 @@ def compile_outer( func_def, unitary_flags=self.unitary_flags, metadata=self.metadata, + effects=self.effects, ) diff --git a/guppylang-internals/src/guppylang_internals/span.py b/guppylang-internals/src/guppylang_internals/span.py index b35366d9d..d9e1d77ba 100644 --- a/guppylang-internals/src/guppylang_internals/span.py +++ b/guppylang-internals/src/guppylang_internals/span.py @@ -71,6 +71,14 @@ def __and__(self, other: "Span") -> "Span | None": return None return Span(max(self.start, other.start), min(self.end, other.end)) + def __or__(self, other: "Span") -> "Span | None": + """Returns the union with the given span, including any gaps, but `None` + if they are in different files.""" + if self.file != other.file: + return None + r = Span(min(self.start, other.start), max(self.end, other.end)) + return r + def __len__(self) -> int: """Returns the length of a single-line span in columns. @@ -80,6 +88,12 @@ def __len__(self) -> int: raise InternalGuppyError("Span: Tried to compute length of multi-line span") return self.end.column - self.start.column + def __bool__(self) -> bool: + """A span is considered false if it has zero length.""" + # Avoid calling __len__: a multi-line span is considered True + # even though its length is not computable + return self.start != self.end + @property def file(self) -> str: """The file containing this span.""" diff --git a/guppylang-internals/src/guppylang_internals/tracing/function.py b/guppylang-internals/src/guppylang_internals/tracing/function.py index f21cd8a4a..c21c25928 100644 --- a/guppylang-internals/src/guppylang_internals/tracing/function.py +++ b/guppylang-internals/src/guppylang_internals/tracing/function.py @@ -9,6 +9,7 @@ from guppylang_internals.checker.core import ( ComptimeVariable, Context, + EffectLimitDecl, Globals, Locals, Variable, @@ -72,7 +73,10 @@ def trace_function( Invokes the passed Python callable and constructs the corresponding Hugr using the passed builder. """ - state = TracingState(ctx, DFContainer(builder, ctx, {}), node, func_def) + max_effects = EffectLimitDecl.for_def(ty, func_def.defined_at) + state = TracingState( + ctx, DFContainer(builder, ctx, {}), node, func_def, max_effects=max_effects + ) with set_tracing_state(state): inputs = [ unpack_guppy_object( @@ -179,7 +183,12 @@ def trace_call(func: CallableDef, *args: Any) -> Any: arg_exprs: list[ast.expr] = [ with_loc(state.node, with_type(var.ty, PlaceNode(var))) for var in arg_vars ] - ctx = Context(Globals(DEF_STORE.frames[func.id]), locals, {}) + ctx = Context( + Globals(DEF_STORE.frames[func.id]), + locals, + {}, + max_effects_from=state.max_effects, + ) call_node, ret_ty = func.synthesize_call(arg_exprs, state.node, ctx) # Here we check if unitary constraints are respected in the function body diff --git a/guppylang-internals/src/guppylang_internals/tracing/state.py b/guppylang-internals/src/guppylang_internals/tracing/state.py index e1969bc5c..8d11905c2 100644 --- a/guppylang-internals/src/guppylang_internals/tracing/state.py +++ b/guppylang-internals/src/guppylang_internals/tracing/state.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from guppylang_internals.ast_util import AstNode +from guppylang_internals.checker.core import EffectLimitDecl from guppylang_internals.compiler.core import CompilerContext, DFContainer from guppylang_internals.definition.traced import CompiledTracedFunctionDef from guppylang_internals.error import InternalGuppyError @@ -29,6 +30,9 @@ class TracingState: #: The function definition currently being traced. function_definition: CompiledTracedFunctionDef + #: The maximum effects that the currently traced function is allowed to perform + max_effects: EffectLimitDecl | None + #: Set of all allocated undroppable GuppyObjects where the `used` flag is not set, #: indexed by their id. This is used to detect linearity violations. unused_undroppable_objs: "dict[GuppyObjectId, GuppyObject]" = field( diff --git a/guppylang-internals/src/guppylang_internals/tys/__init__.py b/guppylang-internals/src/guppylang_internals/tys/__init__.py index e69de29bb..37ae36a26 100644 --- a/guppylang-internals/src/guppylang_internals/tys/__init__.py +++ b/guppylang-internals/src/guppylang_internals/tys/__init__.py @@ -0,0 +1,17 @@ +from collections.abc import Iterable +from enum import Enum + + +class Effect(Enum): + ANY = "Any" + + @classmethod + def __from_str__(cls, s: str) -> "Effect": + for effect in cls: + if effect.name == s: + return effect + raise ValueError(f"Invalid effect name: {s}") + + @staticmethod + def format_list(effects: Iterable["Effect"]) -> str: + return f"[{', '.join(e.name for e in effects)}]" diff --git a/guppylang-internals/src/guppylang_internals/tys/errors.py b/guppylang-internals/src/guppylang_internals/tys/errors.py index 7ed321359..7286446cd 100644 --- a/guppylang-internals/src/guppylang_internals/tys/errors.py +++ b/guppylang-internals/src/guppylang_internals/tys/errors.py @@ -172,6 +172,28 @@ class ComptimeArgShadowError(Error): arg: str +@dataclass(frozen=True) +class EffectsNotApplicableError(Error): + title: ClassVar[str] = "Invalid annotation" + span_label: ClassVar[str] = "Effects may be applied only to a `Callable` type" + + +@dataclass(frozen=True) +class EffectsRepeatedError(Error): + title: ClassVar[str] = "Invalid annotation" + span_label: ClassVar[str] = ( + "Effects have already been applied to this `Callable` type" + ) + + +@dataclass(frozen=True) +class InvalidEffectError(Error): + title: ClassVar[str] = "Invalid annotation" + span_label: ClassVar[str] = "Not a valid effect: `{arg}`" + # We could perhaps provide a list of possible effects? + arg: str + + @dataclass(frozen=True) class InvalidFlagError(Error): title: ClassVar[str] = "Invalid annotation" diff --git a/guppylang-internals/src/guppylang_internals/tys/parsing.py b/guppylang-internals/src/guppylang_internals/tys/parsing.py index 4a771462b..d8695e1c4 100644 --- a/guppylang-internals/src/guppylang_internals/tys/parsing.py +++ b/guppylang-internals/src/guppylang_internals/tys/parsing.py @@ -20,6 +20,7 @@ from guppylang_internals.engine import ENGINE from guppylang_internals.error import GuppyError from guppylang_internals.experimental import check_unitary_callable_enabled +from guppylang_internals.tys import Effect from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg from guppylang_internals.tys.builtin import ( CallableTypeDef, @@ -30,11 +31,14 @@ from guppylang_internals.tys.errors import ( CallableComptimeError, ComptimeArgShadowError, + EffectsNotApplicableError, + EffectsRepeatedError, FlagNotAllowedError, FreeTypeVarError, HigherKindedTypeVarError, IllegalComptimeTypeArgError, InvalidCallableTypeError, + InvalidEffectError, InvalidFlagError, InvalidTypeArgError, InvalidTypeError, @@ -497,6 +501,29 @@ def type_with_flags_from_ast( flags |= InputFlags.Comptime if not ty.copyable or not ty.droppable: raise GuppyError(LinearComptimeError(node.right, ty)) + case ast.Call(func=ast.Name(id="effects")) as fx: + if not isinstance(ty, FunctionType): + raise GuppyError(EffectsNotApplicableError(node.right)) + if ty.declared_effects is not None: + raise GuppyError(EffectsRepeatedError(node.right)) + effects: list[Effect] = [] + if ( + len(fx.args) == 1 + and isinstance(fx.args[0], ast.Constant) + and fx.args[0].value is None + ): + effects = [] + else: + for e in fx.args: + # We might want to support ast.Attribute with LHS "Effects" + # and look at RHS + if not isinstance(e, ast.Name): + raise GuppyError(InvalidEffectError(node.right, str(e))) + try: + effects.append(Effect.__from_str__(e.id)) + except ValueError: + raise GuppyError(InvalidEffectError(node.right, e.id)) # noqa: B904 + ty = ty.with_effects(effects) case _: raise GuppyError(InvalidFlagError(node.right)) return ty, flags diff --git a/guppylang-internals/src/guppylang_internals/tys/printing.py b/guppylang-internals/src/guppylang_internals/tys/printing.py index b92e06031..1525a66ad 100644 --- a/guppylang-internals/src/guppylang_internals/tys/printing.py +++ b/guppylang-internals/src/guppylang_internals/tys/printing.py @@ -1,6 +1,7 @@ from functools import singledispatchmethod from guppylang_internals.error import InternalGuppyError +from guppylang_internals.tys import Effect from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg from guppylang_internals.tys.const import Const, ConstValue from guppylang_internals.tys.param import ConstParam, TypeParam @@ -95,6 +96,11 @@ def _visit_FunctionType(self, ty: FunctionType, inside_row: bool) -> str: if len(ty.inputs) != 1: inputs = f"({inputs})" output = self._visit(ty.output, True) + arrow = ( + "->" + if ty.declared_effects is None + else f"-{Effect.format_list(ty.declared_effects)}->" + ) if ty.parametrized: params = [ self._visit(param, False) @@ -104,8 +110,10 @@ def _visit_FunctionType(self, ty: FunctionType, inside_row: bool) -> str: ] quantified = ", ".join(params) del self.bound_names[: -len(ty.params)] - return _wrap(f"forall {quantified}. {inputs} -> {output}", inside_row) - return _wrap(f"{inputs} -> {output}", inside_row) + desc = f"forall {quantified}. {inputs} {arrow} {output}" + else: + desc = f"{inputs} {arrow} {output}" + return _wrap(desc, inside_row) @_visit.register(OpaqueType) @_visit.register(StructType) @@ -168,4 +176,5 @@ def signature_to_str(name: str, sig: FunctionType, has_var_args: bool = False) - for inp in sig.inputs ) s += ", ..." if has_var_args else "" + # TODO Not clear how to display effects in a Python-like syntax? (skip for now) return s + ") -> " + str(sig.output) diff --git a/guppylang-internals/src/guppylang_internals/tys/ty.py b/guppylang-internals/src/guppylang_internals/tys/ty.py index b40ee28e7..70af1e29d 100644 --- a/guppylang-internals/src/guppylang_internals/tys/ty.py +++ b/guppylang-internals/src/guppylang_internals/tys/ty.py @@ -11,6 +11,7 @@ from typing_extensions import assert_never from guppylang_internals.error import InternalGuppyError +from guppylang_internals.tys import Effect from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg from guppylang_internals.tys.common import ( ToHugr, @@ -485,6 +486,17 @@ class FunctionType(ParametrizedTypeBase): unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, init=True) + """ Effects declared in source code, i.e. `Callable[...] @ effects(EFFECTS)`. + None means there was no declaration, which is equivalent to [Effect.ANY] + except for error reporting. Generally use `effects` instead.""" + declared_effects: list[Effect] | None = field(default=None, init=True) + + @property + def effects(self) -> list[Effect]: + return ( + self.declared_effects if self.declared_effects is not None else [Effect.ANY] + ) + def __init__( self, inputs: Sequence[FuncInput], @@ -492,6 +504,7 @@ def __init__( params: Sequence[Parameter] | None = None, comptime_args: Sequence[ConstArg] | None = None, unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, + declared_effects: list[Effect] | None = None, ) -> None: # We need a custom __init__ to set the args args: list[Argument] = [TypeArg(inp.ty) for inp in inputs] @@ -520,6 +533,7 @@ def __init__( object.__setattr__(self, "output", output) object.__setattr__(self, "params", params) object.__setattr__(self, "unitary_flags", unitary_flags) + object.__setattr__(self, "declared_effects", declared_effects) @property def parametrized(self) -> bool: @@ -573,6 +587,8 @@ def _to_hugr_function_type(self, ctx: ToHugrContext) -> ht.FunctionType: The resulting `FunctionType` can then be embedded into a Hugr `Type` or a Hugr `PolyFuncType`. """ + # At some point we may want to represent the effects as input and + # perhaps output "token" types in Hugr, but for now we will use Order edges. ins = [ inp.ty.to_hugr(ctx) for inp in self.inputs @@ -607,6 +623,7 @@ def transform(self, transformer: Transformer) -> "Type": self.params, comptime_args=self.comptime_args, unitary_flags=self.unitary_flags, + declared_effects=self.declared_effects, ) def instantiate_partial(self, args: "PartialInst") -> "FunctionType": @@ -636,6 +653,7 @@ def instantiate_partial(self, args: "PartialInst") -> "FunctionType": cast("ConstArg", arg.transform(inst)) for arg in self.comptime_args ], unitary_flags=self.unitary_flags, + declared_effects=self.declared_effects, ) def instantiate(self, args: "Inst") -> "FunctionType": @@ -666,6 +684,24 @@ def with_unitary_flags(self, flags: UnitaryFlags) -> "FunctionType": self.params, self.comptime_args, flags, + declared_effects=self.declared_effects, + ) + + def with_effects(self, declared_effects: list[Effect] | None) -> "FunctionType": + """Returns a copy of this function type with the specified effects.""" + # N.B. we can't use `dataclasses.replace` here since `FunctionType` has a custom + # constructor + if self.declared_effects is not None: + raise InternalGuppyError( + "Tried to set effects on a FunctionType that already has them" + ) + return FunctionType( + self.inputs, + self.output, + self.params, + self.comptime_args, + self.unitary_flags, + declared_effects=declared_effects, ) @@ -944,6 +980,12 @@ def unify(s: Type | Const, t: Type | Const, subst: "Subst | None") -> "Subst | N case FunctionType() as s, FunctionType() as t if s.params == t.params: if len(s.inputs) != len(t.inputs): return None + if set(s.effects) != set(t.effects): + # There are no "effect variables" yet, and we enforce exact matching + # (invariance) as covariance will become difficult when we replace Order + # edges with explicit tokens. (Requiring runtime closures or codegen for + # a statically-predictable function being assigned.) + return None for a, b in zip(s.inputs, t.inputs, strict=True): if a.ty.linear and b.ty.linear and a.flags != b.flags: return None diff --git a/guppylang/src/guppylang/__init__.py b/guppylang/src/guppylang/__init__.py index 104b844ee..d9e839926 100644 --- a/guppylang/src/guppylang/__init__.py +++ b/guppylang/src/guppylang/__init__.py @@ -1,12 +1,13 @@ from guppylang_internals.experimental import enable_experimental_features -from guppylang.decorator import guppy +from guppylang.decorator import Effect, guppy from guppylang.module import GuppyModule from guppylang.std import builtins, debug, quantum from guppylang.std.builtins import array, comptime, py from guppylang.std.quantum import qubit __all__ = ( + "Effect", "GuppyModule", "array", "builtins", diff --git a/guppylang/src/guppylang/decorator.py b/guppylang/src/guppylang/decorator.py index 51b2bf3cb..ee49b2d5e 100644 --- a/guppylang/src/guppylang/decorator.py +++ b/guppylang/src/guppylang/decorator.py @@ -3,7 +3,15 @@ import inspect from collections.abc import Callable, Sequence from types import FrameType -from typing import Any, NamedTuple, ParamSpec, TypedDict, TypeVar, cast, overload +from typing import ( + Any, + NamedTuple, + ParamSpec, + TypedDict, + TypeVar, + cast, + overload, +) from guppylang_internals.ast_util import annotate_location from guppylang_internals.compiler.core import ( @@ -46,6 +54,7 @@ from guppylang_internals.metadata.common import FunctionMetadata from guppylang_internals.span import Loc, SourceMap, Span from guppylang_internals.tracing.util import hide_trace +from guppylang_internals.tys import Effect # Re-exported from guppylang_internals.tys.arg import Argument from guppylang_internals.tys.param import Parameter from guppylang_internals.tys.subst import Inst @@ -85,7 +94,12 @@ OverloadedFunctionDef, ) -__all__ = ("GuppyKwargs", "custom_guppy_decorator", "guppy") +__all__ = ( + "Effect", # Re-export + "GuppyKwargs", + "custom_guppy_decorator", + "guppy", +) class GuppyKwargs(TypedDict, total=False): @@ -98,6 +112,8 @@ class GuppyKwargs(TypedDict, total=False): daggerable: bool max_qubits: int link_name: str + # effects=None means no effects, distinct from not specifying effects= at all + effects: list[Effect] | Effect | None class GuppyStructKwargs(TypedDict, total=False): @@ -146,6 +162,7 @@ def decorator( unitary_flags=parsed.flags, metadata=parsed.metadata, link_name=parsed.link_name, + effects=parsed.effects, ) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn) @@ -191,6 +208,7 @@ def decorator( f, unitary_flags=parsed.flags, metadata=parsed.metadata, + effects=parsed.effects, ) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn) @@ -456,6 +474,7 @@ def decorator( unitary_flags=parsed.flags, link_name=parsed.link_name, metadata=parsed.metadata, + effects=parsed.effects, ) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn) @@ -826,6 +845,9 @@ def _with_optional_kwargs( class ParsedGuppyKwargs(NamedTuple): flags: UnitaryFlags metadata: FunctionMetadata + # The empty list means no effects, whereas None means unspecified - i.e. assume all + # effects are possible until we can analyse the call-graph to calculate exactly. + effects: list[Effect] | None link_name: str | None @@ -848,6 +870,20 @@ def _parse_kwargs(kwargs: GuppyKwargs) -> ParsedGuppyKwargs: link_name = kwargs.pop("link_name", None) + effects: list[Effect] | None + if "effects" in kwargs: + max_effects_input = kwargs.pop("effects") + effects = ( + [] + if max_effects_input is None + else [max_effects_input] + if isinstance(max_effects_input, Effect) + else max_effects_input + ) + else: + # Not specified + effects = None + if remaining := next(iter(kwargs), None): err = f"Unknown keyword argument: `{remaining}`" raise TypeError(err) @@ -856,6 +892,7 @@ def _parse_kwargs(kwargs: GuppyKwargs) -> ParsedGuppyKwargs: flags=flags, metadata=metadata, link_name=link_name, + effects=effects, ) diff --git a/guppylang/src/guppylang/std/effects.py b/guppylang/src/guppylang/std/effects.py new file mode 100644 index 000000000..98808bc1b --- /dev/null +++ b/guppylang/src/guppylang/std/effects.py @@ -0,0 +1,21 @@ +from typing import Any + +from guppylang.decorator import Effect + + +class Effects: + """Dummy class to support `@effects` annotations.""" + + effects: list[Effect] + + def __init__(self, *effects: Effect) -> None: + self.effects = list(effects) + + def __rmatmul__(self, other: Any) -> Any: + # This method is to make the Python interpreter happy + return other + + +effects = Effects + +ANY = Effect.ANY diff --git a/guppylang/src/guppylang/std/num.py b/guppylang/src/guppylang/std/num.py index 631abd7c7..176f6b211 100644 --- a/guppylang/src/guppylang/std/num.py +++ b/guppylang/src/guppylang/std/num.py @@ -171,7 +171,7 @@ class int: @hugr_op(int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!) def __abs__(self: int) -> int: ... - @hugr_op(int_op("iadd")) + @hugr_op(int_op("iadd"), effects=[]) # Annotation done early for use in tests def __add__(self: int, other: int) -> int: ... @hugr_op(int_op("iand")) diff --git a/tests/error/effects_errors/comptime_pure_calls_impure.err b/tests/error/effects_errors/comptime_pure_calls_impure.err new file mode 100644 index 000000000..32b6f4266 --- /dev/null +++ b/tests/error/effects_errors/comptime_pure_calls_impure.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 11, in + main.compile() + File "$FILE", line 9, in main + return impure_func(5) +guppylang_internals.error.GuppyComptimeError: Too many effects: Callee of type `int -> int` has effects `[ANY]` not allowed inside `main` diff --git a/tests/error/effects_errors/comptime_pure_calls_impure.py b/tests/error/effects_errors/comptime_pure_calls_impure.py new file mode 100644 index 000000000..f17ebd731 --- /dev/null +++ b/tests/error/effects_errors/comptime_pure_calls_impure.py @@ -0,0 +1,11 @@ +from guppylang.decorator import guppy + +@guppy.comptime +def impure_func(x: int) -> int: + return x + 1 + +@guppy.comptime(effects=[]) +def main() -> int: + return impure_func(5) + +main.compile() diff --git a/tests/error/effects_errors/effects_on_int.err b/tests/error/effects_errors/effects_on_int.err new file mode 100644 index 000000000..691f02342 --- /dev/null +++ b/tests/error/effects_errors/effects_on_int.err @@ -0,0 +1,8 @@ +Error: Invalid annotation (at $FILE:8:26) + | +6 | # This says the return type (not the function) has effects +7 | @guppy +8 | def main(x: int) -> int @ effects(): + | ^^^^^^^^^ Effects may be applied only to a `Callable` type + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/effects_on_int.py b/tests/error/effects_errors/effects_on_int.py new file mode 100644 index 000000000..1666c4131 --- /dev/null +++ b/tests/error/effects_errors/effects_on_int.py @@ -0,0 +1,11 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.std.effects import effects + +# This says the return type (not the function) has effects +@guppy +def main(x: int) -> int @ effects(): + return x + 1 + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/misnamed_effects.err b/tests/error/effects_errors/misnamed_effects.err new file mode 100644 index 000000000..23ea0ea21 --- /dev/null +++ b/tests/error/effects_errors/misnamed_effects.err @@ -0,0 +1,8 @@ +Error: Invalid annotation (at $FILE:7:36) + | +5 | +6 | @guppy.declare +7 | def foo() -> Callable[[int], int] @ effects(ALL): + | ^^^^^^^^^^^^ Not a valid effect: `ALL` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/misnamed_effects.py b/tests/error/effects_errors/misnamed_effects.py new file mode 100644 index 000000000..5f276c33d --- /dev/null +++ b/tests/error/effects_errors/misnamed_effects.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.std.effects import effects, ANY as ALL + +@guppy.declare +def foo() -> Callable[[int], int] @ effects(ALL): + ... + +@guppy +def main() -> int: + return foo()(5) + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/overload.err b/tests/error/effects_errors/overload.err new file mode 100644 index 000000000..9bbd80b13 --- /dev/null +++ b/tests/error/effects_errors/overload.err @@ -0,0 +1,14 @@ +Error: Invalid call of overloaded function (at $FILE:24:11) + | +22 | @guppy(effects=[]) +23 | def bad_pure_func(x: float) -> float: +24 | return only_pure_for_int(x) + | ^^^^^^^^^^^^^^^^^^^^ No variant of overloaded function `only_pure_for_int` takes + | a `float` argument and returns `float` with effects no more + | than `[]` + +Note: Available overloads are: + def only_pure_for_int(x: T) -> T + def only_pure_for_int(x: int) -> int + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/overload.py b/tests/error/effects_errors/overload.py new file mode 100644 index 000000000..4cbb459ae --- /dev/null +++ b/tests/error/effects_errors/overload.py @@ -0,0 +1,32 @@ +from guppylang.decorator import guppy + +T = guppy.type_var("T") + +@guppy.declare +def variant1(x : T) -> T: ... + +@guppy.declare(effects=[]) +def variant2(x : int) -> int: ... + +@guppy.overload(variant1, variant2) +def only_pure_for_int(): ... + +@guppy(effects=[]) +def pure_func(x: int) -> int: + return only_pure_for_int(x + 1) + +@guppy +def impure_func(x: float) -> float: + return only_pure_for_int(x + 1.0) + +@guppy(effects=[]) +def bad_pure_func(x: float) -> float: + return only_pure_for_int(x) + +@guppy +def main() -> None: + pure_func(5) + impure_func(5.0) + bad_pure_func(3.14) + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/pure_calls_explicit_callable.err b/tests/error/effects_errors/pure_calls_explicit_callable.err new file mode 100644 index 000000000..9ec6e30e9 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_explicit_callable.err @@ -0,0 +1,14 @@ +Error: Too many effects (at $FILE:8:10) + | +6 | @guppy(effects=[]) +7 | def main(impure_f: Callable[[int], int] @ effects(ANY)) -> int: +8 | return impure_f(5) + | ^^^^^^^^ Callee of type `int -[ANY]-> int` not allowed inside `main` + +Note: + | +5 | +6 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_explicit_callable.py b/tests/error/effects_errors/pure_calls_explicit_callable.py new file mode 100644 index 000000000..d3a5a60d1 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_explicit_callable.py @@ -0,0 +1,10 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.std.effects import effects, ANY + +@guppy(effects=[]) +def main(impure_f: Callable[[int], int] @ effects(ANY)) -> int: + return impure_f(5) + +main.compile() diff --git a/tests/error/effects_errors/pure_calls_explicit_callable_comptime.err b/tests/error/effects_errors/pure_calls_explicit_callable_comptime.err new file mode 100644 index 000000000..9ec6e30e9 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_explicit_callable_comptime.err @@ -0,0 +1,14 @@ +Error: Too many effects (at $FILE:8:10) + | +6 | @guppy(effects=[]) +7 | def main(impure_f: Callable[[int], int] @ effects(ANY)) -> int: +8 | return impure_f(5) + | ^^^^^^^^ Callee of type `int -[ANY]-> int` not allowed inside `main` + +Note: + | +5 | +6 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_explicit_callable_comptime.py b/tests/error/effects_errors/pure_calls_explicit_callable_comptime.py new file mode 100644 index 000000000..d3a5a60d1 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_explicit_callable_comptime.py @@ -0,0 +1,10 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.std.effects import effects, ANY + +@guppy(effects=[]) +def main(impure_f: Callable[[int], int] @ effects(ANY)) -> int: + return impure_f(5) + +main.compile() diff --git a/tests/error/effects_errors/pure_calls_explicit_decl.err b/tests/error/effects_errors/pure_calls_explicit_decl.err new file mode 100644 index 000000000..ae32326b9 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_explicit_decl.err @@ -0,0 +1,15 @@ +Error: Too many effects (at $FILE:8:10) + | +6 | @guppy(effects=[]) +7 | def main() -> int: +8 | return impure_func(5) + | ^^^^^^^^^^^ Call to `impure_func` has effects `[ANY]` not allowed inside + | `main` + +Note: + | +5 | +6 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_explicit_decl.py b/tests/error/effects_errors/pure_calls_explicit_decl.py new file mode 100644 index 000000000..29facc00f --- /dev/null +++ b/tests/error/effects_errors/pure_calls_explicit_decl.py @@ -0,0 +1,10 @@ +from guppylang.decorator import guppy, Effect + +@guppy.declare(effects=[Effect.ANY]) +def impure_func(x: int) -> int: ... + +@guppy(effects=[]) +def main() -> int: + return impure_func(5) + +main.compile() diff --git a/tests/error/effects_errors/pure_calls_explicit_def.err b/tests/error/effects_errors/pure_calls_explicit_def.err new file mode 100644 index 000000000..b24b4cced --- /dev/null +++ b/tests/error/effects_errors/pure_calls_explicit_def.err @@ -0,0 +1,15 @@ +Error: Too many effects (at $FILE:9:10) + | +7 | @guppy(effects=[]) +8 | def main() -> int: +9 | return impure_func(5) + | ^^^^^^^^^^^ Call to `impure_func` has effects `[ANY]` not allowed inside + | `main` + +Note: + | +6 | +7 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_explicit_def.py b/tests/error/effects_errors/pure_calls_explicit_def.py new file mode 100644 index 000000000..97478da44 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_explicit_def.py @@ -0,0 +1,11 @@ +from guppylang.decorator import guppy, Effect + +@guppy(effects=[Effect.ANY]) +def impure_func(x: int) -> int: + return x + 1 + +@guppy(effects=[]) +def main() -> int: + return impure_func(5) + +main.compile() diff --git a/tests/error/effects_errors/pure_calls_impure_callable.err b/tests/error/effects_errors/pure_calls_impure_callable.err new file mode 100644 index 000000000..e1376c5d1 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_callable.err @@ -0,0 +1,15 @@ +Error: Too many effects (at $FILE:7:10) + | +5 | @guppy(effects=[]) +6 | def main(impure_f: Callable[[int], int]) -> int: +7 | return impure_f(5) + | ^^^^^^^^ Callee of type `int -> int` has effects `[ANY]` not allowed + | inside `main` + +Note: + | +4 | +5 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_impure_callable.py b/tests/error/effects_errors/pure_calls_impure_callable.py new file mode 100644 index 000000000..eb8e4cb97 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_callable.py @@ -0,0 +1,9 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy + +@guppy(effects=[]) +def main(impure_f: Callable[[int], int]) -> int: + return impure_f(5) + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/pure_calls_impure_callable_comptime.err b/tests/error/effects_errors/pure_calls_impure_callable_comptime.err new file mode 100644 index 000000000..e1376c5d1 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_callable_comptime.err @@ -0,0 +1,15 @@ +Error: Too many effects (at $FILE:7:10) + | +5 | @guppy(effects=[]) +6 | def main(impure_f: Callable[[int], int]) -> int: +7 | return impure_f(5) + | ^^^^^^^^ Callee of type `int -> int` has effects `[ANY]` not allowed + | inside `main` + +Note: + | +4 | +5 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_impure_callable_comptime.py b/tests/error/effects_errors/pure_calls_impure_callable_comptime.py new file mode 100644 index 000000000..eb8e4cb97 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_callable_comptime.py @@ -0,0 +1,9 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy + +@guppy(effects=[]) +def main(impure_f: Callable[[int], int]) -> int: + return impure_f(5) + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/pure_calls_impure_comptime.err b/tests/error/effects_errors/pure_calls_impure_comptime.err new file mode 100644 index 000000000..b24b4cced --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_comptime.err @@ -0,0 +1,15 @@ +Error: Too many effects (at $FILE:9:10) + | +7 | @guppy(effects=[]) +8 | def main() -> int: +9 | return impure_func(5) + | ^^^^^^^^^^^ Call to `impure_func` has effects `[ANY]` not allowed inside + | `main` + +Note: + | +6 | +7 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_impure_comptime.py b/tests/error/effects_errors/pure_calls_impure_comptime.py new file mode 100644 index 000000000..46f7b402f --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_comptime.py @@ -0,0 +1,11 @@ +from guppylang.decorator import guppy + +@guppy.comptime +def impure_func(x: int) -> int: + return x + 1 + +@guppy(effects=[]) +def main() -> int: + return impure_func(5) + +main.compile() diff --git a/tests/error/effects_errors/pure_calls_impure_custom_def.err b/tests/error/effects_errors/pure_calls_impure_custom_def.err new file mode 100644 index 000000000..0771e85a4 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_custom_def.err @@ -0,0 +1,17 @@ +Error: Too many effects (at $FILE:12:10) + | +10 | @custom_pure +11 | def main() -> int: +12 | return impure_func(5) + | ^^^^^^^^^^^ Call to `impure_func` has effects `[ANY]` not allowed inside + | `main` + +Note: + | + 9 | +10 | @custom_pure + | ----------- +11 | def main() -> int: + | ----------------- Allowed effects `[]` declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_impure_custom_def.py b/tests/error/effects_errors/pure_calls_impure_custom_def.py new file mode 100644 index 000000000..c0dc31092 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_custom_def.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy + +def custom_pure(func): + return guppy(effects=[])(func) + +@guppy +def impure_func(x: int) -> int: + return x + 1 + +@custom_pure +def main() -> int: + return impure_func(5) + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/pure_calls_impure_decl.err b/tests/error/effects_errors/pure_calls_impure_decl.err new file mode 100644 index 000000000..ae32326b9 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_decl.err @@ -0,0 +1,15 @@ +Error: Too many effects (at $FILE:8:10) + | +6 | @guppy(effects=[]) +7 | def main() -> int: +8 | return impure_func(5) + | ^^^^^^^^^^^ Call to `impure_func` has effects `[ANY]` not allowed inside + | `main` + +Note: + | +5 | +6 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_impure_decl.py b/tests/error/effects_errors/pure_calls_impure_decl.py new file mode 100644 index 000000000..2bda26cc7 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_decl.py @@ -0,0 +1,10 @@ +from guppylang.decorator import guppy + +@guppy.declare +def impure_func(x: int) -> int: ... + +@guppy(effects=[]) +def main() -> int: + return impure_func(5) + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/pure_calls_impure_def.err b/tests/error/effects_errors/pure_calls_impure_def.err new file mode 100644 index 000000000..b24b4cced --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_def.err @@ -0,0 +1,15 @@ +Error: Too many effects (at $FILE:9:10) + | +7 | @guppy(effects=[]) +8 | def main() -> int: +9 | return impure_func(5) + | ^^^^^^^^^^^ Call to `impure_func` has effects `[ANY]` not allowed inside + | `main` + +Note: + | +6 | +7 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_calls_impure_def.py b/tests/error/effects_errors/pure_calls_impure_def.py new file mode 100644 index 000000000..3bbef1121 --- /dev/null +++ b/tests/error/effects_errors/pure_calls_impure_def.py @@ -0,0 +1,11 @@ +from guppylang.decorator import guppy + +@guppy +def impure_func(x: int) -> int: + return x + 1 + +@guppy(effects=[]) +def main() -> int: + return impure_func(5) + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/pure_comptime_calls_impure_decl.err b/tests/error/effects_errors/pure_comptime_calls_impure_decl.err new file mode 100644 index 000000000..32b6f4266 --- /dev/null +++ b/tests/error/effects_errors/pure_comptime_calls_impure_decl.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 11, in + main.compile() + File "$FILE", line 9, in main + return impure_func(5) +guppylang_internals.error.GuppyComptimeError: Too many effects: Callee of type `int -> int` has effects `[ANY]` not allowed inside `main` diff --git a/tests/error/effects_errors/pure_comptime_calls_impure_decl.py b/tests/error/effects_errors/pure_comptime_calls_impure_decl.py new file mode 100644 index 000000000..6c2341ae2 --- /dev/null +++ b/tests/error/effects_errors/pure_comptime_calls_impure_decl.py @@ -0,0 +1,11 @@ +from guppylang.decorator import guppy + +@guppy.declare +def impure_func(x: int) -> int: + ... + +@guppy.comptime(effects=[]) +def main() -> int: + return impure_func(5) + +main.compile() diff --git a/tests/error/effects_errors/pure_comptime_calls_impure_def.err b/tests/error/effects_errors/pure_comptime_calls_impure_def.err new file mode 100644 index 000000000..32b6f4266 --- /dev/null +++ b/tests/error/effects_errors/pure_comptime_calls_impure_def.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 11, in + main.compile() + File "$FILE", line 9, in main + return impure_func(5) +guppylang_internals.error.GuppyComptimeError: Too many effects: Callee of type `int -> int` has effects `[ANY]` not allowed inside `main` diff --git a/tests/error/effects_errors/pure_comptime_calls_impure_def.py b/tests/error/effects_errors/pure_comptime_calls_impure_def.py new file mode 100644 index 000000000..87fa05951 --- /dev/null +++ b/tests/error/effects_errors/pure_comptime_calls_impure_def.py @@ -0,0 +1,11 @@ +from guppylang.decorator import guppy + +@guppy +def impure_func(x: int) -> int: + return x + 1 + +@guppy.comptime(effects=None) +def main() -> int: + return impure_func(5) + +main.compile() diff --git a/tests/error/effects_errors/pure_result.err b/tests/error/effects_errors/pure_result.err new file mode 100644 index 000000000..4cc73f9ed --- /dev/null +++ b/tests/error/effects_errors/pure_result.err @@ -0,0 +1,19 @@ +Error: Invalid call of overloaded function (at $FILE:6:10) + | +4 | @guppy(effects=[]) +5 | def main() -> int: +6 | result("foo", True) + | ^^^^^^^^^^^ No variant of overloaded function `result` takes arguments + | `str`, `bool` with effects no more than `[]` + +Note: Available overloads are: + def result(tag: str @comptime, value: int) -> None + def result(tag: str @comptime, value: nat) -> None + def result(tag: str @comptime, value: bool) -> None + def result(tag: str @comptime, value: float) -> None + def result(tag: str @comptime, value: array[int, n]) -> None + def result(tag: str @comptime, value: array[nat, n]) -> None + def result(tag: str @comptime, value: array[bool, n]) -> None + def result(tag: str @comptime, value: array[float, n]) -> None + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/pure_result.py b/tests/error/effects_errors/pure_result.py new file mode 100644 index 000000000..3d535bafd --- /dev/null +++ b/tests/error/effects_errors/pure_result.py @@ -0,0 +1,9 @@ +from guppylang.decorator import guppy +from guppylang.std.builtins import result + +@guppy(effects=[]) +def main() -> int: + result("foo", True) + return 3 + +main.compile_function() \ No newline at end of file diff --git a/tests/error/effects_errors/repeated_effects.err b/tests/error/effects_errors/repeated_effects.err new file mode 100644 index 000000000..7f402150b --- /dev/null +++ b/tests/error/effects_errors/repeated_effects.err @@ -0,0 +1,8 @@ +Error: Invalid annotation (at $FILE:7:48) + | +5 | +6 | @guppy.declare +7 | def foo() -> Callable[[int], int] @ effects() @ effects(ANY): + | ^^^^^^^^^^^^ Effects have already been applied to this `Callable` type + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/repeated_effects.py b/tests/error/effects_errors/repeated_effects.py new file mode 100644 index 000000000..8ef0708b6 --- /dev/null +++ b/tests/error/effects_errors/repeated_effects.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.std.effects import effects, ANY + +@guppy.declare +def foo() -> Callable[[int], int] @ effects() @ effects(ANY): + ... + +@guppy +def main() -> int: + return foo()(5) + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/return_explicit_callable.err b/tests/error/effects_errors/return_explicit_callable.err new file mode 100644 index 000000000..4d5d950d1 --- /dev/null +++ b/tests/error/effects_errors/return_explicit_callable.err @@ -0,0 +1,9 @@ +Error: Type mismatch (at $FILE:12:10) + | +10 | @guppy +11 | def main() -> Callable[[int], int] @effects(): +12 | return impure_func + | ^^^^^^^^^^^ Expected return value of type `int -[]-> int`, got `int + | -[ANY]-> int` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/return_explicit_callable.py b/tests/error/effects_errors/return_explicit_callable.py new file mode 100644 index 000000000..f3a215aa5 --- /dev/null +++ b/tests/error/effects_errors/return_explicit_callable.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy, Effect +from guppylang.std.effects import effects + +@guppy(effects=[Effect.ANY]) +def impure_func(x: int) -> int: + return x + 1 + +@guppy +def main() -> Callable[[int], int] @effects(): + return impure_func + +main.compile() diff --git a/tests/error/effects_errors/return_explicit_callable_comptime.err b/tests/error/effects_errors/return_explicit_callable_comptime.err new file mode 100644 index 000000000..f3373a7c4 --- /dev/null +++ b/tests/error/effects_errors/return_explicit_callable_comptime.err @@ -0,0 +1,11 @@ +Error: Type mismatch (at $FILE:11:0) + | + 9 | +10 | @guppy.comptime +11 | def main() -> Callable[[int], int] @effects(): + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +12 | return impure_func + | ^^^^^^^^^^^^^^^^^^^^^ Expected return value of type `int -[]-> int`, got `int + | -[ANY]-> int` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/return_explicit_callable_comptime.py b/tests/error/effects_errors/return_explicit_callable_comptime.py new file mode 100644 index 000000000..f050f09d3 --- /dev/null +++ b/tests/error/effects_errors/return_explicit_callable_comptime.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy, Effect +from guppylang.std.effects import effects + +@guppy(effects=[Effect.ANY]) +def impure_func(x: int) -> int: + return x + 1 + +@guppy.comptime +def main() -> Callable[[int], int] @effects(): + return impure_func + +main.compile() diff --git a/tests/error/effects_errors/return_impure_callable.err b/tests/error/effects_errors/return_impure_callable.err new file mode 100644 index 000000000..05d5b87ad --- /dev/null +++ b/tests/error/effects_errors/return_impure_callable.err @@ -0,0 +1,9 @@ +Error: Type mismatch (at $FILE:12:10) + | +10 | @guppy +11 | def main() -> Callable[[int], int] @ effects(): +12 | return impure_func + | ^^^^^^^^^^^ Expected return value of type `int -[]-> int`, got `int -> + | int` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/return_impure_callable.py b/tests/error/effects_errors/return_impure_callable.py new file mode 100644 index 000000000..928e45061 --- /dev/null +++ b/tests/error/effects_errors/return_impure_callable.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.std.effects import effects + +@guppy +def impure_func(x: int) -> int: + return x + 1 + +@guppy +def main() -> Callable[[int], int] @ effects(): + return impure_func + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/return_impure_callable_comptime.err b/tests/error/effects_errors/return_impure_callable_comptime.err new file mode 100644 index 000000000..cf9240fe0 --- /dev/null +++ b/tests/error/effects_errors/return_impure_callable_comptime.err @@ -0,0 +1,11 @@ +Error: Type mismatch (at $FILE:11:0) + | + 9 | +10 | @guppy.comptime +11 | def main() -> Callable[[int], int] @ effects(): + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +12 | return impure_func + | ^^^^^^^^^^^^^^^^^^^^^ Expected return value of type `int -[]-> int`, got `int -> + | int` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/return_impure_callable_comptime.py b/tests/error/effects_errors/return_impure_callable_comptime.py new file mode 100644 index 000000000..25570a8c7 --- /dev/null +++ b/tests/error/effects_errors/return_impure_callable_comptime.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.std.effects import effects + +@guppy +def impure_func(x: int) -> int: + return x + 1 + +@guppy.comptime +def main() -> Callable[[int], int] @ effects(): + return impure_func + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/return_pure_callable.err b/tests/error/effects_errors/return_pure_callable.err new file mode 100644 index 000000000..e81b94dbe --- /dev/null +++ b/tests/error/effects_errors/return_pure_callable.err @@ -0,0 +1,9 @@ +Error: Type mismatch (at $FILE:12:10) + | +10 | @guppy +11 | def main() -> Callable[[int], int]: +12 | return pure_func + | ^^^^^^^^^ Expected return value of type `int -> int`, got `int -[]-> + | int` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/return_pure_callable.py b/tests/error/effects_errors/return_pure_callable.py new file mode 100644 index 000000000..4f2552ad3 --- /dev/null +++ b/tests/error/effects_errors/return_pure_callable.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy + +@guppy(effects=[]) +def pure_func(x: int) -> int: + return x + 1 + +# This is an error because we enforce invariance of Callable types. +@guppy +def main() -> Callable[[int], int]: + return pure_func + +main.compile() \ No newline at end of file diff --git a/tests/error/effects_errors/return_pure_callable_comptime.err b/tests/error/effects_errors/return_pure_callable_comptime.err new file mode 100644 index 000000000..3aabc3944 --- /dev/null +++ b/tests/error/effects_errors/return_pure_callable_comptime.err @@ -0,0 +1,11 @@ +Error: Type mismatch (at $FILE:11:0) + | + 9 | # This is an error because we enforce invariance of Callable types. +10 | @guppy.comptime +11 | def main() -> Callable[[int], int]: + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +12 | return pure_func + | ^^^^^^^^^^^^^^^^^^^ Expected return value of type `int -> int`, got `int -[]-> + | int` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/return_pure_callable_comptime.py b/tests/error/effects_errors/return_pure_callable_comptime.py new file mode 100644 index 000000000..525a652aa --- /dev/null +++ b/tests/error/effects_errors/return_pure_callable_comptime.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy + +@guppy(effects=[]) +def pure_func(x: int) -> int: + return x + 1 + +# This is an error because we enforce invariance of Callable types. +@guppy.comptime +def main() -> Callable[[int], int]: + return pure_func + +main.compile() \ No newline at end of file diff --git a/tests/error/modifier_errors/higher_order.err b/tests/error/modifier_errors/higher_order.err new file mode 100644 index 000000000..c472124ba --- /dev/null +++ b/tests/error/modifier_errors/higher_order.err @@ -0,0 +1,8 @@ +Error: Dagger constraint violation (at $FILE:10:4) + | + 8 | @guppy(dagger=True) + 9 | def test_ho(f: Callable[[qubit], None], q: qubit) -> None: +10 | f(q) + | ^^^^ This function cannot be called in a dagger context + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/poly_errors/non_linear2.err b/tests/error/poly_errors/non_linear2.err index 0d9219d2e..9b3a8bc58 100644 --- a/tests/error/poly_errors/non_linear2.err +++ b/tests/error/poly_errors/non_linear2.err @@ -1,8 +1,8 @@ -Error: Expected a copyable type (at $FILE:15:8) +Error: Expected a copyable type (at $FILE:18:8) | -13 | @guppy -14 | def main() -> None: -15 | foo(h) +16 | @guppy +17 | def main() -> None: +18 | foo(h) | ^ Expected a copyable type, got type `qubit` which is not | implicitly copyable diff --git a/tests/error/poly_errors/non_linear2.py b/tests/error/poly_errors/non_linear2.py index 7fb3d6e63..9efcf0f5b 100644 --- a/tests/error/poly_errors/non_linear2.py +++ b/tests/error/poly_errors/non_linear2.py @@ -2,10 +2,13 @@ from guppylang.decorator import guppy from guppylang.std.quantum.functional import h +from guppylang.std.effects import effects T = guppy.type_var("T") +# Pending https://github.com/Quantinuum/guppylang/issues/1760 we need to explicitly +# declare the effects of `x` @guppy.declare def foo(x: Callable[[T], T]) -> None: ... diff --git a/tests/error/test_effects_errors.py b/tests/error/test_effects_errors.py new file mode 100644 index 000000000..f97f2ca19 --- /dev/null +++ b/tests/error/test_effects_errors.py @@ -0,0 +1,19 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "effects_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() and x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_effects_errors(file, capsys, snapshot): + run_error_test(file, capsys, snapshot) diff --git a/tests/error/type_errors/fun_ty_mismatch_4.err b/tests/error/type_errors/fun_ty_mismatch_4.err new file mode 100644 index 000000000..eff122445 --- /dev/null +++ b/tests/error/type_errors/fun_ty_mismatch_4.err @@ -0,0 +1,8 @@ +Error: Type mismatch (at $FILE:14:11) + | +12 | return x +13 | +14 | return bar + | ^^^ Expected return value of type `nat -> int`, got `nat -> nat` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/type_errors/fun_ty_mismatch_4.py b/tests/error/type_errors/fun_ty_mismatch_4.py new file mode 100644 index 000000000..fbd36be49 --- /dev/null +++ b/tests/error/type_errors/fun_ty_mismatch_4.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.std.builtins import nat + +from tests.util import compile_guppy + + +@compile_guppy +def foo() -> Callable[[nat], int]: + # This has a narrower return type, but we enforce invariance of Callable types, so this is still an error. + def bar(x: nat) -> nat: + return x + + return bar diff --git a/tests/integration/test_comptime_effects.py b/tests/integration/test_comptime_effects.py new file mode 100644 index 000000000..0ac1feb93 --- /dev/null +++ b/tests/integration/test_comptime_effects.py @@ -0,0 +1,167 @@ +"""Tests of effects annotation for comptime callers/callees.""" + +import pytest + +from guppylang.decorator import Effect, guppy +from guppylang.std.builtins import result +from guppylang.std.effects import ANY + + +def test_pure_from_impure_comptime(validate): + @guppy(effects=None) + def pure_func(x: int) -> int: + return x + 1 + + @guppy.comptime + def normal_func(x: int) -> int: + return pure_func(x) + 2 + + validate(normal_func.compile_function()) + + +def test_pure_comptime_from_impure(validate): + @guppy.comptime(effects=None) + def pure_func(x: int) -> int: + return x + 1 + + @guppy + def normal_func(x: int) -> int: + return pure_func(x) + 2 + + validate(normal_func.compile_function()) + + +def test_comptime_pure_from_impure(validate): + @guppy.comptime(effects=[]) + def pure_func(x: int) -> int: + return x + 1 + + @guppy.comptime + def normal_func(x: int) -> int: + return pure_func(x) + 2 + + validate(normal_func.compile_function()) + + +def test_pure_from_explicit_comptime_impure(validate): + @guppy(effects=[]) + def pure_func(x: int) -> int: + return x + 1 + + @guppy.comptime(effects=ANY) + def normal_func(x: int) -> int: + return pure_func(x) + 2 + + validate(normal_func.compile_function()) + + +def test_pure_comptime_from_explicit_impure(validate): + @guppy.comptime(effects=None) + def pure_func(x: int) -> int: + return x + 1 + + @guppy(effects=[Effect.ANY]) + def normal_func(x: int) -> int: + return pure_func(x) + 2 + + validate(normal_func.compile_function()) + + +def test_comptime_pure_from_explicit_impure(validate): + @guppy.comptime(effects=[]) + def pure_func(x: int) -> int: + return x + 1 + + @guppy.comptime + def normal_func(x: int) -> int: + return pure_func(x) + 2 + + validate(normal_func.compile_function()) + + +@pytest.mark.parametrize( + ("caller_flags", "callee"), + [ + ({"effects": [Effect.ANY]}, {}), + ({}, {"effects": [Effect.ANY]}), + ({"effects": [Effect.ANY]}, {"effects": [Effect.ANY]}), + ], +) +def test_impure_explicit_from_comptime(caller_flags, callee, validate): + @guppy(**callee) + def impure_func(x: int) -> int: + result("tag", x) + return x + 3 + + @guppy.comptime(**caller_flags) + def caller(x: int) -> int: + return impure_func(x) + 1 + + validate(caller.compile_function()) + + +@pytest.mark.parametrize( + ("caller", "callee"), + [ + ({"effects": [Effect.ANY]}, {}), + ({}, {"effects": [Effect.ANY]}), + ({"effects": [Effect.ANY]}, {"effects": [Effect.ANY]}), + ], +) +@pytest.mark.parametrize( + ("caller_deco", "callee_deco"), + [ + (guppy.comptime, guppy), + (guppy, guppy.comptime), + (guppy.comptime, guppy.comptime), + ], +) +def test_impure_explicit_comptime_callee( + caller, callee, caller_deco, callee_deco, validate +): + @callee_deco(**callee) + def impure_func(x: int) -> int: + result("tag", x) + return x + 3 + + @caller_deco(**caller) + def impure_func2(x: int) -> int: + return impure_func(x) + 1 + + validate(impure_func2.compile_function()) + + +def test_pure_from_pure_comptime(validate): + @guppy(effects=[]) + def pure_func1(x: int) -> int: + return x + 1 + + @guppy.comptime(effects=None) + def pure_func2(x: int) -> int: + return pure_func1(pure_func1(x)) + 1 + + validate(pure_func2.compile_function()) + + +def test_pure_comptime_from_pure(validate): + @guppy.comptime(effects=None) + def pure_func1(x: int) -> int: + return x + 1 + + @guppy(effects=[]) + def pure_func2(x: int) -> int: + return pure_func1(pure_func1(x)) + 1 + + validate(pure_func2.compile_function()) + + +def test_comptime_pure_from_pure(validate): + @guppy.comptime(effects=[]) + def pure_func1(x: int) -> int: + return x + 1 + + @guppy.comptime(effects=[]) + def pure_func2(x: int) -> int: + return pure_func1(pure_func1(x)) + 1 + + validate(pure_func2.compile_function()) diff --git a/tests/integration/test_effects.py b/tests/integration/test_effects.py new file mode 100644 index 000000000..6e5962840 --- /dev/null +++ b/tests/integration/test_effects.py @@ -0,0 +1,165 @@ +"""Tests of effects annotation.""" + +import pytest +from collections.abc import Callable + +from guppylang.decorator import guppy, Effect +from guppylang.std.builtins import result +from guppylang.std.effects import effects, ANY + + +def test_pure_decl_from_impure(validate): + @guppy.declare(effects=[]) + def pure_func(x: int) -> int: ... + + @guppy + def impure_func(x: int) -> int: + return pure_func(x) + 1 + + validate(impure_func.compile_function()) + + +def test_pure_decl_from_explicit_impure(validate): + @guppy.declare(effects=[]) + def pure_func(x: int) -> int: ... + + @guppy(effects=Effect.ANY) + def impure_func(x: int) -> int: + return pure_func(x) + 1 + + validate(impure_func.compile_function()) + + +def test_pure_decl_from_pure(validate): + @guppy.declare(effects=[]) + def pure_func1(x: int) -> int: ... + + @guppy(effects=[]) + def pure_func2(x: int) -> int: + return pure_func1(x) + 2 + + validate(pure_func2.compile_function()) + + +@pytest.mark.parametrize( + ("caller", "callee"), + [ + ({"effects": [Effect.ANY]}, {}), + ({}, {"effects": [Effect.ANY]}), + ({"effects": [Effect.ANY]}, {"effects": [Effect.ANY]}), + ], +) +def test_impure_decl_explicit(caller, callee, validate): + @guppy.declare(**callee) + def impure_func(x: int) -> int: ... + + @guppy(**caller) + def impure_func2(x: int) -> int: + return impure_func(x) + 1 + + validate(impure_func2.compile_function()) + + +def test_pure_from_impure(validate): + @guppy(effects=None) + def pure_func(x: int) -> int: + return x + 1 + + @guppy + def normal_func(x: int) -> int: + return pure_func(x) + 2 + + validate(normal_func.compile_function()) + + +def test_pure_from_explicit_impure(validate): + @guppy(effects=[]) + def pure_func(x: int) -> int: + return x + 1 + + @guppy(effects=[Effect.ANY]) + def normal_func(x: int) -> int: + return pure_func(x) + 2 + + validate(normal_func.compile_function()) + + +@pytest.mark.parametrize( + ("caller", "callee"), + [ + ({"effects": [Effect.ANY]}, {}), + ({}, {"effects": [Effect.ANY]}), + ({"effects": [Effect.ANY]}, {"effects": [Effect.ANY]}), + ], +) +def test_impure_explicit(caller, callee, validate): + @guppy(**callee) + def impure_func(x: int) -> int: + result("tag", x) + return x + 3 + + @guppy(**caller) + def impure_func2(x: int) -> int: + return impure_func(x) + 1 + + validate(impure_func2.compile_function()) + + +def test_pure_from_pure(validate): + @guppy(effects=[]) + def pure_func1(x: int) -> int: + return x + 1 + + @guppy(effects=[]) + def pure_func2(x: int) -> int: + return pure_func1(pure_func1(x)) + 1 + + validate(pure_func2.compile_function()) + + +def test_pure_callable_from_impure(validate): + @guppy + def impure_func(pure_f: Callable[[int], int] @ effects()) -> int: + return pure_f(5) + 1 + + validate(impure_func.compile_function()) + + +def test_pure_callable_from_pure(validate): + @guppy(effects=[]) + def pure_func(pure_f: Callable[[int], int] @ effects(None)) -> int: + return pure_f(5) + 1 + + validate(pure_func.compile_function()) + + +def test_pure_callable_from_impure_explicit(validate): + @guppy(effects=[Effect.ANY]) + def impure_func(pure_f: Callable[[int], int] @ effects()) -> int: + return pure_f(5) + 1 + + validate(impure_func.compile_function()) + + +def test_return_callable1(validate): + @guppy + def impure_func(x: int) -> int: + return x + 1 + + @guppy(effects=[]) + def higher_order() -> Callable[[int], int] @ effects(ANY): + return impure_func + + validate(higher_order.compile_function()) + + +def test_return_callable2(validate): + @guppy(effects=[Effect.ANY]) + def explicit_impure_func(x: int) -> int: + return x + 1 + + @guppy(effects=[]) + def higher_order() -> Callable[[int], int]: + return explicit_impure_func + + validate(higher_order.compile_function()) diff --git a/tests/integration/test_enum.py b/tests/integration/test_enum.py index 92fedd5b9..b302ba621 100644 --- a/tests/integration/test_enum.py +++ b/tests/integration/test_enum.py @@ -23,11 +23,8 @@ from guppylang import guppy from tests.util import compile_guppy -from typing import Generic, TYPE_CHECKING - - -if TYPE_CHECKING: - from collections.abc import Callable +from typing import Generic +from collections.abc import Callable def test_basic_enum(validate): @@ -250,7 +247,7 @@ class Enum(Generic[T]): # pyright: ignore[reportInvalidTypeForm] VariantA = {"x": T} @guppy - def factory(mk_enum: "Callable[[int], Enum[int]]", x: int) -> Enum[int]: + def factory(mk_enum: Callable[[int], Enum[int]], x: int) -> Enum[int]: return mk_enum(x) @guppy diff --git a/tests/integration/test_higher_order.py b/tests/integration/test_higher_order.py index 5c38a2654..10e124f8e 100644 --- a/tests/integration/test_higher_order.py +++ b/tests/integration/test_higher_order.py @@ -1,6 +1,8 @@ +import pytest + from collections.abc import Callable -from guppylang.decorator import guppy +from guppylang.decorator import guppy, Effect from tests.util import compile_guppy @@ -174,3 +176,42 @@ def fac(x: int) -> int: return Y(fac_)(x) validate(fac.compile_function()) + + +# This should be combined with `test_higher_order_effects2` once we have solved +# https://github.com/Quantinuum/guppylang/issues/1760 +# but presently exists to show the part of the test that *does* work +def test_higher_order_effects1(validate): + @guppy(effects=[Effect.ANY]) + def impure_func(x: int) -> int: + return x + 1 + + # Same def as `test_higher_order_effects2` + @guppy + def higher_order(f: Callable[[int], int], x: int) -> int: + return f(x) + + @guppy + def main() -> int: + return higher_order(impure_func, 5) + + validate(main.compile_function()) + + +@pytest.mark.xfail(reason="Pending https://github.com/Quantinuum/guppylang/issues/1760") +def test_higher_order_effects2(validate): + @guppy(effects=[]) + def pure_func(x: int) -> int: + return x + 1 + + @guppy # we'd love this to be "as pure as f is", but no way to do that yet. + # (Alternatively https://github.com/Quantinuum/guppylang/issues/1752 will allow + # explicitly declaring such effect-polymorphism, but that won't parse yet) + def higher_order(f: Callable[[int], int], x: int) -> int: + return f(x) + + @guppy(effects=[]) + def main() -> int: + return higher_order(pure_func, 5) + + validate(main.compile_function()) diff --git a/tests/integration/test_struct.py b/tests/integration/test_struct.py index 56f25b37d..93e880255 100644 --- a/tests/integration/test_struct.py +++ b/tests/integration/test_struct.py @@ -6,7 +6,6 @@ from tests.integration.modules import struct_scope_defs - if TYPE_CHECKING: from collections.abc import Callable