Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 139 additions & 121 deletions examples/qaoa_maxcut_example.ipynb

Large diffs are not rendered by default.

63 changes: 59 additions & 4 deletions guppylang/src/guppylang/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@

import guppylang
from guppylang.emulator import EmulatorBuilder, EmulatorInstance
from guppylang.emulator._args import (
EntrypointArgSpec,
unsupported_arg_reason,
wrap_entrypoint_with_args,
)
from guppylang.emulator.builder import Platform
from guppylang.emulator.exceptions import EmulatorBuildError

Expand Down Expand Up @@ -80,6 +85,13 @@ def input_names(self) -> str:
return ", ".join(f"`{x}`" for x in self.args)


@dataclass(frozen=True)
class UnsupportedEntrypointArgError(Error):
title: ClassVar[str] = "Unsupported entrypoint argument type"
span_label: ClassVar[str] = "{reason}"
reason: str


@dataclass(frozen=True)
class GuppyDefinition(TracingDefMixin):
"""A general Guppy definition."""
Expand Down Expand Up @@ -138,11 +150,17 @@ def emulator(
) -> EmulatorInstance:
"""Compile this function for emulation with the selene-sim emulator.

Calls `compile()` to get the HUGR package and then builds it using the
provided `EmulatorBuilder` configuration or a default one.
Compiles the function to a HUGR package and builds it using the provided
`EmulatorBuilder` configuration or a default one.

See :py:mod:`guppylang.emulator` for more details on the emulator.

If the entrypoint takes parameters, they become *runtime arguments*: the
entrypoint is wrapped so that each argument is read at run time, and values
are supplied to :py:meth:`EmulatorInstance.run` (or
:py:meth:`EmulatorInstance.run_per_shot`). Only `bool`, signed `int`,
`float`, and arrays of those are supported as argument types.


Args:
n_qubits: The number of qubits to allocate for the function. If it is not
Expand All @@ -160,13 +178,22 @@ def emulator(
Returns:
An `EmulatorInstance` that can be used to run the function in an emulator.
"""
mod = self.compile()
mod = self.compile_function()

if libs is not None:
mod = mod.link(*libs)

arg_specs = self._entrypoint_arg_specs()

if builder is None:
builder = EmulatorBuilder().with_platform(platform)

if arg_specs:
from selene_argreader_plugin import ArgReaderPlugin

wrap_entrypoint_with_args(mod, [spec.name for spec in arg_specs])
builder = builder.link_utility(ArgReaderPlugin())

qubits = n_qubits
if (
isinstance(self.wrapped, RawFunctionDef)
Expand All @@ -193,7 +220,35 @@ def emulator(
)
)

return builder.build(mod, n_qubits=qubits)
return builder.build(mod, n_qubits=qubits, arg_specs=arg_specs)

@pretty_errors
def _entrypoint_arg_specs(self) -> tuple[EntrypointArgSpec, ...]:
"""Validate and collect the runtime argument schema of the entrypoint.

Returns an empty tuple if the entrypoint takes no arguments. Raises a
`GuppyError` if any argument has an unsupported type.
"""
compiled_def = ENGINE.compiled.get((self.id, ()))
if not (
isinstance(compiled_def, CompiledCallableDef)
and len(compiled_def.ty.inputs) > 0
):
return ()

defined_at = cast("ast.FunctionDef", compiled_def.defined_at)
names = compiled_def.ty.input_names or []
specs: list[EntrypointArgSpec] = []
for i, inp in enumerate(compiled_def.ty.inputs):
reason = unsupported_arg_reason(inp.ty)
if reason is not None:
raise GuppyError(
UnsupportedEntrypointArgError(
span=to_span(defined_at.args.args[i]), reason=reason
)
)
specs.append(EntrypointArgSpec(name=names[i], ty=inp.ty))
return tuple(specs)

def compile(self) -> Package:
"""
Expand Down
261 changes: 261 additions & 0 deletions guppylang/src/guppylang/emulator/_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
"""Support for runtime arguments to emulator entrypoint functions.

A Guppy execution entrypoint is normally required to take no inputs. To support
*runtime arguments* (e.g. for variational workflows where the same program is run
many times with different parameters), the emulator compiles an entrypoint with
parameters by *wrapping* it:

* the wrapper takes no inputs,
* for each parameter it reads the value at runtime via the ``tket.argreader``
``read_arg`` op (tagged with the parameter name), and
* it calls the original entrypoint with those values.

