diff --git a/guppylang-internals/src/guppylang_internals/compiler/builder.py b/guppylang-internals/src/guppylang_internals/compiler/builder/__init__.py similarity index 73% rename from guppylang-internals/src/guppylang_internals/compiler/builder.py rename to guppylang-internals/src/guppylang_internals/compiler/builder/__init__.py index 78fa1b137..5bd8122c0 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/builder.py +++ b/guppylang-internals/src/guppylang_internals/compiler/builder/__init__.py @@ -1,24 +1,33 @@ from abc import ABC, abstractmethod, abstractproperty -from collections.abc import Iterator, Sequence +from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass, field from types import TracebackType -from typing import Generic, TypeVar +from typing import Generic, NamedTuple, TypeAlias, TypeVar -from hugr import Node, Wire, ops, val +from hugr import Node, Wire, val from hugr import tys as ht from hugr.build import Block, Case, Cfg, Conditional, TailLoop from hugr.build import function as hf from hugr.hugr.node_port import ToNode from hugr.metadata import HugrDebugInfo +from hugr.ops import DataflowOp, Output from typing_extensions import Self, override from guppylang_internals.ast_util import AstNode -from guppylang_internals.compiler.core import may_have_side_effect from guppylang_internals.metadata.debug_info_util import ( debug_conditions_fulfilled, make_location_record, ) +from guppylang_internals.tys import Effect + + +@dataclass(frozen=True) +class Pure: + op: DataflowOp + + +OpWithEffects: TypeAlias = Pure | tuple[DataflowOp, Sequence[Effect]] @dataclass @@ -33,7 +42,7 @@ class DFBuilder(ABC, ToNode): """ current_ast_node: AstNode | None = field(default=None, kw_only=True) - _last_side_effect: Node | None = field(default=None, init=False) + _last_side_effect: dict[Effect, Node] = field(default_factory=dict, init=False) @abstractproperty def _raw(self) -> hf.Function | Case | TailLoop | Block: @@ -76,13 +85,14 @@ def inputs(self) -> Sequence[Wire]: def set_outputs(self, *outputs: Wire) -> hf.Function | Case | TailLoop | Block: self._raw.set_outputs(*outputs) - if self._last_side_effect is not None: - self._handle_side_effects(self._raw.output_node) + self._handle_side_effects( + self._raw.output_node, list(self._last_side_effect.keys()) + ) return self._raw def add_op( self, - op: ops.DataflowOp, + op: OpWithEffects, /, *args: Wire, set_debug_info: bool = True, @@ -90,9 +100,9 @@ def add_op( """Adds an op to the dataflow graph builder. Set `set_debug_info=False` to avoid automatic debug information attachment. """ + op, effects = (op.op, []) if isinstance(op, Pure) else op op_node = self._raw.add_op(op, *args) - if may_have_side_effect(op): - self._handle_side_effects(op_node) + self._handle_side_effects(op_node, effects) if set_debug_info and debug_conditions_fulfilled(self.current_ast_node): assert self.current_ast_node is not None # for type-checker @@ -101,19 +111,36 @@ def add_op( ) return op_node - def _handle_side_effects(self, op_node: ToNode) -> None: - if self._last_side_effect is None: - self._propagate_side_effects() - self._last_side_effect = self.input_node - else: - assert not isinstance(self._raw.hugr[self._last_side_effect].op, ops.Output) + def _handle_side_effects(self, op_node: ToNode, effects: Iterable[Effect]) -> None: + """Updates Hugr to reflect `op_node` having effects `effects`. + Does nothing if effects is empty (or the node already has those effects).""" node = op_node.to_node() - if self._last_side_effect != node: # avoid self-loops when propagating - self._raw.add_state_order(self._last_side_effect, node) - self._last_side_effect = node + to_propagate = set() # Effects newly added to our container + + def get_last_node(e: Effect) -> Node: + last = self._last_side_effect.get(e) + if last is None: + to_propagate.add(e) + last = self.input_node + else: + assert not isinstance(self._raw.hugr[last].op, Output) + self._last_side_effect[e] = node + return last + + prev_nodes = {get_last_node(e) for e in effects} + # Avoid cycles and duplicate edges: + prev_nodes.discard(node) + for prev in self._raw.hugr.incoming_order_links(node): + prev_nodes.discard(prev) + + for prev in prev_nodes: + self._raw.add_state_order(prev, node) + + if to_propagate: + self._propagate_side_effects(to_propagate) @abstractmethod - def _propagate_side_effects(self) -> None: + def _propagate_side_effects(self, effects: Iterable[Effect]) -> None: """Subclasses must implement to mark the container node as side-effecting within any parent/ancestor builder""" @@ -121,6 +148,7 @@ def call( self, func: ToNode, *args: Wire, + effects: Sequence[Effect], instantiation: ht.FunctionType | None = None, type_args: Sequence[ht.TypeArg] | None = None, set_debug_info: bool = True, @@ -131,7 +159,7 @@ def call( call = self._raw.call( func, *args, instantiation=instantiation, type_args=type_args ) - self._handle_side_effects(call) + self._handle_side_effects(call, effects) if set_debug_info and debug_conditions_fulfilled(self.current_ast_node): assert self.current_ast_node is not None # for type-checker call.metadata[HugrDebugInfo] = make_location_record(self.current_ast_node) @@ -190,16 +218,17 @@ class TailLoopBuilder(_DFBuilderRaw[TailLoop]): def set_loop_outputs(self, predicate: Wire, *outputs: Wire) -> None: self._raw.set_loop_outputs(predicate, *outputs) - if self._last_side_effect is not None: - self._handle_side_effects(self._raw.output_node) + self._handle_side_effects( + self._raw.output_node, list(self._last_side_effect.keys()) + ) - def _propagate_side_effects(self) -> None: - self.parent._handle_side_effects(self._raw) + def _propagate_side_effects(self, effects: Iterable[Effect]) -> None: + self.parent._handle_side_effects(self._raw, effects) @dataclass class FunctionBuilder(_DFBuilderRaw[hf.Function]): - def _propagate_side_effects(self) -> None: + def _propagate_side_effects(self, effects: Iterable[Effect]) -> None: pass # No parent @override @@ -213,10 +242,10 @@ class CaseBuilder(_DFBuilderRaw[Case]): parent: Conditional grandparent: DFBuilder - def _propagate_side_effects(self) -> None: + def _propagate_side_effects(self, effects: Iterable[Effect]) -> None: # No need to do anything in the Conditional, # but the Conditional itself needs to be ordered inside its parent - self.grandparent._handle_side_effects(self.parent) + self.grandparent._handle_side_effects(self.parent, effects) @dataclass @@ -251,12 +280,13 @@ class BlockBuilder(_DFBuilderRaw[Block]): parent: Cfg grandparent: DFBuilder - def _propagate_side_effects(self) -> None: + def _propagate_side_effects(self, effects: Iterable[Effect]) -> None: # No need to do anything in the CFG, but the CFG itself # needs to be ordered inside its parent, - self.grandparent._handle_side_effects(self.parent) + self.grandparent._handle_side_effects(self.parent, effects) def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: self._raw.set_outputs(branching, *other_outputs) - if self._last_side_effect is not None: - self._handle_side_effects(self._raw.output_node) + self._handle_side_effects( + self._raw.output_node, list(self._last_side_effect.keys()) + ) diff --git a/guppylang-internals/src/guppylang_internals/compiler/builder/ops.py b/guppylang-internals/src/guppylang_internals/compiler/builder/ops.py new file mode 100644 index 000000000..b2a1b4c9e --- /dev/null +++ b/guppylang-internals/src/guppylang_internals/compiler/builder/ops.py @@ -0,0 +1,20 @@ +from hugr import ops +from hugr.tys import Sum, Type, TypeRow + +from guppylang_internals.compiler.builder import OpWithEffects, Pure + + +def make_tuple(tys: TypeRow | None = None) -> OpWithEffects: + return Pure(ops.MakeTuple(tys)) + + +def unpack_tuple(tys: TypeRow | None = None) -> OpWithEffects: + return Pure(ops.UnpackTuple(tys)) + + +def tag(tag: int, rows: Sum) -> OpWithEffects: + return Pure(ops.Tag(tag, rows)) + + +def some(ty: Type) -> OpWithEffects: + return Pure(ops.Some(ty)) diff --git a/guppylang-internals/src/guppylang_internals/compiler/cfg_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/cfg_compiler.py index 2f565eb9a..c6b1dcdb8 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/cfg_compiler.py +++ b/guppylang-internals/src/guppylang_internals/compiler/cfg_compiler.py @@ -1,7 +1,7 @@ import functools from collections.abc import Sequence -from hugr import Wire, ops +from hugr import Wire from hugr import tys as ht from hugr.build import cfg as hc from hugr.hugr.node_port import ToNode @@ -13,7 +13,7 @@ Signature, ) from guppylang_internals.checker.core import Place, Variable -from guppylang_internals.compiler.builder import BlockBuilder, DFBuilder +from guppylang_internals.compiler.builder import BlockBuilder, DFBuilder, ops from guppylang_internals.compiler.core import ( CompilerContext, DFContainer, @@ -109,7 +109,7 @@ def compile_bb( else: # Even if we don't branch, we still have to add a `Sum(())` predicates branch_port = dfg.builder.add_op( - ops.Tag(0, ht.UnitSum(1)), set_debug_info=False + ops.tag(0, ht.UnitSum(1)), set_debug_info=False ) # Finally, we have to add the block output. @@ -206,7 +206,7 @@ def choose_vars_for_tuple_sum( for i, var_row in enumerate(output_vars): case = conditional.add_case(i) outputs = [case.inputs()[all_vars_idxs[v.id]] for v in var_row] - tag = case.add_op(ops.Tag(i, sum_type), *outputs) + tag = case.add_op(ops.tag(i, sum_type), *outputs) case.set_outputs(tag) return conditional diff --git a/guppylang-internals/src/guppylang_internals/compiler/core.py b/guppylang-internals/src/guppylang_internals/compiler/core.py index b0ebc1b34..255bd7c0b 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/core.py +++ b/guppylang-internals/src/guppylang_internals/compiler/core.py @@ -3,11 +3,12 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING -from hugr import Hugr, Wire, ops +from hugr import Hugr, Wire from hugr import tys as ht from hugr.build import function as hf from hugr.build.dfg import DefinitionBuilder from hugr.hugr.base import OpVarCov +from hugr.ops import Module from hugr.std import PRELUDE from hugr.std.collections.array import EXTENSION as ARRAY_EXTENSION from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXTENSION @@ -19,6 +20,7 @@ TupleAccess, Variable, ) +from guppylang_internals.compiler.builder import ops from guppylang_internals.definition.common import ( CompilableDef, CompiledDef, @@ -85,7 +87,7 @@ class CompilerContext(ToHugrContext): themselves (i.e. `compile_inner` has not yet been called). """ - module: DefinitionBuilder[ops.Module] + module: DefinitionBuilder[Module] #: The definitions compiled so far. For generic definitions, their id can occur #: multiple times here with respectively different monomorphizations. See @@ -107,7 +109,7 @@ class CompilerContext(ToHugrContext): def __init__( self, - module: DefinitionBuilder[ops.Module], + module: DefinitionBuilder[Module], exported_defs: set[DefId], file_table: StringTable | None = None, ) -> None: @@ -227,7 +229,7 @@ def __getitem__(self, place: Place) -> Wire: raise InternalGuppyError(f"Couldn't obtain a port for `{place}`") child_types = [child.ty.to_hugr(self.ctx) for child in children] child_wires = [self[child] for child in children] - wire = self.builder.add_op(ops.MakeTuple(child_types), *child_wires)[0] + wire = self.builder.add_op(ops.make_tuple(child_types), *child_wires)[0] for child in children: if child.ty.linear: self.locals.pop(child.id) @@ -240,7 +242,7 @@ def __setitem__(self, place: Place, port: Wire) -> None: is_return = isinstance(place, Variable) and is_return_var(place.name) if isinstance(place.ty, StructType) and not is_return: hugr_fields_ty = [t.ty.to_hugr(self.ctx) for t in place.ty.fields] - unpack = self.builder.add_op(ops.UnpackTuple(hugr_fields_ty), port) + unpack = self.builder.add_op(ops.unpack_tuple(hugr_fields_ty), port) for field, field_port in zip(place.ty.fields, unpack, strict=True): self[FieldAccess(place, field, None)] = field_port # If we had a previous wire assigned to this place, we need forget about it. @@ -249,7 +251,7 @@ def __setitem__(self, place: Place, port: Wire) -> None: # Same for tuples. elif isinstance(place.ty, TupleType) and not is_return: hugr_elem_tys = [ty.to_hugr(self.ctx) for ty in place.ty.element_types] - unpack = self.builder.add_op(ops.UnpackTuple(hugr_elem_tys), port) + unpack = self.builder.add_op(ops.unpack_tuple(hugr_elem_tys), port) for idx, (elem, elem_port) in enumerate( zip(place.ty.element_types, unpack, strict=True) ): @@ -319,30 +321,6 @@ def get_parent_type(defn: Definition) -> "RawDef | None": ] -def may_have_side_effect(op: ops.Op) -> bool: - """Checks whether an operation could have a side-effect. - - We need to insert implicit state order edges between these kinds of nodes to ensure - they are executed in the correct order, even if there is no data dependency. - """ - match op: - case ops.ExtOp() as ext_op: - return ext_op.op_def().qualified_name() in EXTENSION_OPS_WITH_SIDE_EFFECTS - case ops.Custom(op_name=op_name, extension=extension): - qualified_name = f"{extension}.{op_name}" if extension else op_name - return qualified_name in EXTENSION_OPS_WITH_SIDE_EFFECTS - case ops.Call() | ops.CallIndirect(): - # Conservative choice is to assume that all calls could have side effects. - # In the future we could inspect the call graph to figure out a more - # precise answer - return True - case _: - # There is no need to handle TailLoop (in case of non-termination) since - # TailLoops are only generated for array comprehensions which must have - # statically-guaranteed (finite) size. TODO revisit this for lists. - return False - - #: List of linear extension types that correspond to affine Guppy types and thus require #: insertion of an explicit drop operation. AFFINE_EXTENSION_TYS: list[str] = [ @@ -378,13 +356,6 @@ def requires_drop(ty: ht.Type) -> bool: return False -def drop_op(ty: ht.Type) -> ops.ExtOp: - """Returns the operation to drop affine values.""" - return GUPPY_EXTENSION.get_op("drop").instantiate( - [ht.TypeTypeArg(ty)], ht.FunctionType([ty], []) - ) - - def insert_drops(hugr: Hugr[OpVarCov]) -> None: """Inserts explicit drop ops for unconnected ports into the Hugr. TODO: This is a quick workaround until we can properly insert these drops during @@ -403,5 +374,8 @@ def insert_drops(hugr: Hugr[OpVarCov]) -> None: and isinstance(kind, ht.ValueKind) and requires_drop(kind.ty) ): - drop = hugr.add_node(drop_op(kind.ty), parent=data.parent) + drop_op = GUPPY_EXTENSION.get_op("drop").instantiate( + [ht.TypeTypeArg(kind.ty)], ht.FunctionType([kind.ty], []) + ) + drop = hugr.add_node(drop_op, parent=data.parent) hugr.add_link(port, drop.inp(0)) diff --git a/guppylang-internals/src/guppylang_internals/compiler/expr_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/expr_compiler.py index 564a024d9..02ffea5de 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/expr_compiler.py +++ b/guppylang-internals/src/guppylang_internals/compiler/expr_compiler.py @@ -8,7 +8,8 @@ import hugr.std.int import hugr.std.logic import hugr.std.prelude -from hugr import Node, Wire, ops +from hugr import Node, Wire +from hugr import ops as hops from hugr import tys as ht from hugr import val as hv @@ -16,7 +17,12 @@ from guppylang_internals.cfg.builder import tmp_vars from guppylang_internals.checker.core import Variable, contains_subscript from guppylang_internals.checker.errors.generic import UnsupportedError -from guppylang_internals.compiler.builder import CondBuilder, DFBuilder +from guppylang_internals.compiler.builder import ( + CondBuilder, + DFBuilder, + Pure, + ops, +) from guppylang_internals.compiler.core import ( DEBUG_EXTENSION, CompilerBase, @@ -71,10 +77,12 @@ list_new, ) from guppylang_internals.std._internal.compiler.prelude import ( + barrier_op, build_panic, make_error, panic, ) +from guppylang_internals.tys import Effect from guppylang_internals.tys.builtin import ( bool_type, get_element_type, @@ -286,12 +294,12 @@ def visit_List(self, node: ast.List) -> Wire: def _unpack_tuple(self, wire: Wire, types: Sequence[Type]) -> Sequence[Wire]: """Add a tuple unpack operation to the graph""" types = [t.to_hugr(self.ctx) for t in types] - return list(self.builder.add_op(ops.UnpackTuple(types), wire)) + return list(self.builder.add_op(ops.unpack_tuple(types), wire)) def _pack_tuple(self, wires: Sequence[Wire], types: Sequence[Type]) -> Wire: """Add a tuple pack operation to the graph""" types = [t.to_hugr(self.ctx) for t in types] - return self.builder.add_op(ops.MakeTuple(types), *wires) + return self.builder.add_op(ops.make_tuple(types), *wires) def _pack_returns(self, returns: Sequence[Wire], return_ty: Type) -> Wire: """Groups function return values into a tuple""" @@ -334,7 +342,7 @@ def visit_LocalCall(self, node: LocalCall) -> Wire: args = self._compile_call_args(node.args, func_ty) call = self.builder.add_op( - ops.CallIndirect(func_ty.to_hugr(self.ctx)), + (hops.CallIndirect(func_ty.to_hugr(self.ctx)), func_ty.effects), func, *args, ) @@ -393,7 +401,10 @@ def _compile_tensor_with_leftovers( consumed_args, other_args = args[0:input_len], args[input_len:] consumed_wires = self._compile_call_args(consumed_args, func_ty) call = self.builder.add_op( - ops.CallIndirect(func_ty.to_hugr(self.ctx)), + ( + hops.CallIndirect(func_ty.to_hugr(self.ctx)), + func_ty.effects, + ), func, *consumed_wires, ) @@ -464,7 +475,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> Wire: # since it is not implemented via a dunder method if isinstance(node.op, ast.Not): arg = self.visit(node.operand) - return self.builder.add_op(hugr.std.logic.Not, arg) + return self.builder.add_op(Pure(hugr.std.logic.Not), arg) raise InternalGuppyError("Node should have been removed during type checking.") @@ -527,12 +538,10 @@ def visit_AbortExpr(self, node: AbortExpr) -> Wire: def visit_BarrierExpr(self, node: BarrierExpr) -> Wire: hugr_tys = [get_type(e).to_hugr(self.ctx) for e in node.args] - op = hugr.std.prelude.PRELUDE_EXTENSION.get_op("Barrier").instantiate( - [ht.ListArg([ht.TypeTypeArg(ty) for ty in hugr_tys])], - ht.FunctionType.endo(hugr_tys), - ) - barrier_n = self.builder.add_op(op, *(self.visit(e) for e in node.args)) + barrier_n = self.builder.add_op( + barrier_op(hugr_tys), *(self.visit(e) for e in node.args) + ) self._update_inout_ports(node.args, iter(barrier_n), node.func_ty) return self._pack_returns([], NoneType()) @@ -550,7 +559,10 @@ def visit_StateOutputExpr(self, node: StateOutputExpr) -> Wire: [standard_array_type(ht.Qubit, num_qubits_arg)], ) - op = ops.ExtOp(DEBUG_EXTENSION.get_op("StateResult"), signature=sig, args=args) + op = ( + hops.ExtOp(DEBUG_EXTENSION.get_op("StateResult"), signature=sig, args=args), + [Effect.ANY], + ) qubit_arr_in: Wire if not node.array_len: @@ -676,17 +688,17 @@ def _build_generators( break_pred_hugr_ty = ht.Either([iter_ty.to_hugr(self.ctx)], []) with stop_case: self.dfg[break_pred.place] = self.builder.add_op( - ops.Tag(1, break_pred_hugr_ty) + ops.tag(1, break_pred_hugr_ty) ) # Otherwise, we continue, set the break predicate to false, and insert # the iterator for the next loop iteration stack.enter_context(hasnext_case) next_wire = self.dfg[next_var.place] - elt, it = self.builder.add_op(ops.UnpackTuple(), next_wire) + elt, it = self.builder.add_op(ops.unpack_tuple(), next_wire) compiler.dfg = self.dfg compiler._assign(gen.target, elt) self.dfg[break_pred.place] = self.builder.add_op( - ops.Tag(0, break_pred_hugr_ty), it + ops.tag(0, break_pred_hugr_ty), it ) # Enter nested conditionals for each if guard on the generator for if_expr in gen.ifs: @@ -717,7 +729,7 @@ def pack_returns( types = type_to_row(return_ty) assert len(returns) == len(types) hugr_tys = [t.to_hugr(ctx) for t in types] - return builder.add_op(ops.MakeTuple(hugr_tys), *returns) + return builder.add_op(ops.make_tuple(hugr_tys), *returns) assert len(returns) == 1, ( f"Expected a single return value. Got {returns}. return type {return_ty}" ) @@ -735,7 +747,7 @@ def unpack_wire( if isinstance(return_ty, TupleType | NoneType): types = type_to_row(return_ty) hugr_tys = [t.to_hugr(ctx) for t in types] - return list(builder.add_op(ops.UnpackTuple(hugr_tys), wire).outputs()) + return list(builder.add_op(ops.unpack_tuple(hugr_tys), wire).outputs()) return [wire] diff --git a/guppylang-internals/src/guppylang_internals/compiler/hugr_extension.py b/guppylang-internals/src/guppylang_internals/compiler/hugr_extension.py index 1dd7346d6..01486a4a4 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/hugr_extension.py +++ b/guppylang-internals/src/guppylang_internals/compiler/hugr_extension.py @@ -9,6 +9,8 @@ import hugr.tys as ht from hugr import ops +from guppylang_internals.compiler.builder import OpWithEffects, Pure + if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -77,7 +79,7 @@ class PartialOp(ops.AsExtOp): @classmethod def from_closure( cls, closure_ty: ht.FunctionType, captured_tys: Sequence[ht.Type] - ) -> PartialOp: + ) -> OpWithEffects: """An operation that partially evaluates a function. args: @@ -93,10 +95,12 @@ def from_closure( assert captured_tys == closure_ty.input[: len(captured_tys)] other_inputs = closure_ty.input[len(captured_tys) :] - return cls( - captured_inputs=list(captured_tys), - other_inputs=list(other_inputs), - outputs=list(closure_ty.output), + return Pure( + cls( + captured_inputs=list(captured_tys), + other_inputs=list(other_inputs), + outputs=list(closure_ty.output), + ) ) def op_def(self) -> he.OpDef: diff --git a/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py index f7bc4e83a..7650e3f79 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py +++ b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py @@ -6,6 +6,7 @@ from guppylang_internals.ast_util import get_type from guppylang_internals.checker.core import SubscriptAccess, contains_subscript from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back +from guppylang_internals.compiler.builder import Pure from guppylang_internals.compiler.cfg_compiler import compile_cfg from guppylang_internals.compiler.core import CompilerContext, DFContainer from guppylang_internals.compiler.expr_compiler import ExprCompiler @@ -75,10 +76,12 @@ def compile_modified_block( if modified_block.has_dagger(): dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty]) call = dfg.builder.add_op( - ops.ExtOp( - dagger_op_def, - dagger_ty, - [in_out_arg, other_in_arg], + Pure( # This is generation of the daggered version, not calling it (below) + ops.ExtOp( + dagger_op_def, + dagger_ty, + [in_out_arg, other_in_arg], + ) ), call, ) @@ -87,10 +90,13 @@ def compile_modified_block( for power in modified_block.power: num = expr_compiler.compile(power.iter, dfg) call = dfg.builder.add_op( - ops.ExtOp( - power_op_def, - power_ty, - [in_out_arg, other_in_arg], + # This is generation of the powered version, not calling it (below) + Pure( + ops.ExtOp( + power_op_def, + power_ty, + [in_out_arg, other_in_arg], + ) ), call, num, @@ -112,10 +118,13 @@ def compile_modified_block( output_fn_ty = ht.FunctionType( [std_array, *hugr_ty.input], [std_array, *hugr_ty.output] ) - op = ops.ExtOp( - control_op_def, - ht.FunctionType([input_fn_ty], [output_fn_ty]), - [qubit_num, in_out_arg, other_in_arg], + # Compilation of the controlled version is pure, not calling it (below) + op = Pure( + ops.ExtOp( + control_op_def, + ht.FunctionType([input_fn_ty], [output_fn_ty]), + [qubit_num, in_out_arg, other_in_arg], + ) ) call = dfg.builder.add_op(op, call) # Update types: later modifiers see the newly wrapped function type @@ -146,7 +155,7 @@ def compile_modified_block( # Call the modified block. call = dfg.builder.add_op( - ops.CallIndirect(), + (ops.CallIndirect(), body_ty.effects), call, *ctrl_args, *args, diff --git a/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py index 509382821..9be258e5a 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py +++ b/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py @@ -2,11 +2,12 @@ import functools from collections.abc import Sequence -from hugr import Wire, ops +from hugr import Wire from guppylang_internals.ast_util import AstVisitor, get_type from guppylang_internals.checker.core import Variable, contains_subscript from guppylang_internals.compiler.builder import DFBuilder +from guppylang_internals.compiler.builder.ops import unpack_tuple from guppylang_internals.compiler.core import ( CompilerBase, CompilerContext, @@ -103,7 +104,7 @@ def _assign_tuple(self, lhs: TupleUnpack, port: Wire) -> None: # Unpack the RHS tuple left, starred, right = lhs.pattern.left, lhs.pattern.starred, lhs.pattern.right types = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(lhs))] - unpack = self.builder.add_op(ops.UnpackTuple(types), port) + unpack = self.builder.add_op(unpack_tuple(types), port) ports = list(unpack) # Assign left and right @@ -206,7 +207,7 @@ def visit_Return(self, node: ast.Return) -> None: row: list[tuple[Wire, Type]] if isinstance(return_ty, TupleType): types = [e.to_hugr(self.ctx) for e in return_ty.element_types] - unpack = self.builder.add_op(ops.UnpackTuple(types), port) + unpack = self.builder.add_op(unpack_tuple(types), port) row = list(zip(unpack, return_ty.element_types, strict=True)) else: row = [(port, return_ty)] diff --git a/guppylang-internals/src/guppylang_internals/definition/custom.py b/guppylang-internals/src/guppylang_internals/definition/custom.py index 001a1c66c..d0fa326d5 100644 --- a/guppylang-internals/src/guppylang_internals/definition/custom.py +++ b/guppylang-internals/src/guppylang_internals/definition/custom.py @@ -20,7 +20,10 @@ from guppylang_internals.checker.core import Context, Globals from guppylang_internals.checker.expr_checker import check_call, synthesize_call from guppylang_internals.checker.func_checker import check_signature -from guppylang_internals.compiler.builder import DFBuilder, FunctionBuilder +from guppylang_internals.compiler.builder import ( + DFBuilder, + FunctionBuilder, +) from guppylang_internals.compiler.core import ( CompilerContext, DFContainer, @@ -344,6 +347,14 @@ def compile_call( ) as compiler: return compiler.compile_with_inouts(args) + @property + def effects(self) -> Sequence[Effect]: + """The effects of the function.""" + # ALAN ?? if self.has_signature: + return self.ty.effects + # else: + # return [] + class CustomCallChecker(ABC): """Abstract base class for custom function call type checkers.""" @@ -413,7 +424,7 @@ class CustomInoutCallCompiler(ABC): ctx: CompilerContext node: AstNode ty: ht.FunctionType - func: CustomMonoFunctionDef | None + func: CustomMonoFunctionDef _depth = 0 @@ -425,7 +436,7 @@ def _setup( ctx: CompilerContext, node: AstNode, hugr_ty: ht.FunctionType, - func: CustomMonoFunctionDef | None, + func: CustomMonoFunctionDef, ) -> Generator["CustomInoutCallCompiler", None, None]: """ A context manager to temporarily set up the compiler with required arguments, @@ -532,7 +543,7 @@ def __init__( @override def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: op = self.op(self.ty, self.type_args, self.ctx) - node = self.builder.add_op(op, *args) + node = self.builder.add_op((op, self.func.effects), *args) num_returns = ( len(type_to_row(self.func.ty.output)) if self.func else len(self.ty.output) ) @@ -582,6 +593,8 @@ def _handle_affine_type(self, ty: ht.Type, arg: Wire) -> list[Wire]: type_args, ht.FunctionType(self.ty.input, self.ty.output), ) + # ALAN assume panics if any borrowed? + clone_op = (clone_op, [Effect.ANY]) return list(self.builder.add_op(clone_op, arg)) case _: pass diff --git a/guppylang-internals/src/guppylang_internals/definition/enum.py b/guppylang-internals/src/guppylang_internals/definition/enum.py index 057da1df4..75be5597a 100644 --- a/guppylang-internals/src/guppylang_internals/definition/enum.py +++ b/guppylang-internals/src/guppylang_internals/definition/enum.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from typing import ClassVar, Generic, TypeVar -from hugr import Wire, ops +from hugr import Wire from guppylang_internals.ast_util import AstNode from guppylang_internals.checker.core import Globals @@ -12,6 +12,7 @@ UnexpectedError, UnsupportedError, ) +from guppylang_internals.compiler.builder.ops import tag from guppylang_internals.compiler.core import GlobalConstId from guppylang_internals.definition.common import ( CheckableDef, @@ -285,7 +286,7 @@ def compile(self, wires: list[Wire]) -> list[Wire]: assert isinstance(inst_enum_type, EnumType) # for mypy return list( self.builder.add_op( - ops.Tag(self.variant_idx, inst_enum_type.to_hugr(self.ctx)), + tag(self.variant_idx, inst_enum_type.to_hugr(self.ctx)), *wires, ) ) diff --git a/guppylang-internals/src/guppylang_internals/definition/function.py b/guppylang-internals/src/guppylang_internals/definition/function.py index 522ab3059..3c31f87e9 100644 --- a/guppylang-internals/src/guppylang_internals/definition/function.py +++ b/guppylang-internals/src/guppylang_internals/definition/function.py @@ -345,7 +345,7 @@ def compile_call( """Compiles a call to the function.""" num_returns = len(type_to_row(ty.output)) with dfg.builder.set_ast_context(call_ast): - call = dfg.builder.call(func, *args) + call = dfg.builder.call(func, *args, effects=ty.effects) return CallReturnWires( regular_returns=list(call[:num_returns]), inout_returns=list(call[num_returns:]), diff --git a/guppylang-internals/src/guppylang_internals/definition/pytket_circuits.py b/guppylang-internals/src/guppylang_internals/definition/pytket_circuits.py index 790624007..384c06833 100644 --- a/guppylang-internals/src/guppylang_internals/definition/pytket_circuits.py +++ b/guppylang-internals/src/guppylang_internals/definition/pytket_circuits.py @@ -4,7 +4,7 @@ import hugr.build.function as hf from guppylang.defs import GuppyDefinition -from hugr import Node, Wire, envelope, ops, val +from hugr import Node, Wire, envelope, val from hugr import tys as ht from hugr.build.dfg import DefinitionBuilder, OpVar from hugr.debug_info import DILocation, DISubprogram @@ -21,6 +21,7 @@ check_signature, ) from guppylang_internals.compiler.builder import FunctionBuilder +from guppylang_internals.compiler.builder.ops import unpack_tuple from guppylang_internals.compiler.core import CompilerContext, DFContainer from guppylang_internals.debug_mode import debug_mode_enabled from guppylang_internals.definition.common import ( @@ -270,14 +271,15 @@ def compile_outer( angle_wires = [name_to_param[name] for name in param_order] # Need to convert all angles to rotations. for angle in angle_wires: - [halfturns] = outer_func.add_op(ops.UnpackTuple([FLOAT_T]), angle) + [halfturns] = outer_func.add_op(unpack_tuple([FLOAT_T]), angle) rotation = outer_func.add_op(from_halfturns_unchecked(), halfturns) param_wires.append(rotation) - # Pass all arguments to call node. Note that since we are using a - # FunctionBuilder, this will default to assuming that the target function - # is side-effecting, so may produce more order edges than necessary. - call_node = outer_func.call(hugr_func, *(input_list + bool_wires + param_wires)) + # Pass all arguments to call node. We assume that the target function has no + # side-effects, since it came from a circuit. + call_node = outer_func.call( + hugr_func, *(input_list + bool_wires + param_wires), effects=[] + ) # Add debug info metadata to the call node inside the outer function definition. if debug_mode_enabled(): # Function stub case. diff --git a/guppylang-internals/src/guppylang_internals/definition/struct.py b/guppylang-internals/src/guppylang_internals/definition/struct.py index 2515e59ce..493b2630a 100644 --- a/guppylang-internals/src/guppylang_internals/definition/struct.py +++ b/guppylang-internals/src/guppylang_internals/definition/struct.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from typing import ClassVar -from hugr import Wire, ops +from hugr import Wire from guppylang_internals.ast_util import AstNode from guppylang_internals.checker.core import Globals @@ -12,6 +12,7 @@ UnexpectedError, UnsupportedError, ) +from guppylang_internals.compiler.builder.ops import make_tuple from guppylang_internals.compiler.core import GlobalConstId from guppylang_internals.definition.common import ( CheckableDef, @@ -215,7 +216,7 @@ class ConstructorCompiler(CustomCallCompiler): """Compiler for the `__new__` constructor method of a struct.""" def compile(self, args: list[Wire]) -> list[Wire]: - return list(self.builder.add_op(ops.MakeTuple(), *args)) + return list(self.builder.add_op(make_tuple(), *args)) constructor_sig = FunctionType( inputs=[ diff --git a/guppylang-internals/src/guppylang_internals/definition/traced.py b/guppylang-internals/src/guppylang_internals/definition/traced.py index 96d84194e..a0c8645d2 100644 --- a/guppylang-internals/src/guppylang_internals/definition/traced.py +++ b/guppylang-internals/src/guppylang_internals/definition/traced.py @@ -200,7 +200,7 @@ def compile_call( """Compiles a call to the function.""" num_returns = len(type_to_row(self.ty.output)) with dfg.builder.set_ast_context(node): - call = dfg.builder.call(self.func_def, *args) + call = dfg.builder.call(self.func_def, *args, effects=self.ty.effects) return CallReturnWires( regular_returns=list(call[:num_returns]), inout_returns=list(call[num_returns:]), diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/arithmetic.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/arithmetic.py index b34d64780..b820a0580 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/arithmetic.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/arithmetic.py @@ -8,6 +8,7 @@ from hugr import tys as ht from hugr.std.int import int_t +from guppylang_internals.compiler.builder import OpWithEffects, Pure from guppylang_internals.std._internal.compiler.prelude import error_type from guppylang_internals.tys.ty import NumericType @@ -47,41 +48,43 @@ def _instantiate_int_op( int_width: int | Sequence[int], inp: list[ht.Type], out: list[ht.Type], -) -> ops.ExtOp: +) -> OpWithEffects: op_def = hugr.std.int.INT_OPS_EXTENSION.get_op(name) int_width = [int_width] if isinstance(int_width, int) else int_width - return ops.ExtOp( - op_def, - ht.FunctionType(inp, out), - [ht.BoundedNatArg(w) for w in int_width], + return Pure( + ops.ExtOp( + op_def, + ht.FunctionType(inp, out), + [ht.BoundedNatArg(w) for w in int_width], + ) ) -def ieq(width: int) -> ops.ExtOp: +def ieq(width: int) -> OpWithEffects: """Returns a `std.arithmetic.int.ieq` operation.""" return _instantiate_int_op("ieq", width, [int_t(width), int_t(width)], [ht.Bool]) -def ine(width: int) -> ops.ExtOp: +def ine(width: int) -> OpWithEffects: """Returns a `std.arithmetic.int.ine` operation.""" return _instantiate_int_op("ine", width, [int_t(width), int_t(width)], [ht.Bool]) -def iwiden_u(from_width: int, to_width: int) -> ops.ExtOp: +def iwiden_u(from_width: int, to_width: int) -> OpWithEffects: """Returns an unsigned `std.arithmetic.int.widen_u` operation.""" return _instantiate_int_op( "iwiden_u", [from_width, to_width], [int_t(from_width)], [int_t(to_width)] ) -def iwiden_s(from_width: int, to_width: int) -> ops.ExtOp: +def iwiden_s(from_width: int, to_width: int) -> OpWithEffects: """Returns a signed `std.arithmetic.int.widen_s` operation.""" return _instantiate_int_op( "iwiden_s", [from_width, to_width], [int_t(from_width)], [int_t(to_width)] ) -def inarrow_u(from_width: int, to_width: int) -> ops.ExtOp: +def inarrow_u(from_width: int, to_width: int) -> OpWithEffects: """Returns an unsigned `std.arithmetic.int.narrow_u` operation.""" return _instantiate_int_op( "inarrow_u", @@ -91,7 +94,7 @@ def inarrow_u(from_width: int, to_width: int) -> ops.ExtOp: ) -def inarrow_s(from_width: int, to_width: int) -> ops.ExtOp: +def inarrow_s(from_width: int, to_width: int) -> OpWithEffects: """Returns a signed `std.arithmetic.int.narrow_s` operation.""" return _instantiate_int_op( "inarrow_s", @@ -111,26 +114,26 @@ def _instantiate_convert_op( inp: list[ht.Type], out: list[ht.Type], args: list[ht.TypeArg] | None = None, -) -> ops.ExtOp: +) -> OpWithEffects: op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op(name) - return ops.ExtOp(op_def, ht.FunctionType(inp, out), args or []) + return Pure(ops.ExtOp(op_def, ht.FunctionType(inp, out), args or [])) -def convert_ifromusize() -> ops.ExtOp: +def convert_ifromusize() -> OpWithEffects: """Returns a `std.arithmetic.conversions.ifromusize` operation.""" return _instantiate_convert_op("ifromusize", [ht.USize()], [INT_T]) -def convert_itousize() -> ops.ExtOp: +def convert_itousize() -> OpWithEffects: """Returns a `std.arithmetic.conversions.itousize` operation.""" return _instantiate_convert_op("itousize", [INT_T], [ht.USize()]) -def convert_ifrombool() -> ops.ExtOp: +def convert_ifrombool() -> OpWithEffects: """Returns a `std.arithmetic.conversions.ifrombool` operation.""" return _instantiate_convert_op("ifrombool", [ht.Bool], [int_t(0)]) -def convert_itobool() -> ops.ExtOp: +def convert_itobool() -> OpWithEffects: """Returns a `std.arithmetic.conversions.itobool` operation.""" return _instantiate_convert_op("itobool", [int_t(0)], [ht.Bool]) diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/array.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/array.py index a0c0d8a94..9f6fd6901 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/array.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/array.py @@ -9,14 +9,18 @@ from hugr import tys as ht from hugr.std.collections.borrow_array import EXTENSION +from guppylang_internals.compiler.builder import OpWithEffects, Pure from guppylang_internals.definition.custom import CustomCallCompiler from guppylang_internals.definition.value import CallReturnWires from guppylang_internals.error import InternalGuppyError from guppylang_internals.std._internal.compiler.arithmetic import convert_itousize from guppylang_internals.std._internal.compiler.prelude import build_unwrap_right +from guppylang_internals.tys import Effect from guppylang_internals.tys.arg import ConstArg, TypeArg if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from guppylang_internals.ast_util import AstNode from guppylang_internals.compiler.builder import DFBuilder @@ -32,9 +36,16 @@ def _instantiate_array_op( length: ht.TypeArg, inp: list[ht.Type], out: list[ht.Type], -) -> ops.ExtOp: - return EXTENSION.get_op(name).instantiate( - [length, ht.TypeTypeArg(elem_ty)], ht.FunctionType(inp, out) + # Almost all (borrow-)array operations can panic if relevant element(s) are + # borrowed. Allow overriding for the minority that do not. + # Usual warning about mutable default arguments applies, but Sequence is read-only. + effects: Sequence[Effect] = [Effect.ANY], +) -> OpWithEffects: + return ( + EXTENSION.get_op(name).instantiate( + [length, ht.TypeTypeArg(elem_ty)], ht.FunctionType(inp, out) + ), + effects, ) @@ -57,16 +68,16 @@ def standard_array_type(elem_ty: ht.Type, length: ht.TypeArg) -> ht.ExtType: return defn.instantiate([length, elem_arg]) -def array_new(elem_ty: ht.Type, length: int) -> ops.ExtOp: +def array_new(elem_ty: ht.Type, length: int) -> OpWithEffects: """Returns an operation that creates a new fixed length array.""" length_arg = ht.BoundedNatArg(length) arr_ty = array_type(elem_ty, length_arg) return _instantiate_array_op( - "new_array", elem_ty, length_arg, [elem_ty] * length, [arr_ty] - ) + "new_array", elem_ty, length_arg, [elem_ty] * length, [arr_ty], effects=[] + ) # never panics -def array_unpack(elem_ty: ht.Type, length: int) -> ops.ExtOp: +def array_unpack(elem_ty: ht.Type, length: int) -> OpWithEffects: """Returns an operation that unpacks a fixed length array.""" length_arg = ht.BoundedNatArg(length) arr_ty = array_type(elem_ty, length_arg) @@ -75,7 +86,7 @@ def array_unpack(elem_ty: ht.Type, length: int) -> ops.ExtOp: ) -def array_get(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def array_get(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array `get` operation.""" assert elem_ty.type_bound() == ht.TypeBound.Copyable arr_ty = array_type(elem_ty, length) @@ -84,7 +95,7 @@ def array_get(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: ) -def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array `set` operation.""" arr_ty = array_type(elem_ty, length) return _instantiate_array_op( @@ -96,7 +107,7 @@ def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: ) -def array_pop(elem_ty: ht.Type, length: int, from_left: bool) -> ops.ExtOp: +def array_pop(elem_ty: ht.Type, length: int, from_left: bool) -> OpWithEffects: """Returns an operation that pops an element from the left of an array.""" assert length > 0 length_arg = ht.BoundedNatArg(length) @@ -108,11 +119,13 @@ def array_pop(elem_ty: ht.Type, length: int, from_left: bool) -> ops.ExtOp: ) -def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp: +def array_discard_empty(elem_ty: ht.Type) -> OpWithEffects: """Returns an operation that discards an array of length zero.""" arr_ty = array_type(elem_ty, ht.BoundedNatArg(0)) - return EXTENSION.get_op("discard_empty").instantiate( - [ht.TypeTypeArg(elem_ty)], ht.FunctionType([arr_ty], []) + return Pure( + EXTENSION.get_op("discard_empty").instantiate( + [ht.TypeTypeArg(elem_ty)], ht.FunctionType([arr_ty], []) + ) ) @@ -121,7 +134,7 @@ def array_scan( length: ht.TypeArg, new_elem_ty: ht.Type, accumulators: list[ht.Type], -) -> ops.ExtOp: +) -> OpWithEffects: """Returns an operation that maps and folds a function across an array.""" ty_args = [ length, @@ -135,49 +148,65 @@ def array_scan( *accumulators, ] outs = [array_type(new_elem_ty, length), *accumulators] - return EXTENSION.get_op("scan").instantiate(ty_args, ht.FunctionType(ins, outs)) + return ( + EXTENSION.get_op("scan").instantiate(ty_args, ht.FunctionType(ins, outs)), + [Effect.ANY], # can panic if any element is borrowed + ) -def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp: +def array_map( + elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type +) -> OpWithEffects: """Returns an operation that maps a function across an array.""" return array_scan(elem_ty, length, new_elem_ty, accumulators=[]) -def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: - """Returns an array `repeat` operation.""" - return EXTENSION.get_op("repeat").instantiate( - [length, ht.TypeTypeArg(elem_ty)], - ht.FunctionType( - [ht.FunctionType([], [elem_ty])], [array_type(elem_ty, length)] - ), +def array_repeat( + elem_ty: ht.Type, length: ht.TypeArg, effects: Iterable[Effect] +) -> OpWithEffects: + """Returns an array `repeat` operation for a function, of no arguments + to one element, with the specified effects.""" + func_ty = ht.FunctionType([], [elem_ty]) + return _instantiate_array_op( + "repeat", + elem_ty, + length, + [func_ty], + [array_type(elem_ty, length)], + effects=list( + set(effects) + ), # As the function, as it'll invoke the function many times ) -def array_to_std_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def array_to_std_array(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array operation to convert a value of the `borrow_array` type used by Guppy into a standard `array`. """ - return EXTENSION.get_op("to_array").instantiate( - [length, ht.TypeTypeArg(elem_ty)], - ht.FunctionType( - [array_type(elem_ty, length)], [standard_array_type(elem_ty, length)] - ), + return _instantiate_array_op( + "to_array", + elem_ty, + length, + [array_type(elem_ty, length)], + [standard_array_type(elem_ty, length)], ) -def std_array_to_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def std_array_to_array(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array operation to convert the standard `array` type into the `borrow_array` type used by Guppy. """ - return EXTENSION.get_op("from_array").instantiate( - [length, ht.TypeTypeArg(elem_ty)], - ht.FunctionType( - [standard_array_type(elem_ty, length)], [array_type(elem_ty, length)] - ), + return _instantiate_array_op( + "from_array", + elem_ty, + length, + [standard_array_type(elem_ty, length)], + [array_type(elem_ty, length)], + effects=[], # Cannot panic: a standard array always has every element ) -def barray_borrow(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def barray_borrow(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array `borrow` operation.""" arr_ty = array_type(elem_ty, length) return _instantiate_array_op( @@ -185,7 +214,7 @@ def barray_borrow(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: ) -def barray_return(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def barray_return(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array `return` operation.""" arr_ty = array_type(elem_ty, length) return _instantiate_array_op( @@ -193,27 +222,34 @@ def barray_return(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: ) -def barray_discard_all_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def barray_discard_all_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array `discard_all_borrowed` operation.""" arr_ty = array_type(elem_ty, length) return _instantiate_array_op("discard_all_borrowed", elem_ty, length, [arr_ty], []) -def barray_new_all_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def barray_new_all_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array `new_all_borrowed` operation.""" arr_ty = array_type(elem_ty, length) - return _instantiate_array_op("new_all_borrowed", elem_ty, length, [], [arr_ty]) + return _instantiate_array_op( + "new_all_borrowed", elem_ty, length, [], [arr_ty], effects=[] + ) -def barray_is_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def barray_is_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array `is_borrowed` operation.""" arr_ty = array_type(elem_ty, length) return _instantiate_array_op( - "is_borrowed", elem_ty, length, [arr_ty, ht.USize()], [arr_ty, ht.Bool] + "is_borrowed", + elem_ty, + length, + [arr_ty, ht.USize()], + [arr_ty, ht.Bool], + effects=[], ) -def array_clone(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def array_clone(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array `clone` operation for arrays none of whose elements are borrowed.""" assert elem_ty.type_bound() == ht.TypeBound.Copyable @@ -221,7 +257,7 @@ def array_clone(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: return _instantiate_array_op("clone", elem_ty, length, [arr_ty], [arr_ty, arr_ty]) -def array_swap(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: +def array_swap(elem_ty: ht.Type, length: ht.TypeArg) -> OpWithEffects: """Returns an array `swap` operation. Swaps two elements at given indices in-place. @@ -235,6 +271,7 @@ def array_swap(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: length, [arr_ty, ht.USize(), ht.USize()], [ht.Either([arr_ty], [arr_ty])], + effects=[], # ALAN TODO CHECK: Do we swap borrowedness? ) diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/either.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/either.py index af1118a9a..833362408 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/either.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/either.py @@ -1,10 +1,11 @@ from abc import ABC from collections.abc import Sequence -from hugr import Wire, ops +from hugr import Wire from hugr import tys as ht from hugr import val as hv +from guppylang_internals.compiler.builder.ops import tag from guppylang_internals.compiler.expr_compiler import unpack_wire from guppylang_internals.definition.custom import ( CustomCallCompiler, @@ -71,7 +72,7 @@ def compile(self, args: list[Wire]) -> list[Wire]: [inp] = args # Unpack the single input into a row inp_row = unpack_wire(inp, inp_arg.ty, self.builder, self.ctx, self.node) - return [self.builder.add_op(ops.Tag(self.tag, ty), *inp_row)] + return [self.builder.add_op(tag(self.tag, ty), *inp_row)] class EitherTestCompiler(EitherCompiler): @@ -86,7 +87,7 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: for i in [0, 1]: case = cond.add_case(i) val = hv.TRUE if i == self.tag else hv.FALSE - either = case.add_op(ops.Tag(i, self.either_ty), *case.inputs()) + either = case.add_op(tag(i, self.either_ty), *case.inputs()) case.set_outputs(case.load(val), either) [res, either] = cond.outputs() return CallReturnWires(regular_returns=[res], inout_returns=[either]) @@ -105,9 +106,9 @@ def compile(self, args: list[Wire]) -> list[Wire]: for i in [0, 1]: case = cond.add_case(i) if i == self.tag: - out = case.add_op(ops.Tag(1, ht.Option(*target_tys)), *case.inputs()) + out = case.add_op(tag(1, ht.Option(*target_tys)), *case.inputs()) else: - out = case.add_op(ops.Tag(0, ht.Option(*target_tys))) + out = case.add_op(tag(0, ht.Option(*target_tys))) case.set_outputs(out) return list(cond.outputs()) diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/frozenarray.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/frozenarray.py index 08910d037..7152f737c 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/frozenarray.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/frozenarray.py @@ -4,11 +4,11 @@ from typing import TYPE_CHECKING, Final -from hugr import Wire, ops +from hugr import Wire from hugr import tys as ht from hugr.std.collections.static_array import EXTENSION, StaticArray -from guppylang_internals.compiler.builder import FunctionBuilder +from guppylang_internals.compiler.builder import FunctionBuilder, OpWithEffects, Pure from guppylang_internals.compiler.core import GlobalConstId from guppylang_internals.compiler.expr_compiler import unpack_wire from guppylang_internals.definition.custom import CustomCallCompiler @@ -25,13 +25,15 @@ from hugr.build import function as hf -def static_array_get(elem_ty: ht.Type) -> ops.ExtOp: +def static_array_get(elem_ty: ht.Type) -> OpWithEffects: """Returns the static array `get` operation.""" assert elem_ty.type_bound() == ht.TypeBound.Copyable arr_ty = StaticArray(elem_ty) - return EXTENSION.get_op("get").instantiate( - [ht.TypeTypeArg(elem_ty)], - ht.FunctionType([arr_ty, ht.USize()], [ht.Option(elem_ty)]), + return Pure( + EXTENSION.get_op("get").instantiate( + [ht.TypeTypeArg(elem_ty)], + ht.FunctionType([arr_ty, ht.USize()], [ht.Option(elem_ty)]), + ) ) @@ -65,6 +67,10 @@ def compile(self, args: list[Wire]) -> list[Wire]: type_args = [ht.TypeTypeArg(elem_ty)] with self.builder.set_ast_context(self.node): out = self.builder.call( - self.getitem_func(), *args, instantiation=inst, type_args=type_args + self.getitem_func(), + *args, + instantiation=inst, + type_args=type_args, + effects=[], ) return unpack_wire(out, ty_arg.ty, self.builder, self.ctx) diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/list.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/list.py index 335e67e9e..bca71df88 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/list.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/list.py @@ -7,10 +7,12 @@ from typing import TYPE_CHECKING, TypeVar import hugr.std.collections.list -from hugr import Wire, ops +from hugr import Wire from hugr import tys as ht +from hugr.ops import DfParentOp, ExtOp from hugr.std.collections.list import List, ListVal +from guppylang_internals.compiler.builder import OpWithEffects, Pure, ops from guppylang_internals.definition.custom import CustomCallCompiler from guppylang_internals.definition.value import CallReturnWires from guppylang_internals.error import InternalGuppyError @@ -36,16 +38,18 @@ def _instantiate_list_op( name: str, elem_type: ht.Type, inp: list[ht.Type], out: list[ht.Type] -) -> ops.ExtOp: +) -> OpWithEffects: op_def = hugr.std.collections.list.EXTENSION.get_op(name) - return ops.ExtOp( - op_def, - ht.FunctionType(inp, out), - [ht.TypeTypeArg(elem_type)], + return Pure( + ExtOp( + op_def, + ht.FunctionType(inp, out), + [ht.TypeTypeArg(elem_type)], + ) ) -def list_pop(elem_type: ht.Type) -> ops.ExtOp: +def list_pop(elem_type: ht.Type) -> OpWithEffects: """Returns a list `pop` operation.""" list_type = List(elem_type) return _instantiate_list_op( @@ -53,13 +57,13 @@ def list_pop(elem_type: ht.Type) -> ops.ExtOp: ) -def list_push(elem_type: ht.Type) -> ops.ExtOp: +def list_push(elem_type: ht.Type) -> OpWithEffects: """Returns a list `push` operation.""" list_type = List(elem_type) return _instantiate_list_op("push", elem_type, [list_type, elem_type], [list_type]) -def list_get(elem_type: ht.Type) -> ops.ExtOp: +def list_get(elem_type: ht.Type) -> OpWithEffects: """Returns a list `get` operation.""" list_type = List(elem_type) return _instantiate_list_op( @@ -67,29 +71,31 @@ def list_get(elem_type: ht.Type) -> ops.ExtOp: ) -def list_set(elem_type: ht.Type) -> ops.ExtOp: +def list_set(elem_type: ht.Type) -> OpWithEffects: """Returns a list `set` operation.""" list_type = List(elem_type) return _instantiate_list_op( "set", elem_type, [list_type, ht.USize(), elem_type], + # Return supplied element if out of range else element removed from list [list_type, ht.Either([elem_type], [elem_type])], ) -def list_insert(elem_type: ht.Type) -> ops.ExtOp: +def list_insert(elem_type: ht.Type) -> OpWithEffects: """Returns a list `insert` operation.""" list_type = List(elem_type) return _instantiate_list_op( "insert", elem_type, [list_type, ht.USize(), elem_type], + # Return supplied element if out of range else unit [list_type, ht.Either([elem_type], [ht.Unit])], ) -def list_length(elem_type: ht.Type) -> ops.ExtOp: +def list_length(elem_type: ht.Type) -> OpWithEffects: """Returns a list `length` operation.""" list_type = List(elem_type) return _instantiate_list_op( @@ -128,7 +134,7 @@ def build_linear_getitem( # implementation of the list type ensures that linear element types are turned # into optionals. elem_opt_ty = ht.Option(elem_ty) - none = self.builder.add_op(ops.Tag(0, elem_opt_ty)) + none = self.builder.add_op(ops.tag(0, elem_opt_ty)) idx = self.builder.add_op(convert_itousize(), idx) list_wire, result = self.builder.add_op( list_set(elem_opt_ty), list_wire, idx, none @@ -183,7 +189,7 @@ def build_linear_setitem( """Lowers a call to `array.__setitem__` for linear arrays.""" # Embed the element into an optional elem_opt_ty = ht.Option(elem_ty) - elem = self.builder.add_op(ops.Some(elem_ty), elem) + elem = self.builder.add_op(ops.some(elem_ty), elem) idx = self.builder.add_op(convert_itousize(), idx) list_wire, result = self.builder.add_op( list_set(elem_opt_ty), list_wire, idx, elem @@ -276,7 +282,7 @@ def build_linear_push( """Lowers a call to `list.push` for linear lists.""" # Wrap element into an optional elem_opt_ty = ht.Option(elem_ty) - elem_opt = self.builder.add_op(ops.Some(elem_ty), elem) + elem_opt = self.builder.add_op(ops.some(elem_ty), elem) list_wire = self.builder.add_op(list_push(elem_opt_ty), list_wire, elem_opt) return CallReturnWires(regular_returns=[], inout_returns=[list_wire]) @@ -315,7 +321,7 @@ def compile(self, args: list[Wire]) -> list[Wire]: raise InternalGuppyError("Call compile_with_inouts instead") -P = TypeVar("P", bound=ops.DfParentOp) +P = TypeVar("P", bound=DfParentOp) def list_new(builder: DFBuilder, elem_type: ht.Type, args: list[Wire]) -> Wire: @@ -342,6 +348,6 @@ def _list_new_linear(builder: DFBuilder, elem_type: ht.Type, args: list[Wire]) - lst = builder.load(ListVal([], elem_ty=elem_opt_ty)) push_op = list_push(elem_opt_ty) for elem in args: - elem_opt = builder.add_op(ops.Some(elem_type), elem) + elem_opt = builder.add_op(ops.some(elem_type), elem) lst = builder.add_op(push_op, lst, elem_opt) return lst diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/mem.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/mem.py index f0d37640a..a7c2a6fb7 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/mem.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/mem.py @@ -11,6 +11,8 @@ class WithOwnedCompiler(CustomInoutCallCompiler): def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: [val, func] = args - [out, val] = self.builder.add_op(ops.CallIndirect(), func, val) + [out, val] = self.builder.add_op( + (ops.CallIndirect(), self.func.effects), func, val + ) outs = unpack_wire(out, get_type(self.node), self.builder, self.ctx) return CallReturnWires(regular_returns=outs, inout_returns=[val]) diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/option.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/option.py index 53158bcd9..1cf74ff2f 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/option.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/option.py @@ -1,9 +1,10 @@ from abc import ABC -from hugr import Wire, ops +from hugr import Wire from hugr import tys as ht from hugr import val as hv +from guppylang_internals.compiler.builder.ops import tag from guppylang_internals.compiler.expr_compiler import unpack_wire from guppylang_internals.definition.custom import ( CustomCallCompiler, @@ -37,7 +38,7 @@ def __init__(self, tag: int): self.tag = tag def compile(self, args: list[Wire]) -> list[Wire]: - return [self.builder.add_op(ops.Tag(self.tag, self.option_ty), *args)] + return [self.builder.add_op(tag(self.tag, self.option_ty), *args)] class OptionTestCompiler(OptionCompiler): @@ -52,7 +53,7 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: for i in [0, 1]: case = cond.add_case(i) val = hv.TRUE if i == self.tag else hv.FALSE - opt = case.add_op(ops.Tag(i, self.option_ty), *case.inputs()) + opt = case.add_op(tag(i, self.option_ty), *case.inputs()) case.set_outputs(case.load(val), opt) [res, opt] = cond.outputs() return CallReturnWires(regular_returns=[res], inout_returns=[opt]) diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/platform.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/platform.py index bef3930f3..a6d9ed3eb 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/platform.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/platform.py @@ -18,6 +18,7 @@ array_to_std_array, ) from guppylang_internals.std._internal.compiler.tket_exts import RESULT_EXTENSION +from guppylang_internals.tys import Effect from guppylang_internals.tys.arg import Argument, ConstArg from guppylang_internals.tys.builtin import get_element_type from guppylang_internals.tys.const import BoundConstVar, ConstValue @@ -57,7 +58,7 @@ def compile(self, args: list[Wire]) -> list[Wire]: args.append(tys.BoundedNatArg(NumericType.INT_WIDTH)) op = RESULT_EXTENSION.get_op(self.op_name) sig = tys.FunctionType(input=[hugr_ty], output=[]) - self.builder.add_op(op.instantiate(args, sig), value) + self.builder.add_op((op.instantiate(args, sig), [Effect.ANY]), value) return [] @@ -96,7 +97,7 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: if self.with_int_width: args.append(tys.BoundedNatArg(NumericType.INT_WIDTH)) op = ops.ExtOp(RESULT_EXTENSION.get_op(self.op_name), signature=sig, args=args) - self.builder.add_op(op, arr) + self.builder.add_op((op, [Effect.ANY]), arr) return CallReturnWires([], [out_arr]) diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/prelude.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/prelude.py index a82148f45..8eb24061d 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/prelude.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/prelude.py @@ -12,7 +12,12 @@ from hugr import tys as ht from hugr import val as hv -from guppylang_internals.compiler.builder import DFBuilder, FunctionBuilder +from guppylang_internals.compiler.builder import ( + DFBuilder, + FunctionBuilder, + OpWithEffects, + Pure, +) from guppylang_internals.compiler.core import CompilerContext, GlobalConstId from guppylang_internals.definition.custom import ( CustomCallCompiler, @@ -21,6 +26,7 @@ from guppylang_internals.definition.value import CallReturnWires from guppylang_internals.error import InternalGuppyError from guppylang_internals.nodes import AbortKind +from guppylang_internals.tys import Effect if TYPE_CHECKING: from collections.abc import Callable @@ -57,7 +63,7 @@ def __str__(self) -> str: def panic( inputs: list[ht.Type], outputs: list[ht.Type], kind: AbortKind = AbortKind.Panic -) -> ops.ExtOp: +) -> OpWithEffects: """Returns an operation that panics.""" name = "panic" if kind == AbortKind.Panic else "exit" op_def = hugr.std.PRELUDE.get_op(name) @@ -66,15 +72,16 @@ def panic( ht.ListArg([ht.TypeTypeArg(ty) for ty in outputs]), ] sig = ht.FunctionType([error_type(), *inputs], outputs) - return ops.ExtOp(op_def, sig, args) + return (ops.ExtOp(op_def, sig, args), [Effect.ANY]) -def make_error() -> ops.ExtOp: - """Returns an operation that makes an error.""" +def make_error() -> OpWithEffects: + """Returns an operation that makes an error (and returns the error as a value, + does nothing to raise it).""" op_def = hugr.std.PRELUDE.get_op("MakeError") args: list[ht.TypeArg] = [] sig = ht.FunctionType([ht.USize(), hugr.std.prelude.STRING_T], [error_type()]) - return ops.ExtOp(op_def, sig, args) + return Pure(ops.ExtOp(op_def, sig, args)) # ------------------------------------------------------ @@ -272,6 +279,7 @@ def unwrap_result( func_call = builder.call( func, either, + effects=[Effect.ANY], # panics instantiation=concrete_ty, type_args=type_args, ) @@ -303,22 +311,25 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: output=[ht.Either([error_type()], self.ty.output)], ) op = self.op(opt_func_type, self.type_args, self.ctx) - either = self.builder.add_op(op, *args) + either = self.builder.add_op((op, self.func.effects), *args) result = unwrap_result(self.builder, self.ctx, either) return CallReturnWires(regular_returns=[result], inout_returns=[]) +def barrier_op(tys: ht.TypeRow) -> OpWithEffects: + """Returns an operation that represents a barrier on the given types.""" + op_def = hugr.std.prelude.PRELUDE_EXTENSION.get_op("Barrier") + args = [ht.ListArg([ht.TypeTypeArg(ty) for ty in tys])] + return Pure(op_def.instantiate(args, ht.FunctionType.endo(tys))) + + class BarrierCompiler(CustomCallCompiler): """Compiler for the `barrier` function.""" def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: tys = [t for arg in args if (t := self.builder.get_wire_type(arg))] - op = hugr.std.prelude.PRELUDE_EXTENSION.get_op("Barrier").instantiate( - [ht.ListArg([ht.TypeTypeArg(ty) for ty in tys])] - ) - - barrier_n = self.builder.add_op(op, *args) + barrier_n = self.builder.add_op(barrier_op(tys), *args) return CallReturnWires( regular_returns=[], inout_returns=[barrier_n[i] for i in range(len(tys))] diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/qsystem.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/qsystem.py index 64a4f21c2..daf2e46d7 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/qsystem.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/qsystem.py @@ -2,6 +2,7 @@ from hugr import tys as ht from hugr.std.int import int_t +from guppylang_internals.compiler.builder import Pure from guppylang_internals.definition.custom import CustomInoutCallCompiler from guppylang_internals.definition.value import CallReturnWires from guppylang_internals.std._internal.compiler.arithmetic import inarrow_s, iwiden_s @@ -19,8 +20,12 @@ class RandomIntCompiler(CustomInoutCallCompiler): def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: [ctx] = args [rnd, ctx] = self.builder.add_op( - external_op("RandomInt", [], ext=QSYSTEM_RANDOM_EXTENSION)( - ht.FunctionType([RNGCONTEXT_T], [int_t(5), RNGCONTEXT_T]), (), self.ctx + Pure( + external_op("RandomInt", [], ext=QSYSTEM_RANDOM_EXTENSION)( + ht.FunctionType([RNGCONTEXT_T], [int_t(5), RNGCONTEXT_T]), + (), + self.ctx, + ) ), ctx, ) @@ -36,10 +41,12 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: self.builder, bound_sum, "bound must be a 32-bit integer" ) [rnd, ctx] = self.builder.add_op( - external_op("RandomIntBounded", [], ext=QSYSTEM_RANDOM_EXTENSION)( - ht.FunctionType([RNGCONTEXT_T, int_t(5)], [int_t(5), RNGCONTEXT_T]), - (), - self.ctx, + Pure( + external_op("RandomIntBounded", [], ext=QSYSTEM_RANDOM_EXTENSION)( + ht.FunctionType([RNGCONTEXT_T, int_t(5)], [int_t(5), RNGCONTEXT_T]), + (), + self.ctx, + ) ), ctx, bound, diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/quantum.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/quantum.py index 29745ae3a..d555559fa 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/quantum.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/quantum.py @@ -9,6 +9,8 @@ from hugr import tys as ht from hugr.std.float import FLOAT_T +from guppylang_internals.compiler.builder import OpWithEffects, Pure +from guppylang_internals.compiler.builder.ops import unpack_tuple from guppylang_internals.definition.custom import CustomInoutCallCompiler from guppylang_internals.definition.value import CallReturnWires from guppylang_internals.std._internal.compiler.tket_exts import ( @@ -17,6 +19,7 @@ QUANTUM_EXTENSION, ROTATION_EXTENSION, ) +from guppylang_internals.tys import Effect # ---------------------------------------------- # --------- tket.* extensions ----------------- @@ -30,10 +33,15 @@ ROTATION_T = ht.ExtType(ROTATION_T_DEF) -def from_halfturns_unchecked() -> ops.ExtOp: - return ops.ExtOp( - ROTATION_EXTENSION.get_op("from_halfturns_unchecked"), - ht.FunctionType([FLOAT_T], [ROTATION_T]), +def from_halfturns_unchecked() -> OpWithEffects: + """Return an operation that converts a float to a rotation, + panicking if the angle is not finite.""" + return ( + ops.ExtOp( + ROTATION_EXTENSION.get_op("from_halfturns_unchecked"), + ht.FunctionType([FLOAT_T], [ROTATION_T]), + ), + [Effect.ANY], ) @@ -79,8 +87,10 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: ) [q] = args [q, bit] = self.builder.add_op( - quantum_op(self.opname, ext=self.ext)( - ht.FunctionType([ht.Qubit], [ht.Qubit, return_ty]), (), self.ctx + Pure( + quantum_op(self.opname, ext=self.ext)( + ht.FunctionType([ht.Qubit], [ht.Qubit, return_ty]), (), self.ctx + ) ), q, ) @@ -97,16 +107,18 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: from guppylang_internals.std._internal.util import quantum_op [*qs, angle] = args - [halfturns] = self.builder.add_op(ops.UnpackTuple([FLOAT_T]), angle) + [halfturns] = self.builder.add_op(unpack_tuple([FLOAT_T]), angle) [rotation] = self.builder.add_op(from_halfturns_unchecked(), halfturns) qs = self.builder.add_op( - quantum_op(self.opname)( - ht.FunctionType( - [ht.Qubit for _ in qs] + [ROTATION_T], [ht.Qubit for _ in qs] - ), - (), - self.ctx, + Pure( + quantum_op(self.opname)( + ht.FunctionType( + [ht.Qubit for _ in qs] + [ROTATION_T], [ht.Qubit for _ in qs] + ), + (), + self.ctx, + ) ), *qs, rotation, diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/wasm.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/wasm.py index 094e6e5fe..1fa85d666 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/wasm.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/wasm.py @@ -1,6 +1,7 @@ from hugr import Wire, ops from hugr import tys as ht +from guppylang_internals.compiler.builder import Pure from guppylang_internals.definition.custom import CustomInoutCallCompiler from guppylang_internals.definition.value import CallReturnWires from guppylang_internals.error import InternalGuppyError @@ -11,6 +12,7 @@ WASM_EXTENSION, ConstWasmModule, ) +from guppylang_internals.tys import Effect from guppylang_internals.tys.builtin import ( wasm_module_name, ) @@ -36,7 +38,7 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: WASM_EXTENSION.get_op("get_context"), ht.FunctionType([ht.USize()], [ht.Option(ctx_ty)]), ) - node = self.builder.add_op(get_ctx_op, ctx_wire) + node = self.builder.add_op((get_ctx_op, [Effect.ANY]), ctx_wire) opt_w: Wire = node[0] err = "Failed to spawn WASM context" out_node = build_unwrap(self.builder, opt_w, err) @@ -50,7 +52,7 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: assert len(args) == 1 ctx = args[0] op = WASM_EXTENSION.get_op("dispose_context").instantiate([]) - self.builder.add_op(op, ctx) + self.builder.add_op((op, [Effect.ANY]), ctx) return CallReturnWires(regular_returns=[], inout_returns=[]) @@ -122,7 +124,7 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: ht.FunctionType([module_ty], [func_ty]), ) - wasm_func = self.builder.add_op(wasm_opdef, wasm_module) + wasm_func = self.builder.add_op(Pure(wasm_opdef), wasm_module) # Call the function call_op = WASM_EXTENSION.get_op("call").instantiate( @@ -130,13 +132,13 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [result_ty]), ) - result = self.builder.add_op(call_op, args[0], wasm_func, *args[1:]) + result = self.builder.add_op(Pure(call_op), args[0], wasm_func, *args[1:]) read_opdef = WASM_EXTENSION.get_op("read_result").instantiate( [output_row_arg], ht.FunctionType([result_ty], [ctx_ty, *wasm_sig.output]), ) - data = self.builder.add_op(read_opdef, result) + data = self.builder.add_op(Pure(read_opdef), result) match list(data[:]): case [ctx]: return CallReturnWires(regular_returns=[], inout_returns=[ctx]) diff --git a/guppylang-internals/src/guppylang_internals/tracing/function.py b/guppylang-internals/src/guppylang_internals/tracing/function.py index 44398e1d2..28fef169f 100644 --- a/guppylang-internals/src/guppylang_internals/tracing/function.py +++ b/guppylang-internals/src/guppylang_internals/tracing/function.py @@ -3,8 +3,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar -from hugr import ops - from guppylang_internals.ast_util import AstNode, with_loc, with_type from guppylang_internals.cfg.builder import tmp_vars from guppylang_internals.checker.core import ( @@ -18,6 +16,7 @@ from guppylang_internals.checker.errors.type_errors import TypeMismatchError from guppylang_internals.checker.unitary_checker import BBUnitaryChecker from guppylang_internals.compiler.builder import FunctionBuilder +from guppylang_internals.compiler.builder.ops import unpack_tuple from guppylang_internals.compiler.core import CompilerContext, DFContainer from guppylang_internals.compiler.expr_compiler import ExprCompiler from guppylang_internals.definition.value import CallableDef @@ -138,7 +137,7 @@ def trace_function( out_tys = type_to_row(out_obj._ty) if len(out_tys) > 1: regular_returns: list[Wire] = list( - builder.add_op(ops.UnpackTuple(), out_obj._use_wire(None)).outputs() + builder.add_op(unpack_tuple(), out_obj._use_wire(None)).outputs() ) elif len(out_tys) > 0: regular_returns = [out_obj._use_wire(None)] diff --git a/guppylang-internals/src/guppylang_internals/tracing/unpacking.py b/guppylang-internals/src/guppylang_internals/tracing/unpacking.py index 02f7adc6a..021abc687 100644 --- a/guppylang-internals/src/guppylang_internals/tracing/unpacking.py +++ b/guppylang-internals/src/guppylang_internals/tracing/unpacking.py @@ -1,13 +1,13 @@ from typing import Any, TypeVar -from hugr import ops +from hugr.ops import DfParentOp from guppylang_internals.ast_util import AstNode from guppylang_internals.checker.errors.comptime_errors import ( IllegalComptimeExpressionError, ) from guppylang_internals.checker.expr_checker import python_value_to_guppy_type -from guppylang_internals.compiler.builder import DFBuilder +from guppylang_internals.compiler.builder import DFBuilder, ops from guppylang_internals.compiler.core import CompilerContext from guppylang_internals.compiler.expr_compiler import python_value_to_hugr from guppylang_internals.error import GuppyComptimeError, GuppyError @@ -29,7 +29,7 @@ from guppylang_internals.tys.const import ConstValue from guppylang_internals.tys.ty import EnumType, NoneType, StructType, TupleType -P = TypeVar("P", bound=ops.DfParentOp) +P = TypeVar("P", bound=DfParentOp) def unpack_guppy_object( @@ -49,13 +49,13 @@ def unpack_guppy_object( case NoneType(): return None case TupleType(element_types=tys): - unpack = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)) + unpack = builder.add_op(ops.unpack_tuple(), obj._use_wire(None)) return tuple( unpack_guppy_object(GuppyObject(ty, wire), builder, frozen) for ty, wire in zip(tys, unpack.outputs(), strict=True) ) case StructType() as ty: - unpack = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)) + unpack = builder.add_op(ops.unpack_tuple(), obj._use_wire(None)) field_values = [ unpack_guppy_object(GuppyObject(field.ty, wire), builder, frozen) for field, wire in zip(ty.fields, unpack.outputs(), strict=True) @@ -98,12 +98,14 @@ def guppy_object_from_py( case TracingDefMixin() as defn: return defn.to_guppy_object() case None: - return GuppyObject(NoneType(), builder.add_op(ops.MakeTuple())) + return GuppyObject(NoneType(), builder.add_op(ops.make_tuple())) case tuple(vs): objs = [guppy_object_from_py(v, builder, node, ctx) for v in vs] return GuppyObject( TupleType([obj._ty for obj in objs]), - builder.add_op(ops.MakeTuple(), *(obj._use_wire(None) for obj in objs)), + builder.add_op( + ops.make_tuple(), *(obj._use_wire(None) for obj in objs) + ), ) case GuppyStructObject(_ty=struct_ty, _field_values=values): wires = [] @@ -117,7 +119,7 @@ def guppy_object_from_py( f"unexpected type. Expected `{f.ty}`, got `{obj._ty}`." ) wires.append(obj._use_wire(None)) - return GuppyObject(struct_ty, builder.add_op(ops.MakeTuple(), *wires)) + return GuppyObject(struct_ty, builder.add_op(ops.make_tuple(), *wires)) case GuppyEnumObject(_ty=enum_ty, _wire=wire): return GuppyObject(enum_ty, wire) case list(vs) if len(vs) > 0: @@ -171,7 +173,7 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DFBuilder) -> bool: case tuple(vs): assert isinstance(obj._ty, TupleType) wire_iterator = builder.add_op( - ops.UnpackTuple(), obj._use_wire(None) + ops.unpack_tuple(), obj._use_wire(None) ).outputs() for v, ty, out_wire in zip( vs, obj._ty.element_types, wire_iterator, strict=True @@ -182,7 +184,7 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DFBuilder) -> bool: case GuppyStructObject(_ty=ty, _field_values=values): assert obj._ty == ty wire_iterator = builder.add_op( - ops.UnpackTuple(), obj._use_wire(None) + ops.unpack_tuple(), obj._use_wire(None) ).outputs() for field, out_wire in zip(ty.fields, wire_iterator, strict=True): v = values[field.name] diff --git a/tests/error/effects_errors/array_read.err b/tests/error/effects_errors/array_read.err new file mode 100644 index 000000000..a7334ac15 --- /dev/null +++ b/tests/error/effects_errors/array_read.err @@ -0,0 +1,15 @@ +Error: Too many effects (at $FILE:6:11) + | +4 | @guppy(effects=[]) +5 | def pure_func(arr: array[int, 3]) -> int: +6 | return arr[0] + arr[1] + arr[2] + | ^^^ Callee of type `forall L, n: nat. (array[L, n], int) -> L` + | has effects `[ANY]` not allowed inside `pure_func` + +Note: + | +3 | +4 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/effects_errors/array_read.py b/tests/error/effects_errors/array_read.py new file mode 100644 index 000000000..a829073a1 --- /dev/null +++ b/tests/error/effects_errors/array_read.py @@ -0,0 +1,8 @@ +from guppylang.std.builtins import array +from guppylang.decorator import guppy + +@guppy(effects=[]) +def pure_func(arr: array[int, 3]) -> int: + return arr[0] + arr[1] + arr[2] + +pure_func.compile_function() \ No newline at end of file diff --git a/tests/error/wasm_errors/wasm_effects.err b/tests/error/wasm_errors/wasm_effects.err new file mode 100644 index 000000000..51d79a6e2 --- /dev/null +++ b/tests/error/wasm_errors/wasm_effects.err @@ -0,0 +1,14 @@ +Error: Too many effects (at $FILE:14:10) + | +12 | @guppy(effects=[]) +13 | def main() -> int: +14 | mod = Foo(0) + | ^^^ Call to `Foo` has effects `[ANY]` not allowed inside `main` + +Note: + | +11 | +12 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/wasm_errors/wasm_effects.py b/tests/error/wasm_errors/wasm_effects.py new file mode 100644 index 000000000..4048e172a --- /dev/null +++ b/tests/error/wasm_errors/wasm_effects.py @@ -0,0 +1,19 @@ +from guppylang import guppy +from guppylang_internals.decorator.wasm import wasm, wasm_module +from guppylang.std.quantum import qubit + +from tests.util import get_wasm_file + +@wasm_module(get_wasm_file()) +class Foo: + @wasm + def two(self: "Foo") -> int: ... + +@guppy(effects=[]) +def main() -> int: + mod = Foo(0) + q = mod.two() + mod.discard() + return q + +main.compile() diff --git a/tests/error/wasm_errors/wasm_effects2.err b/tests/error/wasm_errors/wasm_effects2.err new file mode 100644 index 000000000..ea1361d02 --- /dev/null +++ b/tests/error/wasm_errors/wasm_effects2.err @@ -0,0 +1,15 @@ +Error: Too many effects (at $FILE:15:4) + | +13 | @guppy(effects=[]) +14 | def dis(x: Foo @ owned) -> None: +15 | x.discard() + | ^^^^^^^^^ Call to `discard` has effects `[ANY]` not allowed inside + | `dis` + +Note: + | +12 | +13 | @guppy(effects=[]) + | ----------------- Allowed effects declared here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/wasm_errors/wasm_effects2.py b/tests/error/wasm_errors/wasm_effects2.py new file mode 100644 index 000000000..82c02c9fe --- /dev/null +++ b/tests/error/wasm_errors/wasm_effects2.py @@ -0,0 +1,24 @@ +from guppylang import guppy +from guppylang_internals.decorator.wasm import wasm, wasm_module +from guppylang.std.builtins import owned +from guppylang.std.quantum import qubit + +from tests.util import get_wasm_file + +@wasm_module(get_wasm_file()) +class Foo: + @wasm + def two(self: "Foo") -> int: ... + +@guppy(effects=[]) +def dis(x: Foo @ owned) -> None: + x.discard() + +@guppy +def main() -> int: + mod = Foo(0) + q = mod.two() + dis(mod) + return q + +main.compile() diff --git a/tests/integration/test_array_effects.py b/tests/integration/test_array_effects.py new file mode 100644 index 000000000..0823f8e46 --- /dev/null +++ b/tests/integration/test_array_effects.py @@ -0,0 +1,45 @@ +"""Tests of effects annotation.""" + +import pytest + +from guppylang_internals.error import GuppyTypeError + +from guppylang.std.builtins import array +from guppylang.emulator.exceptions import EmulatorError +from guppylang.decorator import guppy, Effect + + +def test_pure_array_new(validate): + @guppy(effects=[]) + def pure_func(x: int) -> array[int, 3]: + return array(x, x, x) + + validate(pure_func.compile_function()) + + +@pytest.mark.parametrize( + ("fx", "err_type", "msg"), + [ + ([], GuppyTypeError, "TooManyEffectsError"), + ([Effect.ANY], EmulatorError, "Array element is already borrowed"), + ], +) +def test_array_read_after_borrow(fx, err_type, msg, run_int_fn): + @guppy.struct + class MyStruct: + i: int + + T = guppy.type_var("T", copyable=False) + + @guppy(effects=fx) + def read(arr: array[T, 3]) -> T: + return arr.take(1) + + @guppy + def main() -> int: + arr = array(MyStruct(1), MyStruct(2), MyStruct(3)) + read(arr) + return arr[1].i + + with pytest.raises(err_type, match=msg): + run_int_fn(main, expected=0xDEADBEEF)