diff --git a/.github/workflows/ci-py.yml b/.github/workflows/ci-py.yml index 314a486..700c727 100644 --- a/.github/workflows/ci-py.yml +++ b/.github/workflows/ci-py.yml @@ -56,9 +56,8 @@ jobs: - name: Setup dependencies. run: uv sync --python ${{ matrix.python-version }} - # TODO: Fix lints and re-enable - #- name: Type check with mypy - # run: uv run mypy . + - name: Type check with mypy + run: uv run mypy . - name: Check formatting with ruff run: uv run ruff format --check diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml index b4a8483..6f9c3d0 100644 --- a/.github/workflows/ci-schema.yml +++ b/.github/workflows/ci-schema.yml @@ -53,33 +53,19 @@ jobs: sudo make install - name: Get cargo binstall uses: cargo-bins/cargo-binstall@main - - name: Install capnproto-rust plugin - run: cargo binstall capnpc - - name: Regenerate the Rust capnp code - run: | - capnp compile \ - -orust:impl/rs/src \ - --src-prefix=impl \ - impl/capnp/jeff.capnp - - name: Regenerate the C++ capnp code + - uses: extractions/setup-just@v3 + + - name: Regenerate the capnp bindings run: | - patch -p0 < impl/capnp/cpp_namespace.patch - capnp compile \ - -oc++:impl/cpp/src \ - --src-prefix=impl \ - impl/capnp/jeff.capnp - patch -p0 -R < impl/capnp/cpp_namespace.patch - - name: Re-encode the test .jeff files - run: ./examples/encode_examples.sh - - name: Copy the latest capnp schema to the python package - run: cp impl/capnp/jeff.capnp impl/py/src/jeff-format/data/jeff.capnp + just update-capnp + - name: Check if the generated capnproto code is up to date run: | git diff --exit-code \ impl/rs/src/capnp/ \ impl/cpp/src/capnp/ \ examples/ \ - impl/py/src/jeff-format/data/jeff.capnp + impl/py/src/jeff-format/capnp/jeff.capnp if [ $? -ne 0 ]; then echo "The capnp generated code is not up to date" echo "Please run 'just update-capnp' and commit the changes" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 62463ca..630de70 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,25 +40,24 @@ repos: - id: ruff-format name: ruff format description: Format python code with `ruff`. - entry: uv run ruff format + entry: uv run ruff format impl/py language: system files: \.py$ pass_filenames: false - id: ruff-check name: ruff description: Check python code with `ruff`. - entry: uv run ruff check --fix --exit-non-zero-on-fix + entry: uv run ruff check --fix --exit-non-zero-on-fix impl/py + language: system + files: \.py$ + pass_filenames: false + - id: mypy-check + name: mypy + description: Check python code with `mypy`. + entry: uv run mypy . language: system files: \.py$ pass_filenames: false - # TODO: Fix lints and re-enable - #- id: mypy-check - # name: mypy - # description: Check python code with `mypy`. - # entry: uv run mypy . - # language: system - # files: \.py$ - # pass_filenames: false - id: cargo-fmt name: cargo format description: Format rust code with `cargo fmt`. diff --git a/impl/py/pyproject.toml b/impl/py/pyproject.toml index a051394..94bb695 100644 --- a/impl/py/pyproject.toml +++ b/impl/py/pyproject.toml @@ -27,13 +27,15 @@ classifiers = [ ] dependencies = [ - # TODO: Temporarily disabled - # "pycapnp ~= 2.0.0" + "pycapnp~=2.0.0", ] [tool.hatch.build.targets.wheel] packages = ["src/jeff"] +[tool.uv.sources] +pycapnp = { git = "https://github.com/mlxd/pycapnp", branch = "mlxd/update_gh_actions" } + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/impl/py/src/jeff/__init__.py b/impl/py/src/jeff/__init__.py index 535acb9..b875a5b 100644 --- a/impl/py/src/jeff/__init__.py +++ b/impl/py/src/jeff/__init__.py @@ -13,1947 +13,163 @@ Disregarding this advice could lead to unexpected behaviour. All classes come with pretty-print string representation. Note that parsing is a non-goal. -""" - -from __future__ import annotations - -import textwrap -from abc import ABC, abstractmethod -from typing import Any, Iterable - -# from .capnp import load_schema - -# TODO: Temporarily disabled -# schema = load_schema() -schema = None - -# TODO: add remaining op instructions -# TODO: add methods to convert read-only data to cached (builder) instances, remove '_update_cache' -# TODO: introduce JeffString to reduce reliance on string table searching? -# TODO: parent field propagation (like '_func', '_parent', etc.) could be improved -# TODO: add metadata support - -######### -# Enums # -######### - -FloatPrecisions = (32, 64) - -Paulis = ("i", "x", "y", "z") - -KnownGates = ( - "gphase", - "i", - "x", - "y", - "z", - "s", - "t", - "r1", - "rx", - "ry", - "rz", - "h", - "u", - "swap", -) - -################ -# Core classes # -################ - - -class _Empty: - """Sentinal value for uninitialized fields.""" - - -class JeffType(ABC): - """Type information for values. This class is immutable. - Some types carry additional data like a bitwidth.""" - - _raw_data: schema.Type = None - - @staticmethod - def from_encoding(type: schema.Type): - cls = { - "qubit": QubitType, - "qureg": QuregType, - "int": IntType, - "intArray": IntArrayType, - "float": FloatType, - "floatArray": FloatArrayType, - }[str(type.which)] - obj = cls.__new__(cls) - obj._raw_data = type - return obj - - # Python integration - - def __eq__(self, other): - if not isinstance(other, type(self)): - return False - - if hasattr(self, "bitwidth"): - return self.bitwidth == other.bitwidth - - return True - - -class QubitType(JeffType): - """Specialization of the JeffType for qubit values.""" - - def _refresh(self, new_data: schema.Type.Builder): - """For immutable classes, just write the cached data into the encoding buffer.""" - new_data.qubit = None - self._raw_data = new_data.as_reader() - - # Python integration - - def __str__(self) -> str: - return "qubit" - - -class QuregType(JeffType): - """Specialization of the JeffType for qureg values.""" - - def _refresh(self, new_data: schema.Type.Builder): - """For immutable classes, just write the cached data into the encoding buffer.""" - new_data.qureg = None - self._raw_data = new_data.as_reader() - - # Python integration - - def __str__(self) -> str: - return "qureg" - - -class IntType(JeffType): - """Specialization of the JeffType for integer values.""" - - _bitwidth: int = _Empty - - def __init__(self, bitwidth: int): - self._bitwidth = bitwidth - - def _refresh(self, new_data: schema.Type.Builder): - """For immutable classes, just write the cached data into the encoding buffer.""" - new_data.int = self.bitwidth - self._raw_data = new_data.as_reader() - - # static fields - - @property - def bitwidth(self) -> int: - if self._bitwidth is not _Empty: - return self._bitwidth - - return self._raw_data.int - - # Python integration - - def __str__(self) -> str: - return f"int{self.bitwidth}" - - -class IntArrayType(JeffType): - """Specialization of the JeffType for integer arrays.""" - - _bitwidth: int = _Empty - - def __init__(self, bitwidth: int): - self._bitwidth = bitwidth - - def _refresh(self, new_data: schema.Type.Builder): - """For immutable classes, just write the cached data into the encoding buffer.""" - new_data.intArray = self.bitwidth - self._raw_data = new_data.as_reader() - - # static fields - - @property - def bitwidth(self) -> int: - if self._bitwidth is not _Empty: - return self._bitwidth - - return self._raw_data.intArray - - # Python integration - - def __str__(self) -> str: - return f"int{self._bitwidth}[]" - - -class FloatType(JeffType): - """Specialization of the JeffType for floating point values.""" - - _bitwidth: int = _Empty - - def __init__(self, bitwidth: int): - assert bitwidth in FloatPrecisions - self._bitwidth = bitwidth - - def _refresh(self, new_data: schema.Type.Builder): - """For immutable classes, just write the cached data into the encoding buffer.""" - new_data.float = f"float{self.bitwidth}" - self._raw_data = new_data.as_reader() - - # static fields - - @property - def bitwidth(self) -> int: - if self._bitwidth is not _Empty: - return self._bitwidth - - return 32 if self._raw_data.float == "float32" else 64 - - # Python integration - - def __str__(self) -> str: - return f"float{self.bitwidth}" - - -class FloatArrayType(JeffType): - """Specialization of the JeffType for floating point arrays.""" - - _bitwidth: int = _Empty - - def __init__(self, bitwidth: int): - assert bitwidth in FloatPrecisions - self._bitwidth = bitwidth - - def _refresh(self, new_data: schema.Type.Builder): - """For immutable classes, just write the cached data into the encoding buffer.""" - new_data.floatArray = f"float{self.bitwidth}" - self._raw_data = new_data.as_reader() - - # static fields - - @property - def bitwidth(self) -> int: - if self._bitwidth is not _Empty: - return self._bitwidth - - return 32 if self._raw_data.floatArray == "float32" else 64 - - # Python integration - - def __str__(self) -> str: - return f"float{self.bitwidth}[]" - - -class JeffValue: - """Program values represent dataflow between oprations, and defines the data type used. - This class is immutable, and holds an indentifier for the unique edge in the program. In an - encoded program, the indentifier is the index into the parent function's value table, whereas - during program consrtruction the identifier is the object's instance id. - """ - - _raw_data: schema.Value = None - _func: FunctionDef = None - # The value table index is used in reader mode both for pretty printing and comparing values. - _val_idx: int = None - - # cached attributes - _type: JeffType = _Empty - - def __init__(self, type: JeffType): - self._type = type - - @staticmethod - def from_encoding(idx: int, func: FunctionDef): - obj = JeffValue.__new__(JeffValue) - obj._raw_data = func._raw_data.definition.values[idx] - obj._func = func - obj._val_idx = idx - return obj - - def _refresh(self, new_data: schema.Value.Builder): - """For immutable classes, just write the cached data into the encoding buffer.""" - self.type._refresh(new_data.type) - self._raw_data = new_data.as_reader() - - # static attributes - - @property - def type(self): - if self._type is not _Empty: - return self._type - - return JeffType.from_encoding(self._raw_data.type) - - # convenience methods - - @property - def id(self) -> int: - if self._val_idx is not None: - return self._val_idx - - return id(self) - - # Python integration - - def __str__(self): - return f"%{self.id}:{self.type}" - - def __eq__(self, other): - if not isinstance(other, JeffValue): - return False - - if self._val_idx is not None: - return self._func is other._func and self._val_idx == other._val_idx - - return self is other - - -class JeffOp: - """A generic container for all operations in the program. The common fields include input and - output values, as well as the kind of operation represented. All ops have a main kind - (like QubitOp and IntOp), as well as a subkind (like alloc, add, etc.). Some operations store - additional data as well, which can be primitive types as well extra classes defined in the API. - """ - - _is_dirty: bool - _raw_data: schema.Op = None - _func: FunctionDef = None - - # cached attributes - _kind: str = _Empty - _subkind: str = _Empty - _inputs: list[JeffValue] = _Empty - _outputs: list[JeffValue] = _Empty - _instruction_data: JeffGate | JeffSCF | Any | None = _Empty - - def __init__( - self, - kind: str, - subkind: str, - inputs: list[JeffValue], - outputs: list[JeffValue], - instruction_data=None, - ): - self._kind = kind - self._subkind = subkind - self._inputs = inputs - self._outputs = outputs - if isinstance(instruction_data, (JeffGate, JeffSCF)): - instruction_data._op = self - self._instruction_data = instruction_data - self._mark_dirty() - - @staticmethod - def from_encoding(op: schema.Op, func: FunctionDef): - obj = JeffOp.__new__(JeffOp) - obj._raw_data = op - obj._func = func - obj._mark_clean() - return obj - - @property - def is_dirty(self) -> bool: - """Whether the object has been modified since the last time it was encoded. Also returns - True if the object has never been written out (e.g. after instantiation).""" - return self._is_dirty - - def _mark_clean(self): - """An object can be marked clean by itself, but a dirty tag is always propagated upward.""" - self._is_dirty = False - - def _mark_dirty(self): - """An object can be marked clean by itself, but a dirty tag is always propagated upward.""" - self._is_dirty = True - if self._func: - self._func._mark_dirty() - - def _refresh(self, new_data: schema.Op.Builder, string_table: list[str]): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - - _inputs = self.inputs - inputs = new_data.init("inputs", len(_inputs)) - for i, val in enumerate(_inputs): - inputs[i] = val._val_idx # no need to search the value table - - _ouputs = self.outputs - outptus = new_data.init("outputs", len(_ouputs)) - for i, val in enumerate(_ouputs): - outptus[i] = val._val_idx - - instruction_group = new_data.instruction.init(self.kind) - - _data = self.instruction_data - if isinstance(_data, JeffGate): - gate = instruction_group.init("gate") - _data._refresh(gate, string_table) - elif isinstance(_data, JeffSCF): - _data._refresh(instruction_group, string_table) - else: # TODO: array ops might need different initialization due to lists - setattr(instruction_group, self.subkind, _data) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def _update_cache(self): - """TESTING ONLY. Update the cached attributes of this object. This effectively transitions - the object from "reader" mode to "writer" mode, e.g. as part of building a new module.""" - self._instruction_data = self.instruction_data - if isinstance(self._instruction_data, (JeffGate, JeffSCF)): - self._instruction_data._update_cache() - - # cached fields - - @property - def inputs(self) -> list[JeffValue]: - if self._inputs is not _Empty: - return self._inputs - - return [ - JeffValue.from_encoding(inp, self._func) for inp in self._raw_data.inputs - ] - - @inputs.setter - def inputs(self, inputs: list[JeffValue]): - self._inputs = inputs - self._mark_dirty() - - @property - def outputs(self) -> list[JeffValue]: - if self._outputs is not _Empty: - return self._outputs - - return [ - JeffValue.from_encoding(out, self._func) for out in self._raw_data.outputs - ] - - @outputs.setter - def outputs(self, outputs: list[JeffValue]): - self._outputs = outputs - self._mark_dirty() - - @property - def instruction_data(self) -> JeffGate | JeffSCF | Any | None: - """Get instruction details if they exist. Sometimes this is basic data, othertimes another class.""" - if self._instruction_data is not _Empty: - return self._instruction_data - - if self.kind == "qubit" and self.subkind == "gate": - gate = self._raw_data.instruction.qubit.gate - return JeffGate.from_encoding(gate, self) - elif self.kind == "scf": - scf = self._raw_data.instruction.scf - return JeffSCF.from_encoding(scf, self) - - return getattr(getattr(self._raw_data.instruction, self.kind), self.subkind) - - @instruction_data.setter - def instruction_data(self, data): - if isinstance(data, (JeffGate, JeffSCF)): - data._op = self - self._instruction_data = data - - self._mark_dirty() - - # static fields - - @property - def kind(self) -> str: - if self._kind is not _Empty: - return self._kind - - return str(self._raw_data.instruction.which) - - @property - def subkind(self) -> str: - if self._subkind is not _Empty: - return self._subkind - - return str(getattr(self._raw_data.instruction, self.kind).which) - - # convenience methods - - @property - def instruction_name(self) -> str: - return f"{self.kind}.{self.subkind}" - - # Python integrations - - def __str__(self): - string = "" - - if outputs := self.outputs: - string += ", ".join(str(out) for out in outputs) - string += " = " - - string += f"{self.instruction_name} " - - string += ", ".join(str(inp) for inp in self.inputs) - - if (data := self.instruction_data) is not None: - string += f" {data}" - - return string - - -class JeffRegion: - """A region is container for operations, and defines input and output ports. Regions do not - allow value edges across it.""" - - _is_dirty: bool - _raw_data: schema.Region = None - _parent: FunctionDef | JeffSCF = None - - # cached attributes - _sources: list[JeffValue] = _Empty - _targets: list[JeffValue] = _Empty - _operations: list[JeffOp] = _Empty - - def __init__( - self, - sources: list[JeffValue], - targets: list[JeffValue], - operations: list[JeffOp], - ): - self._sources = sources - self._targets = targets - if func := self.parent_func: - for op in operations: - op._func = func - self._operations = operations - self._mark_dirty() - - @staticmethod - def from_encoding(region: schema.Region, parent: FunctionDef | JeffOp): - obj = JeffRegion.__new__(JeffRegion) - obj._raw_data = region - obj._parent = parent - obj._mark_clean() - return obj - - @property - def is_dirty(self) -> bool: - """Whether the object has been modified since the last time it was encoded. Also returns - True if the object has never been written out (e.g. after instantiation).""" - return self._is_dirty - - def _mark_clean(self): - """An object can be marked clean by itself, but a dirty tag is always propagated upward.""" - self._is_dirty = False - - def _mark_dirty(self): - """An object can be marked clean by itself, but a dirty tag is always propagated upward.""" - self._is_dirty = True - if self._parent: - self._parent._mark_dirty() - - def _refresh(self, new_data: schema.Region.Builder, string_table: list[str]): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - - _sources = self.sources - sources = new_data.init("sources", len(_sources)) - for i, val in enumerate(_sources): - sources[i] = val._val_idx # no need to search the value table - - _targets = self.targets - targets = new_data.init("targets", len(_targets)) - for i, val in enumerate(_targets): - targets[i] = val._val_idx - - _operations = self.operations - operations = new_data.init("operations", len(_operations)) - for i, op in enumerate(_operations): - op._refresh(operations[i], string_table) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def _update_cache(self): - """TESTING ONLY. Update the cached attributes of this object. This effectively transitions - the object from "reader" mode to "writer" mode, e.g. as part of building a new module.""" - self._sources = self.sources - self._targets = self.targets - self._operations = self.operations - for op in self._operations: - op._update_cache() - - # settable fields - - @property - def sources(self) -> list[JeffValue]: - if self._sources is not _Empty: - return self._sources - - return [ - JeffValue.from_encoding(source, self.parent_func) - for source in self._raw_data.sources - ] - - @sources.setter - def sources(self, sources: list[JeffValue]): - self._sources = sources - self._mark_dirty() - - @property - def targets(self) -> list[JeffValue]: - if self._targets is not _Empty: - return self._targets - - return [ - JeffValue.from_encoding(target, self.parent_func) - for target in self._raw_data.targets - ] - - @targets.setter - def targets(self, targets: list[JeffValue]): - # no need - self._targets = targets - self._mark_dirty() - - @property - def operations(self) -> list[JeffOp]: - if self._operations is not _Empty: - return self._operations - - return [ - JeffOp.from_encoding(op, self.parent_func) - for op in self._raw_data.operations - ] - - @operations.setter - def operations(self, operations: list[JeffOp]): - if func := self.parent_func: - for op in operations: - op._func = func - self._operations = operations - self._mark_dirty() - - # convenience methods - - @property - def parent_func(self) -> FunctionDef | None: - if isinstance(self._parent, FunctionDef): - return self._parent - elif isinstance(self._parent, JeffSCF): - return getattr(self._parent._op, "_func", None) - - return None - - # Python integration - - def __getitem__(self, idx): - if self._operations is not _Empty: - return self._operations[idx] - - return JeffOp.from_encoding(self._raw_data.operations[idx], self.parent_func) - - def __str__(self): - string = "" - - string += " in :" - if sources := self.sources: - string += f" {', '.join(str(src) for src in sources)}" - string += "\n" - - for op in self: - string += f"{textwrap.indent(str(op), ' ')}\n" - - string += " out:" - if targets := self.targets: - string += f" {', '.join(str(tgt) for tgt in targets)}" - string += "" - - return string - - -class JeffFunc(ABC): - """Jeff supports both function definitions (with a body) and declarations (with a signature). - For both the name is stored as a string attribute.""" - - _is_dirty: bool - _raw_data: schema.Function = None - _module: JeffModule = None - - # cached attributes - _name: str = _Empty - - @staticmethod - def from_encoding(func: schema.Function, module: JeffModule): - """Construct a function from encoded data. This provides a zero-copy view of the data.""" - cls = {"definition": FunctionDef, "declaration": FunctionDecl}[str(func.which)] - obj = cls.__new__(cls) - obj._raw_data = func - obj._module = module - obj._mark_clean() - return obj - - @property - def is_dirty(self) -> bool: - """Whether the object has been modified since the last time it was encoded. Also returns - True if the object has never been written out (e.g. after instantiation).""" - return self._is_dirty - - def _mark_clean(self): - """An object can be marked clean by itself, but a dirty tag is always propagated upward.""" - self._is_dirty = False - - def _mark_dirty(self): - """An object can be marked clean by itself, but a dirty tag is always propagated upward.""" - self._is_dirty = True - if self._module: - self._module._mark_dirty() - - # settable fields - - @property - def name(self) -> str: - if self._name is not _Empty: - return self._name - - assert not self._module.is_dirty, ( - "The parent module is dirty and no name has been cached. " - "Please call `refresh` on the module to access this attribute." - ) - - idx = self._raw_data.name - return self._module.string_table[idx] - - @name.setter - def name(self, name: str): - self._name = name - self._mark_dirty() - - # convenience methods - - @property - @abstractmethod - def function_type(self) -> tuple[list[JeffType], list[JeffType]]: - """Return the input/output type signature of the function.""" - - # Python integration - - def __str__(self): - input_types, output_types = self.function_type - - string = f"func @{self.name}" - string += f"({', '.join(str(ty) for ty in input_types)})" - string += " -> " - string += f"({', '.join(str(ty) for ty in output_types)})" - - if isinstance(self, FunctionDef): - string += f":\n{self.body}" - else: - assert isinstance(self, FunctionDecl) - string += ";" - - return string - - -class FunctionDef(JeffFunc): - """Function definitions contain a single region determining the call signature of the function. - The encoded object also contains a value table for all typed values in the program.""" - - # cached attributes - _body: JeffRegion = _Empty - - def __init__(self, name: str, body: JeffRegion): - self._name = name - body._parent = self - self._body = body - self._mark_dirty() - - def _refresh(self, new_data: schema.Function.Builder, string_table: list[str]): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - definition = new_data.init("definition") - - _values = self._compute_values() - values = definition.init("values", len(_values)) - for i, val in enumerate(_values): - val._refresh(values[i]) - # updating the value index here means we don't need to pass the value table - # down to operations etc, since the index is stored in the value itself - val._val_idx = i - - self.body._refresh(definition.body, string_table) - - # strings are stored as indices in the encoded format - new_data.name = string_table.index(self.name) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def _update_cache(self): - """TESTING ONLY. Update the cached attributes of this object. This effectively transitions - the object from "reader" mode to "writer" mode, e.g. as part of building a new module.""" - self._name = self.name - self._body = self.body - self.body._update_cache() - - # settable fields - - @property - def body(self) -> JeffRegion: - if self._body is not _Empty: - return self._body - - return JeffRegion.from_encoding(self._raw_data.definition.body, self) - - @body.setter - def body(self, body: JeffRegion): - for op in body.operations: - op._func = self - body._parent = self - self._body = body - self._mark_dirty() - - # encoding-only fields - - @property - def value_table(self) -> Iterable[schema.Value]: - assert not self.is_dirty, ( - "The FunctionDef contains some cached data, but the value table is only accessible " - "from the encoded format. Please call `refresh` before accessing this attribute." - ) - - return self._raw_data.definition.values - - def _compute_values(self) -> list[JeffValue]: - """Get a fresh value table from the cached program. Requires whole-function traversal.""" - values = [] - regions = [self.body] - - while regions: - current_region = regions.pop(0) - - for val in current_region.sources: - values.append(val) - - for op in current_region: - for val in op.outputs: - values.append(val) - - data = op.instruction_data - if isinstance(data, SwitchSCF): - for branch in data.branches: - regions.append(branch) - if data.default: - regions.append(data.default) - elif isinstance(data, ForSCF): - regions.append(data.body) - elif isinstance(data, (WhileSCF, DoWhileSCF)): - regions.append(data.condition) - regions.append(data.body) - - return values - - # convenience methods - - @property - def sources(self) -> list[JeffValue]: - return self.body.sources - - @property - def targets(self) -> list[JeffValue]: - return self.body.targets - - @property - def function_type(self) -> tuple[list[JeffType], list[JeffType]]: - input_types = [inp.type for inp in self.sources] - output_types = [out.type for out in self.targets] - return input_types, output_types - - # Python integration - - def __getitem__(self, idx): - if self._body is not _Empty: - return self._body[idx] - - return JeffOp.from_encoding( - self._raw_data.definition.body.operations[idx], self - ) - - -class FunctionDecl(JeffFunc): - """Function declarations contain only the input/output type signature.""" - - # cached attributes - _inputs: list[JeffType] = _Empty - _outputs: list[JeffType] = _Empty - def __init__(self, name: str, inputs: list[JeffType], outputs: list[JeffType]): - self._name = name - self._inputs = inputs - self._outputs = outputs - self._mark_dirty() +.. data:: JEFF_VERSION - def _refresh(self, new_data: schema.Function.Builder, string_table: list[str]): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - declaration = new_data.init("declaration") + Current version of the *jeff* format. - _inputs = self.inputs - inputs = declaration.init("inputs", len(_inputs)) - for i, input in enumerate(_inputs): - input._refresh(inputs[i]) - - _outputs = self.outputs - outputs = declaration.init("outputs", len(_outputs)) - for i, output in enumerate(_outputs): - output._refresh(outputs[i]) - - # strings are stored as indices in the encoded format - new_data.name = string_table.index(self.name) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def _update_cache(self): - """TESTING ONLY. Update the cached attributes of this object. This effectively transitions - the object from "reader" mode to "writer" mode, e.g. as part of building a new module.""" - self._name = self.name - self._inputs = self.inputs - self._outputs = self.outputs - - # settable fields - - @property - def inputs(self) -> list[JeffType]: - if self._inputs is not _Empty: - return self._inputs - - return [ - JeffType.from_encoding(inp.type) - for inp in self._raw_data.declaration.inputs - ] - - @inputs.setter - def inputs(self, inputs: list[JeffType]): - self._inputs = inputs - self._mark_dirty() - - @property - def outputs(self) -> list[JeffType]: - if self._outputs is not _Empty: - return self._outputs - - return [ - JeffType.from_encoding(out.type) - for out in self._raw_data.declaration.outputs - ] - - @outputs.setter - def outputs(self, outputs: list[JeffType]): - self._outputs = outputs - self._mark_dirty() - - # convenience methods - - @property - def function_type(self) -> tuple[list[JeffType], list[JeffType]]: - return self.inputs, self.outputs - - -class JeffModule: - """The module is the root node in the program. It's a container for functions, - as well as certain metadata. The encoded object also stores a string table for all string - attributes in the program.""" - - _is_dirty: bool - _raw_data: schema.Module = None - - # cached attributes - _functions: list[JeffFunc] = _Empty - _entrypoint: int = _Empty - _version: int = _Empty - _tool: str = _Empty - _tool_version: str = _Empty - - def __init__( - self, - functions: list[JeffFunc], - entrypoint: int = 0, - version: int = 0, - tool: str = "", - tool_version: str = "", - ): - """Build a JeffModule from its children fields. The data is cached until `write-out` - is called, upon which the data is encoded in the jeff binary format.""" - for func in functions: - func._module = self - self._functions = functions - self._entrypoint = entrypoint - self._version = version - self._tool = tool - self._tool_version = tool_version - self._mark_dirty() - - @staticmethod - def from_encoding(module: schema.Module): - """Construct a JeffModule from encoded data. This provides a zero-copy view of the data.""" - obj = JeffModule.__new__(JeffModule) - obj._raw_data = module - obj._mark_clean() - return obj - - @property - def is_dirty(self) -> bool: - """Whether the object has been modified since the last time it was encoded. Also returns - True if the object has never been written out (e.g. after instantiation).""" - return self._is_dirty - - def _mark_clean(self): - self._is_dirty = False - - def _mark_dirty(self): - self._is_dirty = True - - def refresh(self): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - if not self.is_dirty: - return - - # Reusing an existing message is a bad idea as any new allocations will leave the old ones - # in the message, bloating its size. - new_data = schema.Module.new_message() - - _strings = self._compute_strings() - strings = new_data.init("strings", len(_strings)) - for i, string in enumerate(_strings): - strings[i] = string - - functions = new_data.init("functions", len(self._functions)) - for i, func in enumerate(self._functions): - func._refresh(functions[i], _strings) - - new_data.entrypoint = self.entrypoint - new_data.version = self.version - new_data.tool = self.tool - new_data.toolVersion = self.tool_version - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def write_out(self, path: str = None): - """Write out the program to file. Only available on the module object as the root node. - Automatically calls `refresh` before writing. - """ - self.refresh() - - with open(path, "wb") as f: - self._raw_data.as_builder().write(f) - - # settable fields - - @property - def functions(self) -> list[JeffFunc]: - """For read-only access, iterate over / index into the module directly.""" - if self._functions is not _Empty: - return self._functions - - return [JeffFunc.from_encoding(func, self) for func in self._raw_data.functions] - - @functions.setter - def functions(self, functions: list[JeffFunc]): - for func in functions: - # "adopting" a read-only object will detach it from its original encoded message, - # so let's load any data associated to it into cache - func._update_cache() - func._module = self - self._functions = functions - self._mark_dirty() - - # encoding-only fields - - @property - def string_table(self) -> Iterable[str]: - assert not self.is_dirty, ( - "The JeffModule contains some cached data, but the string table is only accessible " - "from the encoded format. Please call `refresh` or `write_out` before accessing this " - "attribute." - ) - return self._raw_data.strings - - def _compute_strings(self) -> list[str]: - """Get a fresh string table from the cached program. Requires whole-program traversal.""" - strings = set() - regions = [] - - for func in self._functions: - strings.add(func.name) - regions.append(func.body) - - while regions: - current_region = regions.pop(0) - - for op in current_region: - data = op.instruction_data - if isinstance(data, CustomGate): - strings.add(data.name) - - if isinstance(data, SwitchSCF): - for branch in data.branches: - regions.append(branch) - if data.default: - regions.append(data.default) - elif isinstance(data, ForSCF): - regions.append(data.body) - elif isinstance(data, (WhileSCF, DoWhileSCF)): - regions.append(data.condition) - regions.append(data.body) - - return list(strings) - - # static fields - - @property - def entrypoint(self) -> int: - if self._entrypoint is not _Empty: - return self._entrypoint - - return self._raw_data.entrypoint - - @property - def version(self) -> int: - if self._version is not _Empty: - return self._version - - return self._raw_data.version - - @property - def tool(self) -> str: - if self._version is not _Empty: - return self._tool - - return self._raw_data.tool - - @property - def tool_version(self) -> str: - if self._tool_version is not _Empty: - return self._tool_version - - return self._raw_data.toolVersion +""" - # Python integration +from __future__ import annotations +from typing import Annotated, Type - def __getitem__(self, idx): - if self._functions is not _Empty: - return self._functions[idx] +from jeff.op.qubit.non_unitary import QubitAlloc, QubitFree +from jeff.op.scf import SwitchSCF +from jeff.region import Region - return JeffFunc.from_encoding(self._raw_data.functions[idx], self) +from .module import Module +from .op import JeffOp +from .op.qubit import PPRGate, QubitGate, Pauli +from .type import FloatType, IntType, JeffType, QubitType +from .value import Value - def __str__(self): - string = f"jeff v{self.version}" +from .capnp import schema - if self.tool: - string += f", {self.tool} v{self.tool_version}" - string += "\n\n" - for i, func in enumerate(self): - string += f"{'[entry] ' if i == self.entrypoint else ''}{func}\n" +# This is updated by our release-please workflow, triggered by this +# annotation: x-release-please-version +__version__ = "0.1.0" - return string +# Current version of the *jeff* format. +JEFF_VERSION = Annotated[str, "feet"] +# TODO: add remaining op instructions +# TODO: add metadata support ################ -# Instructions # +# Reading # ################ -class JeffGate(ABC): - """Instruction data for quantum gate operations.""" - - _is_dirty: bool - _raw_data: schema.QubitGate = None - _op: JeffOp = None - - # common cached fields - _num_controls: int = _Empty - _adjoint: bool = _Empty - _power: int = _Empty - - def from_encoding(gate: schema.QubitGate, op: JeffOp): - cls = {"custom": CustomGate, "wellKnown": WellKnowGate, "ppr": PPRGate}[ - str(gate.which) - ] - obj = cls.__new__(cls) - obj._raw_data = gate - obj._op = op - obj._mark_clean() - return obj - - @property - def is_dirty(self): - return self._is_dirty - - def _mark_clean(self): - self._is_dirty = False - - def _mark_dirty(self): - self._is_dirty = True - if self._op: - self._op._mark_dirty() - - def _refresh(self, new_data: schema.QubitGate.Builder): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - - new_data.controlQubits = self.num_controls - new_data.adjoint = self.adjoint - new_data.power = self.power - - # settable fields - - @property - def num_controls(self) -> int: - if self._num_controls is not _Empty: - return self._num_controls - - return self._raw_data.controlQubits - - @num_controls.setter - def num_controls(self, num_controls: int): - self._num_controls = num_controls - self._mark_dirty() - - @property - def adjoint(self) -> bool: - if self._adjoint is not _Empty: - return self._adjoint - - return self._raw_data.adjoint - - @adjoint.setter - def adjoint(self, adjoint: int): - self._adjoint = adjoint - self._mark_dirty() - - @property - def power(self) -> int: - if self._power is not _Empty: - return self._power - - return self._raw_data.power - - @power.setter - def power(self, power: int): - self._power = power - self._mark_dirty() - - # Python integration - - def __str__(self): - string = "" - if num_controls := self.num_controls: - string += f"numControls={num_controls}, " - if self.adjoint: - string += "adjoint, " - if (power := self.power) != 1: - string += f"power={power}, " - return string - - -class WellKnowGate(JeffGate): - """Specialization of gate intruction data for well-known gates. Well-known gates must be one of - the gates defined in the spec. No additional data needs to be specified.""" - - _kind: str = _Empty - - def __init__(self, kind: str, num_controls: int, adjoint: bool, power: int): - assert kind in KnownGates - self._kind = kind - self._num_controls = num_controls - self._adjoint = adjoint - self._power = power - self._mark_dirty() - - def _refresh(self, new_data: schema.QubitGate.Builder, _string_table): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - - new_data.wellKnown = self.kind - super()._refresh(new_data) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - # settable fields - - @property - def kind(self): - if self._kind is not _Empty: - return self._kind - - return str(self._raw_data.wellKnown) - - @kind.setter - def kind(self, kind: str): - assert kind in KnownGates - self._kind = kind - self._mark_dirty() - - # convenience methods - - @property - def num_qubits(self): - match self.kind: - case "gphase": - return 0 - case ( - "i" - | "x" - | "y" - | "z" - | "s" - | "t" - | "r1" - | "rx" - | "ry" - | "rz" - | "h" - | "u" - ): - return 1 - case "swap": - return 2 - - assert False, "unknown gate" - - @property - def num_params(self): - match self.kind: - case "i" | "x" | "y" | "z" | "s" | "t" | "h": - return 0 - case "gphase" | "r1" | "rx" | "ry" | "rz": - return 1 - case "u": - return 3 - - assert False, "unknown gate" - - # Python integration - - def __str__(self): - string = f"({self.kind}, " - string += super().__str__() - string = string[:-2] + ")" - return string - - -class CustomGate(JeffGate): - """Specialization of gate intruction data for custom gates. Custom gates are identified by a - string name, and also have to provide the number of qubits and number float parameters.""" - - _name: str = _Empty - _num_qubits: int = _Empty - _num_params: int = _Empty - - def __init__( - self, - name: str, - num_qubits: int, - num_params: int, - num_controls: int, - adjoint: bool, - power: int, - ): - self._name = name - self._num_qubits = num_qubits - self._num_params = num_params - self._num_controls = num_controls - self._adjoint = adjoint - self._power = power - self._mark_dirty() - - def _refresh(self, new_data: schema.QubitGate.Builder, string_table: list[str]): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - custom = new_data.init("custom") - - custom.name = string_table.index(self.name) - custom.numQubits = self.num_qubits - custom.numParams = self.num_params - super()._refresh(new_data) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def _update_cache(self): - """TESTING ONLY. Update the cached attributes of this object. This effectively transitions - the object from "reader" mode to "writer" mode, e.g. as part of building a new module.""" - self._name = self.name - - # settable fields - - @property - def name(self) -> str: - if self._name is not _Empty: - return self._name - - assert ( - (func := self._op._func) and (mod := func._module) and not mod.is_dirty - ), ( - "The parent module is not present or dirty, and no name has been cached. " - "Please call `refresh` on the module to access this attribute." - ) - - return self._op._func._module.string_table[self._raw_data.custom.name] - - @name.setter - def name(self, name: str): - self._name = name - self._mark_dirty() - - @property - def num_qubits(self) -> int: - if self._num_qubits is not _Empty: - return self._num_qubits - - return self._raw_data.custom.numQubits - - @num_qubits.setter - def num_qubits(self, num_qubits: int): - self._num_qubits = num_qubits - self._mark_dirty() - - @property - def num_params(self) -> int: - if self._num_params is not _Empty: - return self._num_params - - return self._raw_data.custom.numParams - - @num_params.setter - def num_params(self, num_params: int): - self._num_params = num_params - self._mark_dirty() - - # Python integration - - def __str__(self): - string = f'("{self.name}", ' - string += f"numQubits={self.num_qubits}, " - if numParams := self.num_params: - string += f"numParams={numParams}, " - string += super().__str__() - string = string[:-2] + ")" - return string - - -class PPRGate(JeffGate): - """Specialization of gate intruction data for pauli-product rotation gates. Custom gates are - identified by a string name, and also have to provide the number of qubits and number float - parameters.""" - - _pauli_string: list[str] = _Empty - - def __init__( - self, pauli_string: list[str], num_controls: int, adjoint: bool, power: int - ): - self._pauli_string = pauli_string - self._num_controls = num_controls - self._adjoint = adjoint - self._power = power - self._mark_dirty() - - def _refresh(self, new_data: schema.QubitGate.Builder, _string_table): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - ppr = new_data.init("ppr") - - _pauli_string = self.pauli_string - pauli_string = ppr.init("pauliString", len(_pauli_string)) - for i, pauli in enumerate(_pauli_string): - pauli_string[i] = pauli - super()._refresh(new_data) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - # settable fields - - @property - def pauli_string(self) -> list[str]: - if self._pauli_string is not _Empty: - return self._pauli_string - - return [str(pauli) for pauli in self._raw_data.ppr.pauliString] - - @pauli_string.setter - def pauli_string(self, pauli_string: list[str]): - assert all(pauli in Paulis for pauli in pauli_string) - self._pauli_string = pauli_string - self._mark_dirty() - - # Python integration - - def __str__(self): - string = "(PPR, " - string += f"pauliString={self.pauli_string}, " - string += super().__str__() - string = string[:-2] + ")" - return string - - -class JeffSCF(ABC): - """Instruction data for a structured control-flow (SCF) operations.""" - - _is_dirty: bool - _raw_data: schema.ScfOp = None - _op: JeffOp = None - - def from_encoding(scf: schema.ScfOp, op: JeffOp): - cls = { - "switch": SwitchSCF, - "for": ForSCF, - "while": WhileSCF, - "doWhile": DoWhileSCF, - }[str(scf.which)] - obj = cls.__new__(cls) - obj._raw_data = scf - obj._op = op - obj._mark_clean() - return obj - - @property - def is_dirty(self): - return self._is_dirty - - def _mark_clean(self): - self._is_dirty = False - - def _mark_dirty(self): - self._is_dirty = True - if self._op: - self._op._mark_dirty() - - -class SwitchSCF(JeffSCF): - """Switch-statement specialization of the JeffSCF instruction data class. - Switch operations contain a list of regions that are indexed into by an integer parameter, - as well as an optional default region that is triggered when the index is out of bounds. - All regions must have the same input/output port signature.""" - - _branches: list[JeffRegion] = _Empty - _default: JeffRegion = _Empty - - def __init__(self, branches: list[JeffRegion], default: JeffRegion = None): - for branch in branches: - branch._parent = self - self._branches = branches - if default: - default._parent = self - self._default = default - self._mark_dirty() - - def _refresh(self, new_data: schema.ScfOp.Builder, string_table: list[str]): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - switch = new_data.init("switch") - - _branches = self.branches - branches = switch.init("branches", len(_branches)) - for i, branch in enumerate(_branches): - branch._refresh(branches[i], string_table) - - if _default := self.default: - _default._refresh(switch.default, string_table) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def _update_cache(self): - """TESTING ONLY. Update the cached attributes of this object. This effectively transitions - the object from "reader" mode to "writer" mode, e.g. as part of building a new module.""" - self._branches = self.branches - self._default = self.default - for branch in self.branches: - branch._update_cache() - if default := self.default: - default._update_cache() - - # settable fields - - @property - def branches(self) -> list[JeffRegion]: - if self._branches is not _Empty: - return self._branches - - return [ - JeffRegion.from_encoding(branch, self) - for branch in self._raw_data.switch.branches - ] - - @branches.setter - def branches(self, branches: list[JeffRegion]): - for branch in branches: - branch._update_cache() - branch._parent = self - self._branches = branches - self._mark_dirty() - - @property - def default(self) -> JeffRegion | None: - if self._default is not _Empty: - return self._default - - if region := self._raw_data.switch.default: - return JeffRegion.from_encoding(region, self) - - @default.setter - def default(self, default: JeffRegion): - default._update_cache() - default._parent = self - self._default = default - self._mark_dirty() - - # Python integration - - def __str__(self): - string = "\n" - - for i, branch in enumerate(self.branches): - string += f" case {i}:\n" - string += f"{textwrap.indent(str(branch), ' ')}" - - if branch := self.default: - string += "\n" - string += " default:\n" - string += f"{textwrap.indent(str(branch), ' ')}" - - return string - - -class ForSCF(JeffSCF): - """For-loop specialization of the JeffSCF instruction data class. - For loop operations contain a single region that represents the loop body. - The loop iterates from start to stop (exclusive) by step, maintaining state from region output - to input ports.""" - - _body: JeffRegion = _Empty - - def __init__(self, body: JeffRegion): - body._parent = self - self._body = body - self._mark_dirty() - - def _refresh(self, new_data: schema.ScfOp.Builder, string_table: list[str]): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - # the 'for' member is not its own struct, instead it directly stores the body region - forloop = new_data.init("for") - - _body = self.body - _body._refresh(forloop, string_table) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def _update_cache(self): - """TESTING ONLY. Update the cached attributes of this object. This effectively transitions - the object from "reader" mode to "writer" mode, e.g. as part of building a new module.""" - self._body = self.body - self._body._update_cache() - - # settable fields - - @property - def body(self) -> JeffRegion: - if self._body is not _Empty: - return self._body - - return JeffRegion.from_encoding(getattr(self._raw_data, "for"), self) - - @body.setter - def body(self, body: JeffRegion): - body._update_cache() - body._parent = self - self._body = body - self._mark_dirty() - - # Python integration - - def __str__(self): - string = "\n" - string += " body:\n" - string += f"{textwrap.indent(str(self.body), ' ')}" - return string - - -class WhileSCF(JeffSCF): - """While-loop specialization of the JeffSCF instruction data class. - While loop operations contain two regions: a condition region and a body region. - The condition region is executed before each iteration and accepts the state as input, but - only produces a bool as output. The body region takes the same state as input and output.""" - - _condition: JeffRegion = _Empty - _body: JeffRegion = _Empty - - def __init__(self, condition: JeffRegion, body: JeffRegion): - condition._parent = self - self._condition = condition - body._parent = self - self._body = body - self._mark_dirty() - - def _refresh(self, new_data: schema.ScfOp.Builder, string_table: list[str]): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - whileloop = new_data.init("while") - - _condition = self.condition - _condition._refresh(whileloop.condition, string_table) - - _body = self.body - _body._refresh(whileloop.body, string_table) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def _update_cache(self): - """TESTING ONLY. Update the cached attributes of this object. This effectively transitions - the object from "reader" mode to "writer" mode, e.g. as part of building a new module.""" - - self._condition = self.condition - self._body = self.body - self._condition._update_cache() - self._body._update_cache() - - # settable fields - - @property - def condition(self) -> JeffRegion: - if self._condition is not _Empty: - return self._condition - - return JeffRegion.from_encoding( - getattr(self._raw_data, "while").condition, self - ) - - @condition.setter - def condition(self, condition: JeffRegion): - condition._update_cache() - condition._parent = self - self._condition = condition - self._mark_dirty() - - @property - def body(self) -> JeffRegion: - if self._body is not _Empty: - return self._body - - return JeffRegion.from_encoding(getattr(self._raw_data, "while").body, self) - - @body.setter - def body(self, body: JeffRegion): - body._update_cache() - body._parent = self - self._body = body - self._mark_dirty() - - # Python integration - - def __str__(self): - string = "\n" - string += " while:\n" - string += f"{textwrap.indent(str(self.condition), ' ')}" - string += " do:\n" - string += f"{textwrap.indent(str(self.body), ' ')}" - return string - - -class DoWhileSCF(JeffSCF): - """Do-while-loop specialization of the JeffSCF instruction data class. - Do-while loop operations contain two regions: a body region and a condition region. - The body is executed first, then the condition is checked. The region sigantures are the same - as for the while loop.""" - - _body: JeffRegion = _Empty - _condition: JeffRegion = _Empty - - def __init__(self, body: JeffRegion, condition: JeffRegion): - body._parent = self - self._body = body - condition._parent = self - self._condition = condition - self._mark_dirty() - - def _refresh(self, new_data: schema.ScfOp.Builder, string_table: list[str]): - """Refresh this object's encoded data with cached modifications. Also refreshes all child - objects. This method guarantees that `is_dirty` is False after invocation. - When is `is_dirty` is already False, this method does nothing. - """ - doWhile = new_data.init("doWhile") - - self.body._refresh(doWhile.body, string_table) - self.condition._refresh(doWhile.condition, string_table) - - self._raw_data = new_data.as_reader() - self._mark_clean() - - def _update_cache(self): - """TESTING ONLY. Update the cached attributes of this object. This effectively transitions - the object from "reader" mode to "writer" mode, e.g. as part of building a new module.""" - - self._body = self.body - self._condition = self.condition - self._body._update_cache() - self._condition._update_cache() - - # settable fields - - @property - def body(self) -> JeffRegion: - if self._body is not _Empty: - return self._body - - return JeffRegion.from_encoding(self._raw_data.doWhile.body, self) - - @body.setter - def body(self, body: JeffRegion): - body._update_cache() - body._parent = self - self._body = body - self._mark_dirty() - - @property - def condition(self) -> JeffRegion: - if self._condition is not _Empty: - return self._condition - - return JeffRegion.from_encoding(self._raw_data.doWhile.condition, self) - - @condition.setter - def condition(self, condition: JeffRegion): - condition._update_cache() - condition._parent = self - self._condition = condition - self._mark_dirty() - - # Python integration - - def __str__(self): - string = "\n" - string += " do:\n" - string += f"{textwrap.indent(str(self.body), ' ')}" - string += " while:\n" - string += f"{textwrap.indent(str(self.condition), ' ')}" - return string - - -################# -# API functions # -################# - -# reading - - -def load_module(path: str): +def load_module(path: str) -> Module: """Load a jeff module from file.""" with open(path, "rb") as f: - return JeffModule.from_encoding(schema.Module.read(f)) + return Module._read_from_buffer(schema.Module.read(f)) -# building +################# +# Building # +################# -def qubit_alloc(): +def qubit_alloc() -> JeffOp: """Single qubit alloc operation.""" - inputs = [] - outputs = [JeffValue(QubitType())] - return JeffOp("qubit", "alloc", inputs, outputs) + inputs: list[Value] = [] + outputs = [Value(QubitType())] + return JeffOp(QubitAlloc(), inputs, outputs) -def qubit_free(qubit: JeffValue): +def qubit_free(qubit: Value) -> JeffOp: """Single qubit free operation.""" inputs = [qubit] - outputs = [] - return JeffOp("qubit", "free", inputs, outputs) + outputs: list[Value] = [] + return JeffOp(QubitFree(), inputs, outputs) def quantum_gate( name: str, - qubits: JeffValue | list[JeffValue], - params: list[float] = None, - control_qubits: list[JeffValue] = None, + qubits: Value | list[Value], + params: list[Value] | None = None, + *, + control_qubits: list[Value] | None = None, adjoint: bool = False, power: int = 1, -): +) -> JeffOp: """Instantiate a well-known or custom gate operation.""" - qubits = [qubits] if isinstance(qubits, JeffValue) else qubits + qubits = [qubits] if isinstance(qubits, Value) else qubits params = params or [] control_qubits = control_qubits or [] - if name in KnownGates: - gate = WellKnowGate(name, len(control_qubits), adjoint, power) - else: - gate = CustomGate( - name, len(qubits), len(params), len(control_qubits), adjoint, power - ) + _check_values(qubits, QubitType, "Qubit") + _check_values(control_qubits, QubitType, "Control qubit") + _check_values(params, FloatType, "Parameter") + + gate = QubitGate.from_gate_name( + name, + num_qubits=len(qubits), + num_params=len(params), + num_controls=len(control_qubits), + adjoint=adjoint, + power=power, + ) qubit_inputs = qubits + control_qubits inputs = qubit_inputs + params - outputs = [JeffValue(QubitType()) for _ in qubit_inputs] - return JeffOp("qubit", "gate", inputs, outputs, instruction_data=gate) + outputs = [Value(QubitType()) for _ in qubit_inputs] + return JeffOp(gate, inputs, outputs) def pauli_rotation( - angle: JeffValue, - pauli_string: str | list[str], - qubits: JeffValue | list[JeffValue], - control_qubits: list[JeffValue] = None, + angle: Value, + pauli_string: Pauli | str | list[Pauli | str], + qubits: Value | list[Value], + *, + control_qubits: list[Value] | None = None, adjoint: bool = False, power: int = 1, -): +) -> JeffOp: """Instantiate a Pauli-product rotation operation.""" - pauli_string = [pauli_string] if isinstance(pauli_string, str) else pauli_string - qubits = [qubits] if isinstance(qubits, JeffValue) else qubits + + if not isinstance(pauli_string, list): + pauli_string = [pauli_string] + + for i, pauli in enumerate(pauli_string): + if isinstance(pauli, str): + pauli_string[i] = Pauli.from_name(pauli) + + qubits = [qubits] if isinstance(qubits, Value) else qubits control_qubits = control_qubits or [] - assert len(pauli_string) == len(qubits) - ppr = PPRGate(pauli_string, len(control_qubits), adjoint, power) + assert len(pauli_string) == len(qubits), ( + f"Pauli string length {len(pauli_string)} must match number of qubits {len(qubits)}" + ) + _check_values(qubits, QubitType, "Qubit") + _check_values(control_qubits, QubitType, "Control qubit") + _check_values(angle, FloatType, "Pauli angle") + + ppr = PPRGate( + pauli_string, num_controls=len(control_qubits), adjoint=adjoint, power=power + ) inputs = qubits + control_qubits + [angle] - outputs = [JeffValue(QubitType()) for _ in inputs[:-1]] - return JeffOp("qubit", "gate", inputs, outputs, instruction_data=ppr) + outputs = [Value(QubitType()) for _ in inputs[:-1]] + return JeffOp(ppr, inputs, outputs) -def bitwise_not(x: JeffValue): - """Instantiate a bitwise NOT operation.""" - inputs = [x] - outputs = [JeffValue(x.type)] - return JeffOp("int", "not", inputs, outputs) +# TODO: Unimplemented +# def bitwise_not(x: Value): +# """Instantiate a bitwise NOT operation.""" +# inputs = [x] +# outputs = [Value(x.type)] +# #return JeffOp("", inputs, outputs) def switch_case( - index: JeffValue, - region_args: list[JeffValue], - branches: list[JeffRegion], - default: JeffRegion = None, -): - """Instantiate a switch-case operation. Cases run from 0 to len(branches)-1.""" + index: Value, + region_args: list[Value], + branches: list[Region], + default: Region | None = None, +) -> JeffOp: + """Instantiate a switch-case operation. Cases run from 0 to len(branches)-1. + + If the value of the index is out of bounds, the default branch is executed. + + :param index: The index value to switch on. + :param region_args: The arguments to pass to the regions. + :param branches: The branches to switch to. + :param default: The default branch to switch to if the index is out of bounds. + """ + _check_values(index, IntType, "Index") + for branch in branches + [default] if default else branches: assert len(branch.sources) == len(branches[0].sources), ( "all branches require the same number of sources" @@ -1973,6 +189,16 @@ def switch_case( scf = SwitchSCF(branches, default) inputs = [index] + region_args - outputs = [JeffValue(val.type) for val in branches[0].targets] + outputs = [Value(val.type) for val in branches[0].targets] + + return JeffOp(scf, inputs, outputs) + - return JeffOp("scf", "switch", inputs, outputs, instruction_data=scf) +def _check_values( + values: Value | list[Value], expected_type: Type[JeffType], name: str +) -> None: + """Check that the values have valid types.""" + values = [values] if isinstance(values, Value) else values + for value in values: + if not isinstance(value.type, expected_type): + raise ValueError(f"{name} {value} must be a {expected_type}") diff --git a/impl/py/src/jeff/capnp.py b/impl/py/src/jeff/capnp.py deleted file mode 100644 index c19ba47..0000000 --- a/impl/py/src/jeff/capnp.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Utility functions for tket extensions.""" - -import os -from pathlib import Path - -from typing import Any - - -def load_schema() -> Any: - import capnp - - capnp.remove_import_hook() - - import jeff - - # capnp warns about this environment variable being set - if "PWD" in os.environ: - del os.environ["PWD"] - - capnp_file = Path(jeff.__file__).joinpath("data", "jeff.capnp") - return capnp.load(capnp_file) diff --git a/impl/py/src/jeff/capnp/__init__.py b/impl/py/src/jeff/capnp/__init__.py new file mode 100644 index 0000000..d7355cf --- /dev/null +++ b/impl/py/src/jeff/capnp/__init__.py @@ -0,0 +1,98 @@ +"""Capnp schema writer and reader definitions.""" + +from __future__ import annotations + +import os +from pathlib import Path + +from typing import Any, Protocol, TypeVar + +from jeff.string_table import StringTable + + +def load_schema() -> Any: + import capnp # type: ignore[import-untyped] + + capnp.remove_import_hook() + + # capnp warns about this environment variable being set + if "PWD" in os.environ: + del os.environ["PWD"] + + capnp_file = Path(__file__).parent.joinpath("jeff.capnp") + return capnp.load(capnp_file) + + +# The capnp buffer reader for a `JeffCapnp` object. +Reader = TypeVar("Reader", contravariant=True) +# The capnp buffer writer for a `JeffCapnp` object. +Builder = TypeVar("Builder", contravariant=True) + + +class CapnpBuffer(Protocol[Reader, Builder]): + """Protocol for objects that have their data backed by a capnp-encoded buffer.""" + + @staticmethod + def _read_from_buffer(reader: Reader) -> CapnpBuffer[Reader, Builder]: + """Create a new object from a capnp-encoded buffer. + + The reader may be stored internally to avoid caching all the data in memory. + See + """ + + def _force_read_all(self) -> None: + """Force the object to read all the data from the buffer into memory, + and drop any internal references to a Reader. + + This is useful when transitioning an object from "reader" mode to "writer" mode. + + This call spreads recursively to all child objects. + """ + + def _write_to_buffer( + self, + new_data: Builder, + string_table: StringTable, + ) -> None: + """Write any cached modifications back to the capnp buffer.""" + + +class LazyUpdate(Protocol): + """Protocol for objects supported by a capnp buffer that write their modifications in a lazy + manner. + + An object may be marked as 'dirty' to indicate that it has been modified since the last time + it was written to a capnp buffer. Use `CapnpBuffer._write_to_buffer` to store the object's + modifications. + + When implementing this protocol, you may need to override the `_mark_dirty` if there are any + parent objects that should be marked as clean. + """ + + # Whether the properties have been modified. + _is_dirty: bool = True + + def _mark_dirty(self) -> None: + """Mark the object as dirty. + + This call spreads recursively to parent objects. + """ + self._is_dirty = True + + def _mark_clean(self) -> None: + """Mark the object as clean. + + An object can be marked clean by itself, but a dirty tag is always propagated upward. + """ + self._is_dirty = False + + @property + def is_dirty(self) -> bool: + """Whether the object has been modified since the last time it was encoded. + + Also returns True if the object has never been written out (e.g. after instantiation). + """ + return self._is_dirty + + +schema = load_schema() diff --git a/impl/py/src/jeff/data/jeff.capnp b/impl/py/src/jeff/capnp/jeff.capnp similarity index 100% rename from impl/py/src/jeff/data/jeff.capnp rename to impl/py/src/jeff/capnp/jeff.capnp diff --git a/impl/py/src/jeff/function.py b/impl/py/src/jeff/function.py new file mode 100644 index 0000000..1f035ba --- /dev/null +++ b/impl/py/src/jeff/function.py @@ -0,0 +1,315 @@ +"""Function definitions and declarations.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING +from typing_extensions import override + +from jeff.op import JeffOp +from jeff.region import Region +from jeff.type import JeffType + + +from .capnp import CapnpBuffer, LazyUpdate, schema +from .string_table import StringTable +from .value import Value, ValueTable + +if TYPE_CHECKING: + from jeff.module import Module + + +class Function(ABC, CapnpBuffer[schema.Function, schema.Function.Builder], LazyUpdate): # type: ignore + """Function definition or declaration. + + Jeff supports both function definitions (with a body) and declarations (with a signature). + For both the name is stored as a string attribute. + """ + + # cached attributes + _name: str | None = None + + # The read-only buffer backing this object. + _raw_data: schema.Function = None # type: ignore + + # Reference to the containing module. + _module: Module | None = None + + @override + def _mark_dirty(self) -> None: + """An object can be marked clean by itself, but a dirty tag is always propagated upward.""" + self._is_dirty = True + if self._module: + self._module._mark_dirty() + + @staticmethod + def _read_from_buffer(func: schema.Function) -> Function: # type: ignore + """Construct a function from encoded data. This provides a zero-copy view of the data.""" + obj: Function + match func.which: + case "definition": + definition = func.definition + body = Region._read_from_buffer(definition.body) + values = ValueTable([]) + for i, val in enumerate(definition.values): + val = Value._read_from_buffer(val) + val.id = i + values.add(val) + + obj = FunctionDef.__new__(FunctionDef) + obj._body = body + obj._value_table = values + case "declaration": + declaration = func.declaration + # Declaration I/O values do not have an id, as they are not connected ports. + inputs = [Value._read_from_buffer(val) for val in declaration.inputs] + outputs = [Value._read_from_buffer(val) for val in declaration.outputs] + obj = FunctionDecl.__new__(FunctionDecl) + obj._inputs = inputs + obj._outputs = outputs + case _: + raise ValueError(f"unknown function type: {func.which}") + + obj._name = None + obj._raw_data = func + obj._mark_clean() + return obj + + def _force_read_all(self) -> None: + _ = self.name + self._raw_data = None + self._mark_dirty() + + # settable fields + + @property + def name(self) -> str: + if self._name is None: + if self._module is None: + raise ValueError( + "Name hasn't been assigned yet to function without parent module" + ) + if self._raw_data is None: + raise ValueError("Name hasn't been assigned yet") + + idx = self._raw_data.name + self._name = self._module.string_table[idx] + + return self._name + + @name.setter + def name(self, name: str) -> None: + self._name = name + self._mark_dirty() + + # convenience methods + + @property + @abstractmethod + def function_type(self) -> tuple[list[JeffType], list[JeffType]]: + """Return the input/output type signature of the function.""" + + @property + def is_definition(self) -> bool: + """Returns True if the function is a definition.""" + return isinstance(self, FunctionDef) + + @property + def is_declaration(self) -> bool: + """Returns True if the function is a declaration.""" + return isinstance(self, FunctionDecl) + + # Python integration + + def __str__(self) -> str: + input_types, output_types = self.function_type + + string = f"func @{self.name}" + string += f"({', '.join(str(ty) for ty in input_types)})" + string += " -> " + string += f"({', '.join(str(ty) for ty in output_types)})" + + if isinstance(self, FunctionDef): + string += f":\n{self.body}" + else: + assert isinstance(self, FunctionDecl) + string += ";" + + return string + + +class FunctionDef(Function): + """Function definitions. + + Contains a single region determining the call signature of the function. + The encoded object also contains a value table for all typed values in the program. + """ + + # The dataflow region defining this functions + _body: Region + + # A table of typed values containing the function hyperedge types and metadata + _value_table: ValueTable + + def __init__(self, name: str, body: Region): + value_table = ValueTable._collect_from_region(body) + value_table._func = self + + body._set_parent(self) + + self._name = name + self._body = body + self._value_table = value_table + + self._mark_dirty() + + def _force_read_all(self) -> None: + self._body._force_read_all() + self._value_table._force_read_all() + super()._force_read_all() + + def _write_to_buffer( + self, + writer: schema.Function.Builder, # type: ignore + string_table: StringTable, + ) -> None: + definition = writer.init("definition") + values = definition.init("values", len(self._value_table)) + self._value_table._write_to_buffer(values, string_table) + self._body._write_to_buffer(definition.body, string_table) + + # strings are stored as indices in the encoded format + writer.name = string_table.index(self.name) + + self._raw_data = writer.as_reader() + self._mark_clean() + + # settable fields + + @property + def body(self) -> Region: + return self._body + + @body.setter + def body(self, body: Region) -> None: + for op in body.operations: + op._func = self + body._parent = self + self._body = body + self._mark_dirty() + + @property + def value_table(self) -> ValueTable: + return self._value_table + + @value_table.setter + def value_table(self, value_table: ValueTable) -> None: + value_table._func = self + self._value_table = value_table + self._mark_dirty() + + # convenience methods + + @property + def sources(self) -> list[Value]: + return self.body.sources + + @property + def targets(self) -> list[Value]: + return self.body.targets + + @property + def function_type(self) -> tuple[list[JeffType], list[JeffType]]: + input_types = [inp.type for inp in self.sources] + output_types = [out.type for out in self.targets] + return input_types, output_types + + def __getitem__(self, idx: int) -> JeffOp: + """Retrieve an operation in the function region by index.""" + return self.body[idx] + + def __len__(self) -> int: + """The number of operations in the function region.""" + return len(self.body) + + +class FunctionDecl(Function): + """Function declarations contain only the input/output type signature.""" + + # cached attributes + _inputs: list[Value] | None = None + _outputs: list[Value] | None = None + + def __init__(self, name: str, inputs: list[Value], outputs: list[Value]): + self._name = name + self._inputs = inputs + self._outputs = outputs + self._mark_dirty() + + def _force_read_all(self) -> None: + _ = self.inputs + _ = self.outputs + super()._force_read_all() + + def _write_to_buffer( + self, + writer: schema.Function.Builder, # type: ignore + string_table: StringTable, + ) -> None: + writer.name = self.name + + declaration = writer.init("declaration") + + _inputs = self.inputs + inputs = declaration.init("inputs", len(_inputs)) + for i, input in enumerate(_inputs): + input._write_to_buffer(inputs[i], string_table) + + _outputs = self.outputs + outputs = declaration.init("outputs", len(_outputs)) + for i, output in enumerate(_outputs): + output._write_to_buffer(outputs[i], string_table) + + # strings are stored as indices in the encoded format + writer.name = string_table.index(self.name) + + self._raw_data = writer.as_reader() + self._mark_clean() + + # settable fields + + @property + def inputs(self) -> list[Value]: + if self._inputs is None: + self._inputs = [ + Value._read_from_buffer(inp) + for inp in self._raw_data.declaration.inputs + ] + return self._inputs + + @inputs.setter + def inputs(self, inputs: list[Value]) -> None: + self._inputs = inputs + self._mark_dirty() + + @property + def outputs(self) -> list[Value]: + if self._outputs is None: + self._outputs = [ + Value._read_from_buffer(out) + for out in self._raw_data.declaration.outputs + ] + return self._outputs + + @outputs.setter + def outputs(self, outputs: list[Value]) -> None: + self._outputs = outputs + self._mark_dirty() + + # convenience methods + + @property + def function_type(self) -> tuple[list[JeffType], list[JeffType]]: + input_types = [inp.type for inp in self.inputs] + output_types = [out.type for out in self.outputs] + return input_types, output_types diff --git a/impl/py/src/jeff/module.py b/impl/py/src/jeff/module.py new file mode 100644 index 0000000..0ddeddb --- /dev/null +++ b/impl/py/src/jeff/module.py @@ -0,0 +1,246 @@ +"""Top-level module definition.""" + +from __future__ import annotations +from typing import Iterator + +from jeff.function import Function, FunctionDef + +from .capnp import CapnpBuffer, LazyUpdate, schema +from .string_table import StringTable + + +class Module(CapnpBuffer[schema.Module, schema.Module.Builder], LazyUpdate): # type: ignore + """Jeff module. + + The module is the root node in the program. It's a container for functions, + as well as certain metadata. The encoded object also stores a string table + for all string attributes in the program. + + :attr functions: The functions in the module. + :attr entrypoint: The index of the entrypoint function in the `functions` list. + :attr version: The version of the *jeff* format used to encode this module. + :attr tool: The name of the tool that generated this module. + :attr tool_version: The version of the tool that generated this module. + :attr string_table: The string table for all string attributes in the program. + """ + + _raw_data: schema.Module = None # type: ignore + + # Cached list of functions in the module, indexed by their id. + # These are loaded lazily individually. + _functions: dict[int, Function] + _function_count: int + + # Index of the entrypoint function in the `functions` list. + _entrypoint: int + + # Version of the *jeff* format used to encode this module. + _version: int + + # Name of the tool that generated this module. + _tool: str + + # Version of the tool that generated this module. + _tool_version: str + + # String table for all string attributes in the program. + _string_table: StringTable + + def __init__( + self, + functions: list[Function], + *, + entrypoint: int = 0, + version: int = 0, + tool: str | None = None, + tool_version: str | None = None, + ): + self._string_table = StringTable([]) + + if tool is None or tool_version is None: + from jeff import __version__ as jeff_version + + tool = "jeff-py" + tool_version = jeff_version + + for func in functions: + func._module = self + if isinstance(func, FunctionDef): + self._string_table._update_with_function(func) + + self.functions = functions + self._entrypoint = entrypoint + self._version = version + self._tool = tool + self._tool_version = tool_version + self._mark_dirty() + + def refresh(self) -> None: + """Refresh this object's encoded data with cached modifications. + + Also refreshes all child objects. This method guarantees that `is_dirty` + is False after invocation, and When is `is_dirty` is already False, this + method does nothing. + """ + if not self.is_dirty: + return + + # Reusing an existing message is a bad idea as any new allocations will leave the old ones + # in the message, bloating its size. + new_data = schema.Module.new_message() + string_table = self._string_table + self._write_to_buffer(new_data, string_table) + + def write_out(self, path: str) -> None: + """Write out the program to file. Only available on the module object as the root node. + Automatically calls `refresh` before writing. + """ + self.refresh() + + with open(path, "wb") as f: + self._raw_data.as_builder().write(f) + + @staticmethod + def _read_from_buffer(module: schema.Module) -> Module: # type: ignore + """Construct a JeffModule from encoded data. This provides a zero-copy view of the data.""" + obj = Module.__new__(Module) + obj._raw_data = module + obj._function_count = len(module.functions) + obj._entrypoint = module.entrypoint + obj._version = module.version + obj._tool = module.tool + obj._tool_version = module.toolVersion + obj._string_table = StringTable._read_from_buffer(module.strings) + obj._mark_clean() + return obj + + def _force_read_all(self) -> None: + """Force the object to read all the data from the buffer into memory, + and drop any internal references to a Reader. + + This is useful when transitioning an object from "reader" mode to "writer" mode. + + This call spreads recursively to all child objects. + """ + for func in self.functions: + func._force_read_all() + self._raw_data = None + self._mark_dirty() + + def _write_to_buffer( + self, + writer: schema.Module.Builder, # type: ignore + string_table: StringTable, + ) -> None: + functions = writer.init("functions", len(self._functions)) + for i in range(len(self._functions)): + self[i]._write_to_buffer(functions[i], string_table) + + writer.entrypoint = self.entrypoint + writer.version = self.version + writer.tool = self.tool + writer.toolVersion = self.tool_version + + strings = writer.init("strings", len(string_table)) + string_table._write_to_buffer(strings, string_table) + + self._raw_data = writer.as_reader() + self._mark_clean() + + # settable fields + + @property + def functions(self) -> list[Function]: + """Returns a list of functions in the module.""" + return list(self) + + @functions.setter + def functions(self, functions: list[Function]) -> None: + """Set the functions in the module.""" + for func in functions: + # "adopting" a read-only object will detach it from its original encoded message, + # so let's load any data associated to it into cache + func._force_read_all() + func._module = self + if isinstance(func, FunctionDef): + self._string_table._update_with_function(func) + self._functions = {i: func for i, func in enumerate(functions)} + self._function_count = len(functions) + self._mark_dirty() + + # encoding-only fields + + @property + def string_table(self) -> StringTable: + """The string table for all string attributes in the program.""" + return self._string_table + + # static fields + + @property + def entrypoint(self) -> int: + """The index of the entrypoint function in the `functions` list.""" + return self._entrypoint + + @property + def version(self) -> int: + """The version of the *jeff* format used to encode this module.""" + return self._version + + @property + def tool(self) -> str: + """The name of the tool that generated this module.""" + return self._tool + + @property + def tool_version(self) -> str: + """The version of the tool that generated this module.""" + return self._tool_version + + def __getitem__(self, idx: int) -> Function: + """Returns the function at the given index.""" + if idx < 0 or idx >= self._function_count: + raise IndexError( + f"Index {idx} is out of bounds for module with {self._function_count} functions" + ) + if idx not in self._functions: + if self._raw_data is None: + msg = f"Module is incomplete. function {idx} has not been assigned yet." + raise ValueError(msg) + func = Function._read_from_buffer(self._raw_data.functions[idx]) + func._module = self + self._functions[idx] = func + return self._functions[idx] + + def __setitem__(self, idx: int, func: Function) -> None: + """Set the function at the given index. + + :raises: If idx is equal or larger than `len(self)`. + """ + if idx < 0 or idx >= self._function_count: + raise IndexError( + f"Index {idx} is out of bounds for module with {self._function_count} functions" + ) + func._module = self + self._functions[idx] = func + self._mark_dirty() + + def __len__(self) -> int: + """The number of functions in the module.""" + return self._function_count + + def __iter__(self) -> Iterator[Function]: + for i in range(len(self)): + yield self[i] + + def __str__(self) -> str: + string = f"jeff v{self.version}" + + if self.tool: + string += f", {self.tool} v{self.tool_version}" + string += "\n\n" + + for i, func in enumerate(self): + string += f"{'[entry] ' if i == self.entrypoint else ''}{func}\n" + + return string diff --git a/impl/py/src/jeff/op/__init__.py b/impl/py/src/jeff/op/__init__.py new file mode 100644 index 0000000..accf72e --- /dev/null +++ b/impl/py/src/jeff/op/__init__.py @@ -0,0 +1,217 @@ +"""Jeff operation definitions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from typing_extensions import override + +from jeff.capnp import CapnpBuffer, LazyUpdate, schema +from jeff.op.kind import OpKind, OpType +from jeff.op.qubit import QubitGate +from jeff.op.qubit.non_unitary import NonUnitaryOp +from jeff.op.scf import Scf +from jeff.string_table import StringTable +from jeff.value import Value + + +if TYPE_CHECKING: + from jeff.region import Region + from jeff.function import FunctionDef + + +class JeffOp(CapnpBuffer[schema.Op, schema.Op.Builder], LazyUpdate): # type: ignore + """A generic container for all operations in the program. + + The common fields include input and output values, as well as the kind of + operation represented. All ops have a main kind (like QubitOp and IntOp), as + well as a subkind (like alloc, add, etc.). Some operations store additional + data as well, which can be primitive types as well extra classes defined in + the API. + """ + + # The read-only buffer backing this object. + _raw_data: schema.Op | None = None # type: ignore + + # Reference to the containing function. + _func: FunctionDef | None = None + # Reference to the containing region. + _region: Region | None = None + + # Input values to the operation + _inputs: list[Value] | None = None + # Output values from the operation + _outputs: list[Value] | None = None + # The operation type + _op_type: OpType[Any, Any] + + def __init__( + self, + op_type: OpType[Any, Any], + inputs: list[Value], + outputs: list[Value], + ): + op_type._op = self + self._inputs = inputs + self._outputs = outputs + self._op_type = op_type + self._mark_dirty() + + @override + def _mark_dirty(self) -> None: + """Mark the object as dirty. + + This call spreads recursively to parent objects. + """ + self._is_dirty = True + if self._func: + self._func._mark_dirty() + + @staticmethod + def _read_from_buffer(op: schema.Op) -> JeffOp: # type: ignore + obj = JeffOp.__new__(JeffOp) + obj._raw_data = op + + match op.instruction.which: + case "qubit": + qubit = op.instruction.qubit + match qubit.which: + case "gate": + obj._op_type = QubitGate._read_from_buffer(qubit.gate) + case _: + obj._op_type = NonUnitaryOp._read_from_buffer(qubit) + case "scf": + obj._op_type = Scf._read_from_buffer(op.instruction.scf) + case _: + raise ValueError(f"unknown operation type: {op.instruction.which}") + obj._op_type._op = obj + obj._mark_clean() + return obj + + def _force_read_all(self) -> None: + _ = self.inputs + _ = self.outputs + self._op_type._force_read_all() + self._mark_dirty() + + def _write_to_buffer( + self, + writer: schema.Op.Builder, # type: ignore + string_table: StringTable, + ) -> None: + """Write any cached modifications back to the capnp buffer.""" + + _inputs = self.inputs + inputs = writer.init("inputs", len(_inputs)) + for i, val in enumerate(_inputs): + inputs[i] = val.id # no need to search the value table + + _outputs = self.outputs + outputs = writer.init("outputs", len(_outputs)) + for i, val in enumerate(_outputs): + outputs[i] = val.id + + instruction_group = writer.instruction.init(self.op_type.op_kind.kind) + match self.op_type.op_kind: + case OpKind.QUBIT_GATE: + gate = instruction_group.init("gate") + self.op_type._write_to_buffer(gate, string_table) + case OpKind.SCF: + self.op_type._write_to_buffer(instruction_group, string_table) + case _: + raise ValueError(f"unknown operation type: {self.op_type.op_kind}") + + self._raw_data = writer.as_reader() + self._mark_clean() + + # cached fields + + @property + def inputs(self) -> list[Value]: + """The input values to the operation.""" + if self._inputs is None: + if self._func is None: + raise ValueError( + "Input values haven't been assigned yet to operation without parent function" + ) + if self._raw_data is None: + raise ValueError("Input values haven't been assigned yet") + self._inputs = [] + for input_id in self._raw_data.inputs: + val = self._func.value_table[input_id] + self._inputs.append(val) + return self._inputs + + @inputs.setter + def inputs(self, inputs: list[Value]) -> None: + if self._func is not None: + for val in inputs: + self._func.value_table.add(val) + self._inputs = inputs + self._mark_dirty() + + @property + def outputs(self) -> list[Value]: + if self._outputs is None: + if self._func is None: + raise ValueError( + "Output values haven't been assigned yet to operation without parent function" + ) + if self._raw_data is None: + raise ValueError("Output values haven't been assigned yet") + self._outputs = [] + for output_id in self._raw_data.outputs: + val = self._func.value_table[output_id] + self._outputs.append(val) + return self._outputs + + @outputs.setter + def outputs(self, outputs: list[Value]) -> None: + if self._func is not None: + for val in outputs: + self._func.value_table.add(val) + self._outputs = outputs + self._mark_dirty() + + @property + def op_type(self) -> OpType[Any, Any]: + return self._op_type + + @op_type.setter + def op_type(self, op_type: OpType[Any, Any]) -> None: + op_type._op = self + self._op_type = op_type + self._mark_dirty() + + # static fields + + @property + def kind(self) -> OpKind: + return self.op_type.op_kind + + # convenience methods + + def get_value(self, idx: int) -> Value: + """Returns a value from the function's value table.""" + if self._func is None: + raise ValueError( + "Value haven't been assigned yet to operation without parent function" + ) + return self._func.value_table[idx] + + # Python integrations + + def __str__(self) -> str: + string = "" + + if outputs := self.outputs: + string += ", ".join(str(out) for out in outputs) + string += " = " + + string += f"{self.kind} " + + string += ", ".join(str(inp) for inp in self.inputs) + + if (data := self.kind) is not None: + string += f" {data}" + + return string diff --git a/impl/py/src/jeff/op/kind.py b/impl/py/src/jeff/op/kind.py new file mode 100644 index 0000000..8fd73be --- /dev/null +++ b/impl/py/src/jeff/op/kind.py @@ -0,0 +1,46 @@ +"""Operation kinds.""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Protocol + +from jeff.capnp import Builder, CapnpBuffer, Reader + +if TYPE_CHECKING: + from jeff.op import JeffOp + + +class OpKind(Enum): + """Categories of operation types. + + Each of these are represented by different classes in the API. + They roughly correspond to structs in the capnp schema. + """ + + QUBIT = "qubit" + QUBIT_GATE = "qubit.gate" + SCF = "scf" + + @property + def kind(self) -> str: + """The kind of operation.""" + return self.value.split(".")[0] + + @property + def subkind(self) -> str | None: + """The subkind of operation, if any.""" + if "." not in self.value: + return None + return self.value.split(".")[1] + + +class OpType(Protocol, CapnpBuffer[Reader, Builder]): + """A concrete jeff operation type.""" + + # Reference to the containing operation. + _op: JeffOp | None = None + + @property + def op_kind(self) -> OpKind: + """The kind of operation.""" diff --git a/impl/py/src/jeff/op/qubit/__init__.py b/impl/py/src/jeff/op/qubit/__init__.py new file mode 100644 index 0000000..4615576 --- /dev/null +++ b/impl/py/src/jeff/op/qubit/__init__.py @@ -0,0 +1,29 @@ +"""Quantum operations.""" + +from .gate import QubitGate, WellKnownGate, CustomGate, PPRGate, Pauli +from .non_unitary import ( + NonUnitaryOp, + QubitAlloc, + QubitFree, + QubitFreeZero, + QubitMeasure, + QubitMeasureNd, + QubitReset, +) +from .protocol import QubitOp + +__all__ = [ + "CustomGate", + "Pauli", + "NonUnitaryOp", + "PPRGate", + "QubitAlloc", + "QubitFree", + "QubitFreeZero", + "QubitGate", + "QubitMeasure", + "QubitMeasureNd", + "QubitOp", + "QubitReset", + "WellKnownGate", +] diff --git a/impl/py/src/jeff/op/qubit/gate.py b/impl/py/src/jeff/op/qubit/gate.py new file mode 100644 index 0000000..bd11b7e --- /dev/null +++ b/impl/py/src/jeff/op/qubit/gate.py @@ -0,0 +1,506 @@ +"""Qubit gates and operations.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum +from typing_extensions import override + +from jeff.op.kind import OpKind + +from jeff.capnp import LazyUpdate, schema +from jeff.op.qubit import QubitOp +from jeff.string_table import StringTable + + +class Pauli(Enum): + I = "i" # noqa: E741 + X = "x" + Y = "y" + Z = "z" + + @staticmethod + def from_name(name: str) -> Pauli: + """Try to match a Pauli name to a known Pauli.""" + match name.lower(): + case "i": + return Pauli.I + case "x": + return Pauli.X + case "y": + return Pauli.Y + case "z": + return Pauli.Z + raise ValueError(f"unknown Pauli: {name}") + + +class KnownGate(Enum): + """A standard quantum gate. + + Well-known gates must be one of the gates defined in the spec. No additional + data needs to be specified.""" + + GPHASE = "gphase" + I = "i" # noqa: E741 + X = "x" + Y = "y" + Z = "z" + S = "s" + T = "t" + RX = "rx" + RY = "ry" + RZ = "rz" + H = "h" + U = "u" + SWAP = "swap" + + # convenience methods + + @property + def num_qubits(self) -> int: + match self: + case KnownGate.GPHASE: + return 0 + case ( + KnownGate.I + | KnownGate.X + | KnownGate.Y + | KnownGate.Z + | KnownGate.S + | KnownGate.T + | KnownGate.RX + | KnownGate.RY + | KnownGate.RZ + | KnownGate.H + | KnownGate.U + ): + return 1 + case KnownGate.SWAP: + return 2 + case _: + raise ValueError(f"unknown gate: {self.value}") + + @property + def num_params(self) -> int: + match self: + case ( + KnownGate.I + | KnownGate.X + | KnownGate.Y + | KnownGate.Z + | KnownGate.S + | KnownGate.T + | KnownGate.H + | KnownGate.SWAP + ): + return 0 + case KnownGate.GPHASE | KnownGate.RX | KnownGate.RY | KnownGate.RZ: + return 1 + case KnownGate.U: + return 3 + case _: + raise ValueError(f"unknown gate: {self.value}") + + @staticmethod + def from_name(name: str) -> KnownGate: + """Try to match a quantum gate name to a known gate. + + :raises ValueError: If the name does not match any known gate. + """ + name = name.lower() + for gate in KnownGate: + if name == gate.value: + return gate + raise ValueError(f"unknown gate: {name}") + + +class QubitGate( + ABC, + QubitOp[schema.QubitGate, schema.QubitGate.Builder], # type: ignore + LazyUpdate, +): + """Instruction data for quantum gate operations.""" + + _num_controls: int = 0 + _adjoint: bool = False + _power: int = 1 + + # The read-only buffer backing this object. + _raw_data: schema.QubitGate | None = None # type: ignore + + @staticmethod + def from_gate_name( + name: str, + *, + num_qubits: int, + num_params: int, + num_controls: int = 0, + adjoint: bool = False, + power: int = 1, + ) -> QubitGate: + """Return a qubit gate from a name, trying to match a well-known names if possible. + + :param name: The name of the gate. + :param num_qubits: The number of qubits the gate acts on. + :param num_params: The number of float parameters to the gate. + :param num_controls: The number of control qubits. + """ + try: + gate = KnownGate.from_name(name) + # Only use the well-known gate if it matches the number of qubits and parameters + if gate.num_qubits == num_qubits and gate.num_params == num_params: + return WellKnownGate( + gate, num_controls=num_controls, adjoint=adjoint, power=power + ) + except ValueError: + pass + + return CustomGate( + name, + num_qubits, + num_params, + num_controls=num_controls, + adjoint=adjoint, + power=power, + ) + + @override + def _mark_dirty(self) -> None: + """Mark the object as dirty. + + This call spreads recursively to parent objects. + """ + self._is_dirty = True + if self._op: + self._op._mark_dirty() + + @staticmethod + def _read_from_buffer(reader: schema.QubitGate) -> QubitGate: # type: ignore + match reader.which: + case "custom": + gate = CustomGate._read_from_buffer(reader.custom) + case "wellKnown": + gate = WellKnownGate._read_from_buffer(reader.wellKnown) + case "ppr": + gate = PPRGate._read_from_buffer(reader.ppr) + case _: + raise ValueError(f"unknown gate type: {reader.which}") + gate._num_controls = reader.controlQubits + gate._adjoint = reader.adjoint + gate._power = reader.power + gate._mark_clean() + return gate + + def _force_read_all(self) -> None: + self._raw_data = None + self._mark_dirty() + + def _write_to_buffer( + self, + writer: schema.QubitGate.Builder, # type: ignore + string_table: StringTable, + ) -> None: + writer.controlQubits = self.num_controls + writer.adjoint = self.adjoint + writer.power = self.power + self._raw_data = writer.as_reader() + self._mark_clean() + + @property + @abstractmethod + def num_qubits(self) -> int: + """Number of qubits the gate acts on, not including control qubits.""" + + @property + @abstractmethod + def num_params(self) -> int: + """Number of float parameters to the gate.""" + + @property + def op_kind(self) -> OpKind: + """The kind of operation.""" + return OpKind.QUBIT_GATE + + @property + @abstractmethod + def qualified_name(self) -> str: + """The name of the gate, including the 'qubit.gate.' prefix.""" + + # settable fields + + @property + def num_controls(self) -> int: + """Number of control qubits.""" + return self._num_controls + + @num_controls.setter + def num_controls(self, num_controls: int) -> None: + """Set the number of control qubits.""" + self._num_controls = num_controls + self._mark_dirty() + + @property + def adjoint(self) -> bool: + """Whether the gate is adjoint.""" + return self._adjoint + + @adjoint.setter + def adjoint(self, adjoint: bool) -> None: + """Set whether the gate is adjoint.""" + self._adjoint = adjoint + self._mark_dirty() + + @property + def power(self) -> int: + """Times the gate is applied.""" + return self._power + + @power.setter + def power(self, power: int) -> None: + """Set the number of times the gate is applied.""" + self._power = power + self._mark_dirty() + + def _str_attributes(self) -> list[str]: + """Helper method used to list the attributes of the gate for the __str__ representation.""" + strings = [] + if num_controls := self.num_controls: + strings.append(f"numControls={num_controls}") + if self.adjoint: + strings.append("adjoint") + if (power := self.power) != 1: + strings.append(f"power={power}") + return strings + + +class WellKnownGate(QubitGate): + """A standard quantum gate. + + Well-known gates must be one of the gates defined in the spec. No additional data needs to be specified. + """ + + _kind: KnownGate + + def __init__( + self, + kind: KnownGate, + *, + num_controls: int = 0, + adjoint: bool = False, + power: int = 1, + ): + self._kind = kind + self._num_controls = num_controls + self._adjoint = adjoint + self._power = power + self._mark_dirty() + + def _write_to_buffer( + self, + writer: schema.QubitGate.Builder, # type: ignore + string_table: StringTable, + ) -> None: + writer.wellKnown = self.kind + super()._write_to_buffer(writer, string_table) + + @property + def num_qubits(self) -> int: + return self.kind.num_qubits + + @property + def num_params(self) -> int: + return self.kind.num_params + + @property + def qualified_name(self) -> str: + """The name of the gate, including the 'qubit.gate.' prefix.""" + return f"{self.op_kind}.{self.kind.name}" + + @property + def kind(self) -> KnownGate: + """The kind of gate.""" + return self._kind + + @kind.setter + def kind(self, kind: KnownGate) -> None: + """Set the kind of gate.""" + self._kind = kind + self._mark_dirty() + + def __str__(self) -> str: + attrs = [self.qualified_name] + self._str_attributes() + return f"({', '.join(attrs)})" + + +class CustomGate(QubitGate): + """Custom quantum gate. + + Custom gates are identified by a string name, and also have to provide the + number of qubits and float parameters. + """ + + _name: str + _num_qubits: int + _num_params: int + + def __init__( + self, + name: str, + num_qubits: int, + num_params: int, + *, + num_controls: int = 0, + adjoint: bool = False, + power: int = 1, + ): + self._name = name + self._num_qubits = num_qubits + self._num_params = num_params + self._num_controls = num_controls + self._adjoint = adjoint + self._power = power + self._mark_dirty() + + def _write_to_buffer( + self, + writer: schema.QubitGate.Builder, # type: ignore + string_table: StringTable, + ) -> None: + custom = writer.init("custom") + custom.name = string_table.index(self.name) + custom.numQubits = self.num_qubits + custom.numParams = self.num_params + super()._write_to_buffer(writer, string_table) + + @property + def qualified_name(self) -> str: + """The name of the gate, including the 'qubit.gate.' prefix.""" + return f"{self.op_kind}.{self.name}" + + # settable fields + + @property + def name(self) -> str: + if self._name is None: + assert ( + (func := self._op._func) and (mod := func._module) and not mod.is_dirty + ), ( + "The parent module is not present or dirty, and no name has been cached. " + "Please call `_read_from_buffer` on the module to access this attribute." + ) + + self._name = self._op._func._module.string_table[self._raw_data.custom.name] + + return self._name + + @name.setter + def name(self, name: str) -> None: + self._name = name + self._mark_dirty() + + @property + def num_qubits(self) -> int: + if self._num_qubits is None: + self._num_qubits = self._raw_data.custom.numQubits + return self._num_qubits + + @num_qubits.setter + def num_qubits(self, num_qubits: int) -> None: + self._num_qubits = num_qubits + self._mark_dirty() + + @property + def num_params(self) -> int: + if self._num_params is None: + self._num_params = self._raw_data.custom.numParams + return self._num_params + + @num_params.setter + def num_params(self, num_params: int) -> None: + self._num_params = num_params + self._mark_dirty() + + # Python integration + + def __str__(self) -> str: + strings = [f"{self.qualified_name}", f"numQubits={self.num_qubits}"] + if numParams := self.num_params: + strings += [f"numParams={numParams}"] + strings += self._str_attributes() + return f"({', '.join(strings)})" + + +class PPRGate(QubitGate): + """Pauli-product rotation gate.""" + + _pauli_string: list[Pauli] | None = None + + def __init__( + self, + pauli_string: list[Pauli | str], + *, + num_controls: int = 0, + adjoint: bool = False, + power: int = 1, + ): + self.pauli_string = pauli_string + self._num_controls = num_controls + self._adjoint = adjoint + self._power = power + self._mark_dirty() + + def _force_read_all(self) -> None: + _ = self.pauli_string + super()._force_read_all() + + def _write_to_buffer( + self, + writer: schema.QubitGate.Builder, # type: ignore + string_table: StringTable, + ) -> None: + ppr = writer.init("ppr") + _pauli_string = self.pauli_string + pauli_string = ppr.init("pauliString", len(_pauli_string)) + for i, pauli in enumerate(_pauli_string): + pauli_string[i] = pauli.value + super()._write_to_buffer(writer, string_table) + + @property + def qualified_name(self) -> str: + """The name of the gate, including the 'qubit.gate.' prefix.""" + return f"{self.op_kind}.ppr" + + # settable fields + + @property + def pauli_string(self) -> list[Pauli]: + if self._pauli_string is None: + if self._raw_data is None: + msg = "Pauli string has not been assigned yet." + raise ValueError(msg) + self._pauli_string = [ + Pauli.from_name(pauli) for pauli in self._raw_data.ppr.pauliString + ] + return self._pauli_string + + @pauli_string.setter + def pauli_string(self, pauli_string: list[Pauli | str]) -> None: + self._pauli_string = [ + Pauli.from_name(p) if isinstance(p, str) else p for p in pauli_string + ] + self._mark_dirty() + + @property + def num_qubits(self) -> int: + return len(self.pauli_string) + + @property + def num_params(self) -> int: + return 1 + + def __str__(self) -> str: + attrs = [ + self.qualified_name, + f"pauliString={self.pauli_string}", + ] + self._str_attributes() + return f"({', '.join(attrs)})" diff --git a/impl/py/src/jeff/op/qubit/non_unitary.py b/impl/py/src/jeff/op/qubit/non_unitary.py new file mode 100644 index 0000000..0786a7d --- /dev/null +++ b/impl/py/src/jeff/op/qubit/non_unitary.py @@ -0,0 +1,178 @@ +"""Non unitary qubit operations""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing_extensions import override + +from jeff.capnp import LazyUpdate, schema +from jeff.string_table import StringTable +from jeff.op.kind import OpKind +from jeff.op.qubit.protocol import QubitOp + + +class NonUnitaryOp( + ABC, + QubitOp[schema.QubitOp, schema.QubitGate.Builder], # type: ignore + LazyUpdate, +): + """A non-unitary quantum operation. + + See QubitGate for unitary operations. + """ + + # The read-only buffer backing this object. + _raw_data: schema.QubitOp | None = None # type: ignore + + @override + def _mark_dirty(self) -> None: + """Mark the object as dirty. + + This call spreads recursively to parent objects. + """ + self._is_dirty = True + if self._op: + self._op._mark_dirty() + + @staticmethod + def _read_from_buffer(reader: schema.QubitOp) -> NonUnitaryOp: # type: ignore + ops = [ + QubitAlloc(), + QubitFree(), + QubitFreeZero(), + QubitMeasure(), + QubitMeasureNd(), + QubitReset(), + ] + which = reader.which + for op in ops: + if op.name == which: + return op + else: + raise ValueError(f"unknown gate type: {reader.which}") + + def _force_read_all(self) -> None: + self._raw_data = None + self._mark_dirty() + + def _write_to_buffer( + self, + writer: schema.QubitGate.Builder, # type: ignore + string_table: StringTable, + ) -> None: + # Setting the op variant to `None` marks it as the current value. + setattr(writer, self.name, None) + self._raw_data = writer.as_reader() + self._mark_clean() + + @property + def op_kind(self) -> OpKind: + """The kind of operation.""" + return OpKind.QUBIT + + @property + @abstractmethod + def name(self) -> str: + """Name of the operation""" + + @property + def qualified_name(self) -> str: + """Full name of the operation, including the "qubit." prefix""" + return f"{self.op_kind}.{self.name}" + + def __str__(self) -> str: + return self.qualified_name + + +class QubitAlloc(NonUnitaryOp): + """Allocates a new qubit in the |0> state. + + Outputs: + - `qubit`: The newly allocated qubit. + """ + + @property + def name(self) -> str: + """Name of the operation""" + return "alloc" + + +class QubitFree(NonUnitaryOp): + """Frees a qubit. + + This operation makes no assumptions about the state of the qubit. + + Inputs: + - `qubit`: The qubit to free. + """ + + @property + def name(self) -> str: + """Name of the operation""" + return "free" + + +class QubitFreeZero(NonUnitaryOp): + """Frees a qubit in the |0> state. + + This operation can be used to avoid performing resets when it is known + that the qubit has already been reset. It is undefined behavior to free + a qubit that is not in the |0> state. + + Inputs: + - `qubit`: The qubit to free. + """ + + @property + def name(self) -> str: + """Name of the operation""" + return "freeZero" + + +class QubitMeasure(NonUnitaryOp): + """Perform a destructive measurement of a qubit in the computational basis. + + Inputs: + - `qubit`: The qubit to measure. + + Outputs: + - `int(1)`: The measurement result. + """ + + @property + def name(self) -> str: + """Name of the operation""" + return "measure" + + +class QubitMeasureNd(NonUnitaryOp): + """Perform a non-destructive measurement of a qubit in the computational basis. + + Inputs: + - `qubit`: The qubit to measure. + + Outputs: + - `qubit`: The measured qubit. + - `int(1)`: The measurement result. + """ + + @property + def name(self) -> str: + """Name of the operation""" + return "measureNd" + + +class QubitReset(NonUnitaryOp): + """Resets a qubit to the |0> state. + + Inputs: + - `qubit`: The qubit to reset. + + Outputs: + - `qubit`: The reset qubit. + """ + + @property + def name(self) -> str: + """Name of the operation""" + return "reset" diff --git a/impl/py/src/jeff/op/qubit/protocol.py b/impl/py/src/jeff/op/qubit/protocol.py new file mode 100644 index 0000000..8b525c3 --- /dev/null +++ b/impl/py/src/jeff/op/qubit/protocol.py @@ -0,0 +1,10 @@ +"""Common protocol definition for all quantum operations""" + +from typing import Protocol + +from jeff.capnp import Builder, Reader +from jeff.op.kind import OpType + + +class QubitOp(Protocol, OpType[Reader, Builder]): + """A qubit operation.""" diff --git a/impl/py/src/jeff/op/scf.py b/impl/py/src/jeff/op/scf.py new file mode 100644 index 0000000..5a55312 --- /dev/null +++ b/impl/py/src/jeff/op/scf.py @@ -0,0 +1,468 @@ +"""Structured control flow operations.""" + +from __future__ import annotations + +from abc import ABC +import textwrap +from typing import TYPE_CHECKING, Literal + +from jeff.op.kind import OpKind, OpType + +from ..capnp import LazyUpdate, schema +from ..string_table import StringTable + +if TYPE_CHECKING: + from ..region import Region + from ..function import FunctionDef + + +class Scf(ABC, OpType[schema.ScfOp, schema.ScfOp.Builder], LazyUpdate): # type: ignore + """A structured control flow operation.""" + + # The read-only buffer backing this object. + _raw_data: schema.ScfOp | None = None # type: ignore + + def _mark_dirty(self) -> None: + """Mark the object as dirty. + + This call spreads recursively to parent objects. + """ + self._is_dirty = True + if self._op: + self._op._mark_dirty() + + @staticmethod + def _read_from_buffer(reader: schema.ScfOp) -> Scf: # type: ignore + scf: Scf + match reader.which: + case "switch": + switch = reader.switch + branches = [ + Region._read_from_buffer(branch) for branch in switch.branches + ] + if switch.default: + default = Region._read_from_buffer(switch.default) + else: + default = None + scf = SwitchSCF(branches, default) + case "for": + for_loop = getattr(reader, "for") + body = Region._read_from_buffer(for_loop) + scf = ForSCF(body) + case "while": + while_loop = getattr(reader, "while") + condition = Region._read_from_buffer(while_loop.condition) + body = Region._read_from_buffer(while_loop.body) + scf = WhileSCF(condition, body) + case "doWhile": + do_while = getattr(reader, "doWhile") + body = Region._read_from_buffer(do_while.body) + condition = Region._read_from_buffer(do_while.condition) + scf = DoWhileSCF(body, condition) + case _: + raise ValueError(f"unknown scf type: {reader.which}") + scf._raw_data = reader + scf._mark_clean() + return scf + + def _force_read_all(self) -> None: + self._raw_data = None + self._mark_dirty() + + def _write_to_buffer( + self, + writer: schema.ScfOp.Builder, # type: ignore + string_table: StringTable, + ) -> None: + self._raw_data = writer.as_reader() + self._mark_clean() + + @property + def op_kind(self) -> OpKind: + """The kind of operation.""" + return OpKind.SCF + + @property + def parent_func(self) -> FunctionDef | None: + """Returns the parent function to this scf, if any.""" + if self._op is None: + return None + return self._op._func + + +class SwitchSCF(Scf): + """Switch-statement operation. + + Switch operations contain a list of regions that are indexed into by an + integer parameter, as well as an optional default region that is triggered + when the index is out of bounds. + + All regions must have the same input/output port signature. + + :param branches: List of regions to switch between. + :param default: Optional default region to execute when the index is out of bounds. + """ + + _branches: list[Region] | None = None + + # The default branch, if any. + # We use a literal False to indicate that the switch does not have a default branch. + _default: Region | Literal[False] | None = None + + def __init__(self, branches: list[Region], default: Region | None = None): + self.branches = branches + self.default = default + self._mark_dirty() + + def _force_read_all(self) -> None: + branches = self.branches + default = self.default + for branch in branches: + branch._force_read_all() + if default is not None: + default._force_read_all() + super()._force_read_all() + + def _write_to_buffer( + self, + writer: schema.ScfOp.Builder, # type: ignore + string_table: StringTable, + ) -> None: + switch = writer.init("switch") + + _branches = self.branches + branches = switch.init("branches", len(_branches)) + for i, branch in enumerate(_branches): + branch._write_to_buffer(branches[i], string_table) + + if _default := self.default: + _default._write_to_buffer(switch.default, string_table) + + super()._write_to_buffer(writer, string_table) + + # settable fields + + @property + def branches(self) -> list[Region]: + """Returns the list of branches in this switch operation. + + If the input index is out of bounds, the default branch is executed instead. + """ + if self._branches is None: + if self._raw_data is None: + raise ValueError("SwitchSCF branches haven't been assigned yet") + self._branches = [ + Region._read_from_buffer(branch) + for branch in self._raw_data.switch.branches + ] + for branch in self._branches: + branch._set_parent(self) + + return self._branches + + @branches.setter + def branches(self, branches: list[Region]) -> None: + """Set the list of branches in this switch operation.""" + for branch in branches: + branch._force_read_all() + branch._parent = self + self._branches = branches + self._mark_dirty() + + @property + def default(self) -> Region | None: + """Returns the default branch in this switch operation. + + If the input index is out of range of the branches, the default branch is executed instead. + """ + if self._default is None: + if self._raw_data is None: + raise ValueError("SwitchSCF default branch hasn't been assigned yet") + if region := self._raw_data.switch.default: + self._default = Region._read_from_buffer(region) + self._default._set_parent(self) + else: + self._default = False + + if self._default is None or self._default is False: + return None + return self._default + + @default.setter + def default(self, default: Region | None) -> None: + """Set the default branch in this switch operation.""" + if default is None: + self._default = False + else: + default._force_read_all() + default._set_parent(self) + self._default = default + self._mark_dirty() + + # Python integration + + def __str__(self) -> str: + string = "\n" + + for i, branch in enumerate(self.branches): + string += f" case {i}:\n" + string += f"{textwrap.indent(str(branch), ' ')}" + + if default := self.default: + string += "\n" + string += " default:\n" + string += f"{textwrap.indent(str(default), ' ')}" + + return string + + +class ForSCF(Scf): + """For-loop instruction. + + For loop operations contain a single region that represents the loop body. + The loop iterates from start to stop (exclusive) by step, maintaining state + from region output to input ports. + + :param body: The region to execute as the loop body. + """ + + _body: Region | None = None + + def __init__(self, body: Region): + body._set_parent(self) + self._body = body + self._mark_dirty() + + def _force_read_all(self) -> None: + body = self.body + body._force_read_all() + super()._force_read_all() + + def _write_to_buffer( + self, + writer: schema.ScfOp.Builder, # type: ignore + string_table: StringTable, + ) -> None: + forloop = writer.init("for") + + _body = self.body + _body._write_to_buffer(forloop, string_table) + + super()._write_to_buffer(writer, string_table) + + # settable fields + + @property + def body(self) -> Region: + """Returns the region to execute as the loop body.""" + if self._body is None: + forloop = getattr(self._raw_data, "for") + self._body = Region._read_from_buffer(forloop) + self._body._set_parent(self) + + return self._body + + @body.setter + def body(self, body: Region) -> None: + """Set the region to execute as the loop body.""" + body._force_read_all() + body._set_parent(self) + self._body = body + self._mark_dirty() + + # Python integration + + def __str__(self) -> str: + string = "\n" + string += " body:\n" + string += f"{textwrap.indent(str(self.body), ' ')}" + return string + + +class WhileSCF(Scf): + """While-loop instruction. + + While loop operations contain two regions: a condition region and a body + region. + + The condition region is executed before each iteration and accepts the state + as input, but only produces a bool as output. The body region takes the same + state as input and output. + + :param condition: The region to execute as the loop condition. + :param body: The region to execute as the loop body. + """ + + _condition: Region | None = None + _body: Region | None = None + + def __init__(self, condition: Region, body: Region): + condition._set_parent(self) + body._set_parent(self) + self._condition = condition + self._body = body + self._mark_dirty() + + def _force_read_all(self) -> None: + condition = self.condition + body = self.body + condition._force_read_all() + body._force_read_all() + super()._force_read_all() + + def _write_to_buffer( + self, + writer: schema.ScfOp.Builder, # type: ignore + string_table: StringTable, + ) -> None: + while_loop = writer.init("while") + + _condition = self.condition + _condition._write_to_buffer(while_loop.condition, string_table) + + _body = self.body + _body._write_to_buffer(while_loop.body, string_table) + + super()._write_to_buffer(writer, string_table) + + # settable fields + + @property + def condition(self) -> Region: + """Returns the region to execute as the loop condition.""" + if self._condition is None: + while_loop = getattr(self._raw_data, "while") + self._condition = Region._read_from_buffer(while_loop.condition) + self._condition._set_parent(self) + + return self._condition + + @condition.setter + def condition(self, condition: Region) -> None: + """Set the region to execute as the loop condition.""" + condition._force_read_all() + condition._set_parent(self) + self._condition = condition + self._mark_dirty() + + @property + def body(self) -> Region: + """Returns the region to execute as the loop body.""" + if self._body is None: + while_loop = getattr(self._raw_data, "while") + self._body = Region._read_from_buffer(while_loop.body) + self._body._set_parent(self) + + return self._body + + @body.setter + def body(self, body: Region) -> None: + """Set the region to execute as the loop body.""" + body._force_read_all() + body._set_parent(self) + self._body = body + self._mark_dirty() + + # Python integration + + def __str__(self) -> str: + string = "\n" + string += " while:\n" + string += f"{textwrap.indent(str(self.condition), ' ')}" + string += " do:\n" + string += f"{textwrap.indent(str(self.body), ' ')}" + return string + + +class DoWhileSCF(Scf): + """Do-while-loop instruction. + + Do-while loop operations contain two regions: a body region and a condition + region. + + The body is executed first, then the condition is checked. The region + signatures are the same as for the while loop. + + :param body: The region to execute as the loop body. + :param condition: The region to execute as the loop condition. + """ + + _body: Region | None = None + _condition: Region | None = None + + def __init__(self, body: Region, condition: Region): + body._set_parent(self) + condition._set_parent(self) + self._body = body + self._condition = condition + self._mark_dirty() + + def _force_read_all(self) -> None: + body = self.body + condition = self.condition + body._force_read_all() + condition._force_read_all() + super()._force_read_all() + + def _write_to_buffer( + self, + writer: schema.ScfOp.Builder, # type: ignore + string_table: StringTable, + ) -> None: + do_while = writer.init("doWhile") + + _body = self.body + _body._write_to_buffer(do_while.body, string_table) + + _condition = self.condition + _condition._write_to_buffer(do_while.condition, string_table) + + super()._write_to_buffer(writer, string_table) + + # settable fields + + @property + def body(self) -> Region: + """Returns the region to execute as the loop body.""" + if self._body is None: + do_while = getattr(self._raw_data, "doWhile") + self._body = Region._read_from_buffer(do_while.body) + self._body._set_parent(self) + + return self._body + + @body.setter + def body(self, body: Region) -> None: + """Set the region to execute as the loop body.""" + body._force_read_all() + body._set_parent(self) + self._body = body + self._mark_dirty() + + @property + def condition(self) -> Region: + """Returns the region to execute as the loop condition.""" + if self._condition is None: + do_while = getattr(self._raw_data, "doWhile") + self._condition = Region._read_from_buffer(do_while.condition) + self._condition._set_parent(self) + + return self._condition + + @condition.setter + def condition(self, condition: Region) -> None: + """Set the region to execute as the loop condition.""" + condition._force_read_all() + condition._set_parent(self) + self._condition = condition + self._mark_dirty() + + # Python integration + + def __str__(self) -> str: + string = "\n" + string += " do:\n" + string += f"{textwrap.indent(str(self.body), ' ')}" + string += " while:\n" + string += f"{textwrap.indent(str(self.condition), ' ')}" + return string diff --git a/impl/py/src/jeff/region.py b/impl/py/src/jeff/region.py new file mode 100644 index 0000000..338d0d7 --- /dev/null +++ b/impl/py/src/jeff/region.py @@ -0,0 +1,281 @@ +"""Dataflow regions.""" + +from __future__ import annotations + +from collections import deque +import textwrap +from typing import TYPE_CHECKING, Iterator +from typing_extensions import override + +from .capnp import CapnpBuffer, LazyUpdate, schema +from .string_table import StringTable +from .value import Value + +if TYPE_CHECKING: + from .op import JeffOp + from .op.scf import Scf + from .function import FunctionDef + + +class Region(CapnpBuffer[schema.Region, schema.Region.Builder], LazyUpdate): # type: ignore + """A region is container for operations, and defines input and output ports. Regions do not + allow value edges across it.""" + + # cached attributes + _sources: list[Value] | None = None + _targets: list[Value] | None = None + + # The list of operations in this region, indexed by their id. + # We lazily load each one individually. + _operations: dict[int, JeffOp] + _operation_count: int + + # The read-only buffer backing this object. + _raw_data: schema.Region | None = None # type: ignore + + # Reference to the containing function or scf. + _parent: FunctionDef | Scf | None = None + + def __init__( + self, + sources: list[Value], + targets: list[Value], + operations: list[JeffOp], + ): + self._sources = sources + self._targets = targets + self.operations = operations + self._mark_dirty() + + @override + def _mark_dirty(self) -> None: + self._is_dirty = True + if self._parent: + self._parent._mark_dirty() + + @staticmethod + def _read_from_buffer(reader: schema.Region) -> Region: # type: ignore + region = Region.__new__(Region) + region._raw_data = reader + region._operation_count = len(reader.operations) + region._mark_clean() + return region + + def _force_read_all(self) -> None: + _ = self.sources + _ = self.targets + _ = self.operations + for op in self._operations.values(): + op._force_read_all() + self._raw_data = None + self._mark_dirty() + + def _write_to_buffer( + self, + writer: schema.Region.Builder, # type: ignore + string_table: StringTable, + ) -> None: + _sources = self.sources + sources = writer.init("sources", len(_sources)) + for i, val in enumerate(_sources): + sources[i] = val.id + + _targets = self.targets + targets = writer.init("targets", len(_targets)) + for i, val in enumerate(_targets): + targets[i] = val.id + + ops_writer = writer.init("operations", self._operation_count) + for i in range(self._operation_count): + op = self[i] + op._write_to_buffer(ops_writer[i], string_table) + + self._raw_data = writer.as_reader() + self._mark_clean() + + # settable fields + + @property + def sources(self) -> list[Value]: + if self._sources is None: + if self.parent_func is None: + raise ValueError( + "Source values haven't been assigned yet to region without parent function" + ) + if self._raw_data is None: + raise ValueError("Source values haven't been assigned yet") + self._sources = [] + for source_id in self._raw_data.sources: + val = self.parent_func.value_table[source_id] + self._sources.append(val) + return self._sources + + @sources.setter + def sources(self, sources: list[Value]) -> None: + if self.parent_func is not None: + for val in sources: + self.parent_func.value_table.add(val) + self._sources = sources + self._mark_dirty() + + @property + def targets(self) -> list[Value]: + if self._targets is None: + if self.parent_func is None: + raise ValueError( + "Target values haven't been assigned yet to region without parent function" + ) + if self._raw_data is None: + raise ValueError("Target values haven't been assigned yet") + self._targets = [] + for target_id in self._raw_data.targets: + val = self.parent_func.value_table[target_id] + self._targets.append(val) + return self._targets + + @targets.setter + def targets(self, targets: list[Value]) -> None: + if self.parent_func is not None: + for val in targets: + self.parent_func.value_table.add(val) + self._targets = targets + self._mark_dirty() + + @property + def operations(self) -> list[JeffOp]: + """A copy of the list of operations in the region""" + return list(self) + + @operations.setter + def operations(self, operations: list[JeffOp]) -> None: + if (func := self.parent_func) is not None: + for op in operations: + op._func = func + self._operations = {i: op for i, op in enumerate(operations)} + self._operation_count = len(operations) + self._mark_dirty() + + def append_op(self, op: JeffOp) -> int: + """Append an operation to the region. + + :returns: The index of the new operation. + """ + if (func := self.parent_func) is not None: + op._func = func + idx = self._operation_count + self._operations[idx] = op + self._operation_count += 1 + self._mark_dirty() + return idx + + # convenience methods + + @property + def parent_func(self) -> FunctionDef | None: + """Returns the parent function to this region, if any.""" + match self._parent: + case FunctionDef(): + return self._parent + case Scf(): + return self._parent.parent_func + case _: + return None + + def _set_parent(self, parent: FunctionDef | Scf) -> None: + """Set the parent container of this region. + + This may be either a function if this is a top-level region in the definition, + or a scf if this is region is nested. + """ + self._parent = parent + if (func := self.parent_func) is not None: + for op in self._operations.values(): + op._func = func + for val in self.sources + self.targets: + func.value_table.add(val) + + def subregions_bfs(self) -> Iterator[Region]: + """Returns an iterator over all the subregions in this region, in breadth-first order. + + The iterator returns the region itself first, then all its subregions, then all their + subregions, and so on. + """ + from jeff.op.scf import SwitchSCF, ForSCF, WhileSCF, DoWhileSCF + + queue = deque([self]) + while region := queue.popleft(): + yield region + + for op in region: + match op.op_type: + case SwitchSCF(): + for branch in op.op_type.branches: + queue.append(branch) + if op.op_type.default: + queue.append(op.op_type.default) + case ForSCF(): + queue.append(op.op_type.body) + case WhileSCF() | DoWhileSCF(): + queue.append(op.op_type.condition) + queue.append(op.op_type.body) + case _: + pass + + # Python integration + + def __getitem__(self, idx: int) -> JeffOp: + """Retrieve an operation in the region by index.""" + if idx < 0 or idx >= self._operation_count: + raise IndexError( + f"Index {idx} is out of bounds for region with {self._operation_count} operations" + ) + if idx not in self._operations: + if self._raw_data is None: + msg = ( + f"Region is incomplete. operation {idx} has not been assigned yet." + ) + raise ValueError(msg) + op = JeffOp._read_from_buffer(self._raw_data.operations[idx]) + if (func := self.parent_func) is not None: + op._func = func + self._operations[idx] = op + return self._operations[idx] + + def __setitem__(self, idx: int, op: JeffOp) -> None: + """Set the value of an operation in the function + + :raises: If idx is equal or larger than `len(self)`. + """ + if idx >= self._operation_count: + msg = f"Index {idx} is out of bounds for region with {self._operation_count} operations" + raise IndexError(msg) + if (func := self.parent_func) is not None: + op._func = func + self._operations[idx] = op + self._mark_dirty() + + def __len__(self) -> int: + """The number of operations in the region.""" + return self._operation_count + + def __iter__(self) -> Iterator[JeffOp]: + for i in range(len(self)): + yield self[i] + + def __str__(self) -> str: + string = "" + + string += " in :" + if sources := self.sources: + string += f" {', '.join(str(src) for src in sources)}" + string += "\n" + + for op in self: + string += f"{textwrap.indent(str(op), ' ')}\n" + + string += " out:" + if targets := self.targets: + string += f" {', '.join(str(tgt) for tgt in targets)}" + string += "" + + return string diff --git a/impl/py/src/jeff/string_table.py b/impl/py/src/jeff/string_table.py new file mode 100644 index 0000000..c79ab38 --- /dev/null +++ b/impl/py/src/jeff/string_table.py @@ -0,0 +1,126 @@ +"""An indexed string table.""" + +from __future__ import annotations +from typing import Any, TYPE_CHECKING +from typing_extensions import override + +from jeff.capnp import CapnpBuffer, LazyUpdate + +if TYPE_CHECKING: + from jeff.module import Module + from jeff.function import FunctionDef + + +# TODO: What's the correct type for capnp list[str] reader and writers? +class StringTable(CapnpBuffer[Any, Any], LazyUpdate): + """An indexed string table. + + Lazy-loaded from a capnp string list buffer. + """ + + # Cached sparse list of strings. + # This may contain holes, but all indices will be below `self._len` + _string_table: dict[int, str] + # Reverse mapping from string to index + _reverse_table: dict[str, int] + # The number of elements in the table + _len = 0 + + # The buffer backing this list + _raw_data: Any | None = None + + # A reference to the module defining this table. + _module: Module | None = None + + def __init__(self, string_table: list[str]): + self._string_table = {i: s for i, s in enumerate(string_table)} + self._reverse_table = {s: i for i, s in enumerate(string_table)} + self._len = len(string_table) + self._mark_dirty() + + def _update_with_function(self, function: FunctionDef) -> None: + """Collect all the strings used in a function and update the table. + + Traverses the function operations and and its operations collecting any value defined in it, + adds them to a value table and sets the ids for each value. + """ + from jeff.op.qubit import CustomGate + + self.insert(function.name) + + for region in function.body.subregions_bfs(): + for op in region: + if isinstance(op.op_type, CustomGate): + self.insert(op.op_type.name) + + # TODO: Add metadata keys too, once we support them + + @override + def _mark_dirty(self) -> None: + """Mark the object as dirty. + + This call spreads recursively to parent objects. + """ + self._is_dirty = True + if self._module: + self._module._mark_dirty() + + @staticmethod + def _read_from_buffer(reader: Any) -> StringTable: + # Do not read the strings eagerly, only fetch them when needed. + table = StringTable([]) + table._raw_data = reader + table._len = reader.len + table._mark_clean() + return table + + def _force_read_all(self) -> None: + for i in range(len(self)): + self[i] + self._raw_data = None + self._mark_dirty() + + def _write_to_buffer( + self, + writer: Any, + string_table: StringTable, + ) -> None: + for i in range(self._len): + writer[i] = self[i] + self._raw_data = writer.as_reader() + self._mark_clean() + + def index(self, value: str) -> int: + """Returns the index of a string in the table.""" + return self._reverse_table[value] + + def insert(self, value: str) -> int: + """Inserts a string into the table and returns its index.""" + if value in self._reverse_table: + return self._reverse_table[value] + index = self._len + self[index] = value + return index + + def __len__(self) -> int: + return self._len + + def __getitem__(self, index: int) -> str: + if index < 0 or index >= self._len: + msg = f"Index {index} is out of bounds for string table with {self._len} strings" + raise IndexError(msg) + if index not in self._string_table: + if self._raw_data is None: + msg = f"String table is incomplete. index {index} has not been assigned yet." + raise ValueError(msg) + s = self._raw_data.index(index) + self._string_table[index] = s + self._reverse_table[s] = index + return self._string_table[index] + + def __setitem__(self, index: int, value: str) -> None: + self._string_table[index] = value + self._reverse_table[value] = index + if index >= self._len: + self._len = index + 1 + self._mark_dirty() diff --git a/impl/py/src/jeff/type.py b/impl/py/src/jeff/type.py new file mode 100644 index 0000000..b732c98 --- /dev/null +++ b/impl/py/src/jeff/type.py @@ -0,0 +1,150 @@ +"""Jeff type definitions.""" + +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass +from enum import Enum +from .capnp import schema, CapnpBuffer +from .string_table import StringTable + + +class FloatPrecision(Enum): + FLOAT32 = "float32" + FLOAT64 = "float64" + + @staticmethod + def from_name(name: str) -> FloatPrecision: + match name.lower(): + case "float32": + return FloatPrecision.FLOAT32 + case "float64": + return FloatPrecision.FLOAT64 + raise ValueError(f"Unknown float precision: {name}") + + +class JeffType(ABC, CapnpBuffer[schema.Type, schema.Type.Builder]): # type: ignore + """A Jeff type.""" + + @staticmethod + def _read_from_buffer(reader: schema.Type): # type: ignore + match reader.which: + case "qubit": + return QubitType() + case "qureg": + return QuregType() + case "int": + bitwidth = reader.int + return IntType(bitwidth) + case "intArray": + bitwidth = reader.intArray + return IntArrayType(bitwidth) + case "float": + float_width = FloatPrecision.from_name(reader.float) + return FloatType(float_width) + case "floatArray": + float_width = FloatPrecision.from_name(reader.floatArray) + return FloatArrayType(float_width) + case _: + raise ValueError(f"Unknown type: {reader.which}") + + def _force_read_all(self) -> None: + pass + + +@dataclass(frozen=True) +class QubitType(JeffType): + """A qubit type.""" + + def _write_to_buffer( + self, + new_data: schema.Type.Builder, # type: ignore + string_table: StringTable, + ) -> None: + new_data.qubit = None + + def __str__(self) -> str: + return "qubit" + + +@dataclass(frozen=True) +class QuregType(JeffType): + """A register of qubits with arbitrary size.""" + + def _write_to_buffer( + self, + new_data: schema.Type.Builder, # type: ignore + string_table: StringTable, + ) -> None: + new_data.qureg = None + + def __str__(self) -> str: + return "qureg" + + +@dataclass(frozen=True) +class IntType(JeffType): + """An integer type, with a fixed bitwidth.""" + + bitwidth: int + + def _write_to_buffer( + self, + new_data: schema.Type.Builder, # type: ignore + string_table: StringTable, + ) -> None: + new_data.int = self.bitwidth + + def __str__(self) -> str: + return f"int{self.bitwidth}" + + +@dataclass(frozen=True) +class IntArrayType(JeffType): + """Specialization of the JeffType for integer arrays.""" + + bitwidth: int + + def _write_to_buffer( + self, + new_data: schema.Type.Builder, # type: ignore[name-defined] + string_table: StringTable, + ) -> None: + new_data.intArray = self.bitwidth + + def __str__(self) -> str: + return f"int{self.bitwidth}[]" + + +@dataclass(frozen=True) +class FloatType(JeffType): + """Specialization of the JeffType for floating point values.""" + + bitwidth: FloatPrecision + + def _write_to_buffer( + self, + new_data: schema.Type.Builder, # type: ignore + string_table: StringTable, + ) -> None: + new_data.float = self.bitwidth.value + + def __str__(self) -> str: + return f"{self.bitwidth.value}" + + +@dataclass(frozen=True) +class FloatArrayType(JeffType): + """Specialization of the JeffType for floating point arrays.""" + + bitwidth: FloatPrecision + + def _write_to_buffer( + self, + new_data: schema.Type.Builder, # type: ignore + string_table: StringTable, + ) -> None: + new_data.floatArray = self.bitwidth + + def __str__(self) -> str: + return f"{self.bitwidth.value}[]" diff --git a/impl/py/src/jeff/value.py b/impl/py/src/jeff/value.py new file mode 100644 index 0000000..51d060b --- /dev/null +++ b/impl/py/src/jeff/value.py @@ -0,0 +1,224 @@ +"""Values identifying typed hyperedges in the program.""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from typing import Any, TYPE_CHECKING +from typing_extensions import override + +from .capnp import schema, CapnpBuffer, LazyUpdate +from .type import JeffType +from .string_table import StringTable + +if TYPE_CHECKING: + from jeff.function import FunctionDef + from jeff.region import Region + + +# TODO: What's the correct type for capnp list[schema.Value] reader and writers? +class ValueTable(CapnpBuffer[Any, Any], LazyUpdate): + """An indexed value table.""" + + _value_table: dict[int, Value] + _len = 0 + + # TODO: Not the correct type. What's the type of a capnp list reader? + _raw_data: Any | None = None + + # A reference to the function definition defining this value table + _func: FunctionDef | None = None + + def __init__(self, value_table: list[Value]): + self._value_table = {i: s for i, s in enumerate(value_table)} + self._len = len(value_table) + self._mark_dirty() + + @staticmethod + def _collect_from_region(region: Region) -> ValueTable: + """Compute a value table from a region. + + Traverses the region and its operations collecting any value defined in it, + adds them to a value table and sets the ids for each value. + """ + # The value table to return + table = ValueTable([]) + + # Values which already have an id + value_dict: dict[int, Value] = {} + # Values which have not yet been assigned an id + unordered_values: deque[Value] = deque() + + def add_value(val: Value) -> None: + val._value_table = table + match val.id: + case None: + unordered_values.append(val) + case id: + value_dict[id] = val + + for region in region.subregions_bfs(): + for val in region.sources + region.targets: + add_value(val) + + for op in region: + for val in op.inputs + op.outputs: + add_value(val) + + # Assign ids to all values + taken_ids = deque(sorted(value_dict.keys())) + id = 0 + while unordered_values or taken_ids: + table._len = max(id + 1, table._len) + if id in value_dict: + table[id] = value_dict[id] + elif unordered_values: + val = unordered_values.popleft() + val.id = id + table[id] = val + # Bump the next assignable id + if not unordered_values and taken_ids: + id = taken_ids.popleft() + else: + id += 1 + + return table + + @override + def _mark_dirty(self) -> None: + """Mark the object as dirty. + + This call spreads recursively to parent objects. + """ + self._is_dirty = True + if self._func: + self._func._mark_dirty() + + @staticmethod + def _read_from_buffer(reader: Any) -> ValueTable: + # Do not read the values eagerly, only fetch them when needed. + table = ValueTable([]) + table._raw_data = reader + table._len = reader.len + table._mark_clean() + return table + + def _force_read_all(self) -> None: + for i in range(len(self)): + self[i] + self._raw_data = None + self._mark_dirty() + + def _write_to_buffer( + self, + writer: Any, + string_table: StringTable, + ) -> None: + for i in range(self._len): + self[i]._write_to_buffer(writer[i], string_table) + self._raw_data = writer.as_reader() + self._mark_clean() + + def add(self, value: Value) -> int: + """Add a value to the value table and return its id. + + If the value already has an id matching an existing entry, + checks that the types and metadata coincide. + + :returns: The id of the added value. + :raises ValueError: If the table already contains a different value with the same id. + """ + value._value_table = self + if value.id is None: + id = self._len + value.id = id + self._value_table[id] = value + self._len += 1 + self._mark_dirty() + return id + + if value.id not in self._value_table: + self._value_table[value.id] = value + self._len = max(self._len, value.id + 1) + self._mark_dirty() + return value.id + + if self._value_table[value.id] != value: + raise ValueError( + f"Value #{value} already exists in value table with type {self._value_table[value.id].type}" + ) + + return value.id + + def __getitem__(self, index: int) -> Value: + if index < 0 or index >= self._len: + msg = f"Index {index} is out of bounds for value table with {self._len} values" + raise IndexError(msg) + if index not in self._value_table: + if self._raw_data is None: + msg = ( + f"Value table is incomplete. id {index} has not been assigned yet." + ) + raise ValueError(msg) + value_reader = self._raw_data.index(index) + val = Value._read_from_buffer(value_reader) + val.id = index + val._value_table = self + self._value_table[index] = val + return self._value_table[index] + + def __setitem__(self, index: int, value: Value) -> None: + value.id = index + value._value_table = self + + self._value_table[index] = value + if index >= self._len: + self._len = index + 1 + self._mark_dirty() + + def __len__(self) -> int: + return self._len + + +@dataclass +class Value(CapnpBuffer[schema.Value, schema.Value.Builder]): # type: ignore + """Program values represent dataflow between operations, and defines the data type used. + + This class is immutable, and holds an identifier for the unique edge in the program. In an + encoded program, the identifier is the index into the parent function's value table, whereas + during program construction the identifier is the object's instance id. + + :attr id: The value table index or `None` if not yet assigned. + :attr type: The type of the value. + """ + + # The value table index is used in reader mode both for pretty printing and comparing values. + type: JeffType + id: int | None = None + + # TODO: Add register metadata + + # The value table that defines this value with id `self.id`, if available. + _value_table: ValueTable | None = None + + @staticmethod + def _read_from_buffer(reader: schema.Value) -> Value: # type: ignore + type = JeffType._read_from_buffer(reader.type) + return Value(type) + + def _force_read_all(self) -> None: + pass + + def _write_to_buffer( + self, + new_data: schema.Value.Builder, # type: ignore + string_table: StringTable, + ) -> None: + # For immutable classes, just write the cached data into the encoding buffer. + self.type._write_to_buffer(new_data.type, string_table) + + def __str__(self) -> str: + if self.id is None: + return f"%{self.type}" + else: + return f"%{self.id}:{self.type}" diff --git a/impl/py/tests/test_reader.py b/impl/py/tests/test_reader.py index e06e537..8b95ac5 100644 --- a/impl/py/tests/test_reader.py +++ b/impl/py/tests/test_reader.py @@ -1,2 +1,2 @@ def test_hello_world() -> None: - assert 2 + 2 != "🐟" + assert 2 + 2 != "🐟" # type: ignore[comparison-overlap] diff --git a/justfile b/justfile index db8699d..dace165 100644 --- a/justfile +++ b/justfile @@ -27,7 +27,7 @@ fix-rs: cargo clippy --all-targets --all-features --workspace --fix --allow-staged --allow-dirty # Auto-fix all the lints in the python code. fix-py: - uv run ruff check --fix + uv run ruff check --fix impl/py # Format all the code in the repository. format: format-rs format-py @@ -50,9 +50,9 @@ coverage-py: # Update the capnproto definitions. update-capnp: # Always use the latest version of capnproto-rust - cargo install capnpc + cargo binstall capnpc || cargo install capnpc # Copy the definition to the python package - cp impl/capnp/jeff.capnp impl/py/src/jeff-format/data/jeff.capnp + cp impl/capnp/jeff.capnp impl/py/src/jeff/capnp/jeff.capnp # Re-generate rust capnp files capnp compile -orust:impl/rs/src --src-prefix=impl impl/capnp/jeff.capnp # Re-generate c++ capnp files diff --git a/uv.lock b/uv.lock index bb9a3d7..0577474 100644 --- a/uv.lock +++ b/uv.lock @@ -186,6 +186,12 @@ wheels = [ name = "jeff-format" version = "0.1.0" source = { editable = "impl/py" } +dependencies = [ + { name = "pycapnp" }, +] + +[package.metadata] +requires-dist = [{ name = "pycapnp", git = "https://github.com/mlxd/pycapnp?branch=mlxd%2Fupdate_gh_actions" }] [[package]] name = "mypy" @@ -302,6 +308,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/a5/987a405322d78a73b66e39e4a90e4ef156fd7141bf71df987e50717c321b/pre_commit-4.3.0-py2.py3-none-any.whl", hash = "sha256:2b0747ad7e6e967169136edffee14c16e148a778a54e4f967921aa1ebf2308d8", size = 220965, upload-time = "2025-08-09T18:56:13.192Z" }, ] +[[package]] +name = "pycapnp" +version = "2.0.0" +source = { git = "https://github.com/mlxd/pycapnp?branch=mlxd%2Fupdate_gh_actions#ef05121594a58413f223f5b9adc5e931d9d9cfaf" } + [[package]] name = "pygments" version = "2.19.2"