Argument *values* are supplied at run time through selene's ``ArgProvider`` (see
:meth:`EmulatorInstance.run <guppylang.emulator.EmulatorInstance.run>` and
:meth:`run_per_shot <guppylang.emulator.EmulatorInstance.run_per_shot>`), keyed by
the parameter name.

This is currently an emulator-only (selene) feature: the ``read_arg`` op is only
lowered by the selene compiler, and argument values are provided through selene's
``ArgProvider``.

Only a restricted set of argument types is supported: ``bool``, signed ``int``,
``float``, and (one-dimensional) arrays of those. Unsigned ``nat`` is deliberately
not supported so that a single generic ``read_arg`` op suffices without signedness
ambiguity; take an ``int`` and convert in-script instead.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from guppylang_internals.tys.builtin import (
get_array_length,
get_element_type,
is_array_type,
is_bool_type,
)
from guppylang_internals.tys.const import ConstValue
from guppylang_internals.tys.ty import NumericType

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

from guppylang_internals.tys.ty import Type
from hugr.package import Package

#: Human-readable description of the supported entrypoint argument types.
SUPPORTED_ARG_TYPES_MSG = "`bool`, `int`, `float`, and arrays of those"

#: Python value types accepted by selene's argument provider.
ArgValue = int | float | bool | list[int] | list[float] | list[bool]


@dataclass(frozen=True)
class EntrypointArgSpec:
"""A single runtime argument expected by a wrapped entrypoint."""

name: str
ty: Type


class EntrypointArgValueError(ValueError):
"""Raised when runtime argument values don't match the entrypoint signature."""


def _is_supported_scalar(ty: Type) -> bool:
"""Whether ``ty`` is a supported scalar entrypoint argument type."""
if is_bool_type(ty):
return True
if isinstance(ty, NumericType):
return ty.kind in (NumericType.Kind.Int, NumericType.Kind.Float)
return False


def is_supported_arg_type(ty: Type) -> bool:
"""Whether ``ty`` may be used as an entrypoint runtime argument."""
if _is_supported_scalar(ty):
return True
if is_array_type(ty):
return _is_supported_scalar(get_element_type(ty))
return False


def unsupported_arg_reason(ty: Type) -> str | None:
"""Return ``None`` if ``ty`` is a supported argument type, otherwise a
human-readable explanation of why it is not."""
if is_supported_arg_type(ty):
return None
if isinstance(ty, NumericType) and ty.kind is NumericType.Kind.Nat:
return (
"Unsigned `nat` arguments are not supported. "
"Use a signed `int` argument and convert in your program if needed."
)
if is_array_type(ty):
elem = get_element_type(ty)
if is_array_type(elem):
return "Nested arrays are not supported as entrypoint arguments."
return (
f"Arrays of `{elem}` are not supported as entrypoint arguments. "
f"Supported element types are {SUPPORTED_ARG_TYPES_MSG}."
)
return (
f"Type `{ty}` is not supported as an entrypoint argument. "
f"Supported types are {SUPPORTED_ARG_TYPES_MSG}."
)


def _array_length(ty: Type) -> int | None:
"""Return the (concrete) length of an array type, or ``None`` if unknown."""
length = get_array_length(ty)
if isinstance(length, ConstValue) and isinstance(length.value, int):
return length.value
return None


def _value_error(ty: Type, value: object) -> str | None:
"""Return ``None`` if ``value`` is a valid argument of guppy type ``ty``,
otherwise a human-readable reason why it is not."""
if is_bool_type(ty):
return None if isinstance(value, bool) else "expected a `bool`"
if isinstance(ty, NumericType):
if ty.kind is NumericType.Kind.Int:
if isinstance(value, bool):
return "expected an `int`, got a `bool`"
return None if isinstance(value, int) else "expected an `int`"
if ty.kind is NumericType.Kind.Float:
if isinstance(value, bool):
return "expected a `float`, got a `bool`"
return None if isinstance(value, (int, float)) else "expected a `float`"
if is_array_type(ty):
elem = get_element_type(ty)
n = _array_length(ty)
if not isinstance(value, (list, tuple)):
return f"expected an array of length {n}"
if len(value) == 0:
return "arrays must be non-empty"
if n is not None and len(value) != n:
return f"expected an array of length {n}, got {len(value)}"
for item in value:
reason = _value_error(elem, item)
if reason is not None:
return f"array element {reason}"
return None
return f"unsupported argument type `{ty}`"


def _validate_record(
specs: Sequence[EntrypointArgSpec],
record: Mapping[str, object],
*,
shot: int | None = None,
) -> None:
where = f" for shot {shot}" if shot is not None else ""
expected = {spec.name for spec in specs}
given = set(record)
if missing := sorted(expected - given):
raise EntrypointArgValueError(
f"Missing value(s){where} for entrypoint argument(s): "
+ ", ".join(f"`{name}`" for name in missing)
)
if extra := sorted(given - expected):
raise EntrypointArgValueError(
f"Unexpected entrypoint argument(s){where}: "
+ ", ".join(f"`{name}`" for name in extra)
)
for spec in specs:
reason = _value_error(spec.ty, record[spec.name])
if reason is not None:
raise EntrypointArgValueError(
f"Invalid value{where} for entrypoint argument `{spec.name}`: {reason}"
)


def validate_constant_args(
specs: Sequence[EntrypointArgSpec], values: Mapping[str, object]
) -> None:
"""Validate constant argument values against the entrypoint signature."""
_validate_record(specs, values)


def validate_per_shot_args(
specs: Sequence[EntrypointArgSpec],
per_shot: Sequence[Mapping[str, object]],
) -> None:
"""Validate a list of per-shot argument records against the signature."""
if not per_shot:
raise EntrypointArgValueError(
"`run_per_shot` requires at least one shot's arguments."
)
for shot, record in enumerate(per_shot):
_validate_record(specs, record, shot=shot)


def wrap_entrypoint_with_args(package: Package, arg_names: Sequence[str]) -> None:
"""Rewrite the entrypoint of ``package`` so that it takes no inputs.

The original entrypoint ``f(a, b, ...)`` is replaced as the package entrypoint
by a no-input wrapper that reads each argument at runtime (via ``read_arg``,
tagged with the corresponding name from ``arg_names``) and calls ``f``.

Mutates ``package`` in place.
"""
import tket_exts
from guppylang_internals.std._internal.compiler.array import (
standard_array_type,
std_array_to_array,
)
from hugr import Wire, ops
from hugr import tys as ht
from hugr.build.function import Function
from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXT

borrow_array_def = BORROW_ARRAY_EXT.types["borrow_array"]

def read_arg_wire(wrapper: Function, name: str, ty: ht.Type) -> Wire:
"""Read a single runtime argument, bridging array representations.

Entrypoint array parameters are lowered to ``borrow_array``, but the
``read_arg`` extern fills a standard ``array``. For arrays we therefore read
a standard ``array`` and convert it to the ``borrow_array`` the entrypoint
expects (the mirror of how the result compiler converts the other way).
"""
if isinstance(ty, ht.ExtType) and ty.type_def is borrow_array_def:
length_arg, elem_arg = ty.args
assert isinstance(elem_arg, ht.TypeTypeArg)
elem_ty = elem_arg.ty
std_ty = standard_array_type(elem_ty, length_arg)
std_wire = wrapper.add_op(tket_exts.argreader.read_arg(name, std_ty))[0]
return wrapper.add_op(std_array_to_array(elem_ty, length_arg), std_wire)[0]
return wrapper.add_op(tket_exts.argreader.read_arg(name, ty))[0]

for module in package.modules:
entrypoint = module.entrypoint
op = module[entrypoint].op
if not isinstance(op, ops.FuncDefn) or len(op.inputs) == 0:
continue

input_types = list(op.inputs)
output_types = list(op.signature.body.output)
if len(arg_names) != len(input_types):
raise ValueError(
"Mismatch between entrypoint parameter names "
f"({len(arg_names)}) and HUGR inputs ({len(input_types)})."
)

wrapper = Function.new_nested(
ops.FuncDefn("__wrapped_entrypoint", [], visibility="Public"),
module,
module.module_root,
)
arg_wires = [
read_arg_wire(wrapper, name, ty)
for name, ty in zip(arg_names, input_types, strict=True)
]
call_node = wrapper.call(entrypoint, *arg_wires)
wrapper.set_outputs(*(call_node[i] for i in range(len(output_types))))
module.entrypoint = wrapper.parent_node
return

raise ValueError("No entrypoint with input parameters found in package.")
Loading
Loading