diff --git a/CHANGELOG.md b/CHANGELOG.md index a0c3ef4d7b..172865af1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel - ✨ Add conversions between Jeff and QCO ([#1479], [#1548], [#1565]) ([**@denialhaag**]) - ✨ Add a `place-and-route` pass for mapping circuits to architectures with restricted topologies ([#1537], [#1547], [#1568], [#1581], [#1583], [#1588]) ([**@MatthiasReumann**]) - ✨ Add initial infrastructure for new QC and QCO MLIR dialects - ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1602]) + ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**]) ### Changed @@ -341,6 +341,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1588]: https://github.com/munich-quantum-toolkit/core/pull/1588 [#1583]: https://github.com/munich-quantum-toolkit/core/pull/1583 [#1581]: https://github.com/munich-quantum-toolkit/core/pull/1581 +[#1580]: https://github.com/munich-quantum-toolkit/core/pull/1580 [#1573]: https://github.com/munich-quantum-toolkit/core/pull/1573 [#1572]: https://github.com/munich-quantum-toolkit/core/pull/1572 [#1571]: https://github.com/munich-quantum-toolkit/core/pull/1571 diff --git a/mlir/include/mlir/Compiler/CompilerPipeline.h b/mlir/include/mlir/Compiler/CompilerPipeline.h index ec7473a8cf..f43bac99d7 100644 --- a/mlir/include/mlir/Compiler/CompilerPipeline.h +++ b/mlir/include/mlir/Compiler/CompilerPipeline.h @@ -72,18 +72,18 @@ struct CompilationRecord { * * 1. QC dialect (reference semantics) - imported from * qc::QuantumComputation - * 2. Canonicalization + cleanup + * 2. QC cleanup pipeline * 3. QCO dialect (value semantics) - enables SSA-based optimizations - * 4. Canonicalization + cleanup + * 4. QCO cleanup pipeline * 5. Quantum optimization passes - * 6. Canonicalization + cleanup + * 6. QCO cleanup pipeline * 7. QC dialect - converted back for backend lowering - * 8. Canonicalization + cleanup + * 8. QC cleanup pipeline * 9. QIR (Quantum Intermediate Representation) - optional final lowering - * 10. Canonicalization + cleanup + * 10. QIR cleanup pipeline * - * Following MLIR best practices, canonicalization and dead value removal - * are always run after each major transformation stage. + * Following MLIR best practices, simplification and dead-value cleanup are + * run after each major transformation stage. */ class QuantumCompilerPipeline { public: @@ -111,15 +111,6 @@ class QuantumCompilerPipeline { CompilationRecord* record = nullptr) const; private: - /** - * @brief Add canonicalization and cleanup passes - * - * @details - * Always adds the standard MLIR canonicalization pass followed by common - * sub-expression elimination and dead value removal. - */ - static void addCleanupPasses(PassManager& pm); - /** * @brief Configure PassManager with diagnostic options * diff --git a/mlir/include/mlir/Conversion/QCOToQC/QCOToQC.td b/mlir/include/mlir/Conversion/QCOToQC/QCOToQC.td index a773b3d5c2..e77e57ad9a 100644 --- a/mlir/include/mlir/Conversion/QCOToQC/QCOToQC.td +++ b/mlir/include/mlir/Conversion/QCOToQC/QCOToQC.td @@ -16,5 +16,6 @@ def QCOToQC : Pass<"qco-to-qc"> { It handles the transformation of qubit values in QCO to qubit references in QC, ensuring that the semantics of quantum operations are preserved during the conversion process. }]; - let dependentDialects = ["mlir::qc::QCDialect"]; + let dependentDialects = ["mlir::memref::MemRefDialect", + "mlir::qc::QCDialect"]; } diff --git a/mlir/include/mlir/Conversion/QCToQCO/QCToQCO.td b/mlir/include/mlir/Conversion/QCToQCO/QCToQCO.td index 9ffd33d4f7..8b1f6fadc4 100644 --- a/mlir/include/mlir/Conversion/QCToQCO/QCToQCO.td +++ b/mlir/include/mlir/Conversion/QCToQCO/QCToQCO.td @@ -16,5 +16,6 @@ def QCToQCO : Pass<"qc-to-qco"> { It handles the transformation of qubit references in QC to qubit values in QCO, ensuring that the semantics of quantum operations are preserved during the conversion process. }]; - let dependentDialects = ["mlir::qco::QCODialect"]; + let dependentDialects = ["mlir::arith::ArithDialect", "mlir::qco::QCODialect", + "mlir::qtensor::QTensorDialect"]; } diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index dfc77dc6e2..6714a978af 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -8,4 +8,5 @@ add_subdirectory(QC) add_subdirectory(QCO) +add_subdirectory(QIR) add_subdirectory(QTensor) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index f602903950..c1804ea22a 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -126,21 +126,20 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { /** * @brief Allocate a qubit register * @param size Number of qubits (must be positive) - * @param name Register name (default: "q") * @return Vector of qubit references * * @par Example: * ```c++ - * auto q = builder.allocQubitRegister(3, "q"); + * auto q = builder.allocQubitRegister(3); * ``` * ```mlir - * %q0 = qc.alloc("q", 3, 0) : !qc.qubit - * %q1 = qc.alloc("q", 3, 1) : !qc.qubit - * %q2 = qc.alloc("q", 3, 2) : !qc.qubit + * %memref = memref.alloc() : memref<3x!qc.qubit> + * %q0 = memref.load %memref[%c0] : memref<3x!qc.qubit> + * %q1 = memref.load %memref[%c1] : memref<3x!qc.qubit> + * %q2 = memref.load %memref[%c2] : memref<3x!qc.qubit> * ``` */ - llvm::SmallVector allocQubitRegister(int64_t size, - const std::string& name = "q"); + llvm::SmallVector allocQubitRegister(int64_t size); /** * @brief A small structure representing a single classical bit within a @@ -942,6 +941,9 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { /// Track allocated qubits for automatic deallocation llvm::DenseSet allocatedQubits; + /// Track allocated MemRefs for automatic deallocation + llvm::DenseSet allocatedMemrefs; + /// Check if the builder has been finalized void checkFinalized() const; }; diff --git a/mlir/include/mlir/Dialect/QC/CMakeLists.txt b/mlir/include/mlir/Dialect/QC/CMakeLists.txt index b181a84fed..3b0a561d0f 100644 --- a/mlir/include/mlir/Dialect/QC/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/QC/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/QC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/QC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..115aeb67b7 --- /dev/null +++ b/mlir/include/mlir/Dialect/QC/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name QC) +add_public_tablegen_target(MLIRQCTransformsIncGen) + +add_mlir_doc(Passes QCTransforms Passes/ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/QC/Transforms/Passes.h b/mlir/include/mlir/Dialect/QC/Transforms/Passes.h new file mode 100644 index 0000000000..435be783d7 --- /dev/null +++ b/mlir/include/mlir/Dialect/QC/Transforms/Passes.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include "mlir/Dialect/QC/IR/QCDialect.h" + +#include +#include + +namespace mlir::qc { + +#define GEN_PASS_DECL +#include "mlir/Dialect/QC/Transforms/Passes.h.inc" // IWYU pragma: export + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/QC/Transforms/Passes.h.inc" // IWYU pragma: export + +} // namespace mlir::qc diff --git a/mlir/include/mlir/Dialect/QC/Transforms/Passes.td b/mlir/include/mlir/Dialect/QC/Transforms/Passes.td new file mode 100644 index 0000000000..d717641820 --- /dev/null +++ b/mlir/include/mlir/Dialect/QC/Transforms/Passes.td @@ -0,0 +1,27 @@ +// Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +// Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +// All rights reserved. +// +// SPDX-License-Identifier: MIT +// +// Licensed under the MIT License + +#ifndef MLIR_DIALECT_QC_TRANSFORMS_PASSES_TD +#define MLIR_DIALECT_QC_TRANSFORMS_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def ShrinkQubitRegistersPass + : Pass<"qc-shrink-qubit-registers", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qc::QCDialect", "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect"]; + let summary = + "Shrink static qc::QubitType memref registers to accessed indices."; + let description = [{ + Shrinks one-dimensional static memref registers with element type + `!qc.qubit` by removing never-read indices and remapping `memref.load` + users accordingly. + }]; +} + +#endif // MLIR_DIALECT_QC_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 6520d83702..6e86dd513f 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -10,6 +10,7 @@ #pragma once +#include #include #include #include @@ -134,21 +135,20 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { /** * @brief Allocate a qubit register * @param size Number of qubits (must be positive) - * @param name Register name (default: "q") * @return Vector of tracked, valid qubit SSA values * * @par Example: * ```c++ - * auto q = builder.allocQubitRegister(3, "q"); + * auto q = builder.allocQubitRegister(3); * ``` * ```mlir - * %q0 = qco.alloc("q", 3, 0) : !qco.qubit - * %q1 = qco.alloc("q", 3, 1) : !qco.qubit - * %q2 = qco.alloc("q", 3, 2) : !qco.qubit + * %t0 = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * %t1, %q0 = qtensor.extract %t0[%c0]: tensor<3x!qco.qubit> + * %t2, %q1 = qtensor.extract %t1[%c1]: tensor<3x!qco.qubit> + * %t3, %q2 = qtensor.extract %t2[%c2]: tensor<3x!qco.qubit> * ``` */ - llvm::SmallVector allocQubitRegister(int64_t size, - const std::string& name = "q"); + llvm::SmallVector allocQubitRegister(int64_t size); /** * @brief A small structure representing a single classical bit within a @@ -272,36 +272,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { * %outTensor, %q0 = qtensor.extract %tensor[%c0]: tensor<3x!qco.qubit> * ``` */ - std::pair - qtensorExtract(Value tensor, const std::variant& index); - - /** - * @brief Extract a qubit slice from a tensor - * - * @details - * Extracts a slice from a one-dimensional tensor of qubits at the given - * offset and size and returns the updated input tensor and the extracted - * tensor. The extracted tensor is added to the qubit tensor tracking and the - * tracking for the input tensor is updated. - * - * @param tensor Source tensor (must be valid/unconsumed) - * @param offset The offset from where the slice is extracted - * @param size The size of the extracted slice - * @return Pair of (outTensor, extractedSlice) - * - * @par Example: - * ```c++ - * auto [outTensor, extractedSlice] = builder.qtensorExtractSlice(tensor, 0, - * 2); - * ``` - * ```mlir - * %outTensor, %extractedSlice = qtensor.extract_slice %tensor[%c0][%c2] - * : tensor<3x!qco.qubit> to tensor<2x!qco.qubit> - * ``` - */ - std::pair - qtensorExtractSlice(Value tensor, const std::variant& offset, - const std::variant& size); + std::pair qtensorExtract(Value tensor, const int64_t index); /** * @brief Insert a qubit into a tensor @@ -328,35 +299,6 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { Value qtensorInsert(Value scalar, Value tensor, const std::variant& index); - /** - * @brief Insert a qubit slice into a tensor - * - * @details - * Inserts a one-dimensional tensor of qubits into another one-dimensional - * tensor of qubits at the given offset and size. The inserted tensor slice is - * consumed and removed from the tracking, while the tracking for the - * destination tensor is updated. - * - * @param sourceTensor The slice that is inserted (must be valid/unconsumed) - * @param destTensor The tensor where the slice is inserted (must be - * valid/unconsumed) - * @param offset The offset into where the slice is inserted - * @param size The size of the inserted slice - * @return The output tensor - * - * @par Example: - * ```c++ - * auto outTensor = builder.qtensorInsertSlice(slicedTensor, tensor, 0, 2); - * ``` - * ```mlir - * %outTensor = qtensor.insert_slice %slicedTensor into %tensor[%c0][%c2] - * : tensor<2x!qco.qubit> into tensor<3x!qco.qubit> - * ``` - */ - Value qtensorInsertSlice(Value sourceTensor, Value destTensor, - const std::variant& offset, - const std::variant& size); - /** * @brief Explicitly deallocate a tensor * @@ -1347,11 +1289,24 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ void updateQubitTracking(Value inputQubit, Value outputQubit); + /// Count unique tensors + int64_t tensorCounter = 0; + + /** + * @brief Information about a qubit + */ + struct QubitInfo { + /// ID of the register the qubit belongs to + int64_t regId = -1; + /// Index of the qubit within its register + int64_t regIndex = -1; + }; + /// Track valid (unconsumed) qubit SSA values for linear type enforcement. - /// Only values present in this set are valid for use in operations. + /// Only values present in this map are valid for use in operations. /// When an operation consumes a qubit and produces a new one, the old value /// is removed and the new output is added. - llvm::DenseSet validQubits; + llvm::DenseMap validQubits; /** * @brief Validate that a tensor value is valid and unconsumed. This also @@ -1369,10 +1324,18 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ void updateTensorTracking(Value inputTensor, Value outputTensor); + /** + * @brief Information about a tensor + */ + struct TensorInfo { + /// ID of the register the tensor corresponds to + int64_t regId = -1; + }; + /// Track valid (unconsumed) tensor SSA values for linear type enforcement. - /// Only values present in this set are valid for use in operations. + /// Only values present in this map are valid for use in operations. /// When an operation consumes a tensor and produces a new one, the old value /// is removed and the new output is added. - llvm::DenseSet validTensors; + llvm::DenseMap validTensors; }; } // namespace mlir::qco diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index efee9623e5..20b1190d15 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -180,6 +180,7 @@ def ResetOp : QCOOp<"reset", [Idempotent, SameOperandsAndResultType]> { let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -524,6 +525,7 @@ def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -552,6 +554,7 @@ def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -580,6 +583,7 @@ def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -608,6 +612,7 @@ def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -841,6 +846,7 @@ def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit0_in, "Value":$qubit1_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -872,6 +878,7 @@ def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit0_in, "Value":$qubit1_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -903,6 +910,7 @@ def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit0_in, "Value":$qubit1_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -934,6 +942,7 @@ def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit0_in, "Value":$qubit1_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -968,6 +977,7 @@ def XXPlusYYOp : QCOOp<"xx_plus_yy", "const std::variant&":$theta, "const std::variant&":$beta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -1002,6 +1012,7 @@ def XXMinusYYOp : QCOOp<"xx_minus_yy", "const std::variant&":$theta, "const std::variant&":$beta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 337b767fb9..f888509129 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -242,48 +242,4 @@ mergeTwoTargetOneParameterWithSwappedTargets(OpType op, return success(); } -/** - * @brief Remove a trivial one-target, one-parameter operation - * - * @tparam OpType The type of the operation to be checked. - * @param op The operation instance. - * @param rewriter The pattern rewriter. - * @return LogicalResult Success or failure of the removal. - */ -template -mlir::LogicalResult -removeTrivialOneTargetOneParameter(OpType op, PatternRewriter& rewriter) { - const auto param = utils::valueToDouble(op.getOperand(1)); - if (!param || std::abs(*param) > utils::TOLERANCE) { - return failure(); - } - - // Trivialize operation - rewriter.replaceOp(op, op.getInputQubit(0)); - - return success(); -} - -/** - * @brief Remove a trivial two-target, one-parameter operation - * - * @tparam OpType The type of the operation to be checked. - * @param op The operation instance. - * @param rewriter The pattern rewriter. - * @return LogicalResult Success or failure of the removal. - */ -template -mlir::LogicalResult -removeTrivialTwoTargetOneParameter(OpType op, PatternRewriter& rewriter) { - const auto param = utils::valueToDouble(op.getOperand(2)); - if (!param || std::abs(*param) > utils::TOLERANCE) { - return failure(); - } - - // Trivialize operation - rewriter.replaceOp(op, op.getInputQubits()); - - return success(); -} - } // namespace mlir::qco diff --git a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h index d09133fd71..d31d4f0bdb 100644 --- a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h @@ -13,8 +13,10 @@ #include "mlir/Dialect/QIR/Utils/QIRMetadata.h" #include +#include #include #include +#include #include #include #include @@ -233,10 +235,10 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * @brief Measure a qubit and record the result (simple version) * * @details - * Performs a Z-basis measurement using __quantum__qis__mz__body. The - * result is tracked for deferred output recording in the output block. - * This version does NOT include register information, so output will - * not be grouped by register. + * Performs a Z-basis measurement using `__quantum__qis__mz__body`. + * + * The output is recorded via `__quantum__rt__result_record_output` during + * `finalize()`. * * @param qubit The qubit to measure * @param resultIndex The classical bit index for result pointer @@ -247,12 +249,17 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * auto result = builder.measure(q0, 0); * ``` * ```mlir + * // In entry block: + * %zero = llvm.mlir.zero : !llvm.ptr + * %r = llvm.call @"@__quantum__rt__result_allocate"(%zero) : !llvm.ptr -> + * !llvm.ptr + * * // In measurements block: - * %c0 = llvm.mlir.constant(0 : i64) : i64 - * %r = llvm.inttoptr %c0 : i64 to !llvm.ptr * llvm.call @__quantum__qis__mz__body(%q0, %r) : (!llvm.ptr, !llvm.ptr) -> () * - * // Output recording deferred to output block + * // In output block: + * llvm.call @__quantum__rt__result_record_output(%r, %label) : (!llvm.ptr, + * !llvm.ptr) -> () * ``` */ Value measure(Value qubit, int64_t resultIndex); @@ -261,12 +268,10 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * @brief Measure a qubit into a classical register * * @details - * Performs a Z-basis measurement using __quantum__qis__mz__body and tracks - * the measurement with register information for array-based output recording. - * Output recording is deferred to the output block during finalize(), where - * measurements are grouped by register and recorded using: - * 1. __quantum__rt__array_record_output for each register - * 2. __quantum__rt__result_record_output for each measurement in the register + * Performs a Z-basis measurement using `__quantum__qis__mz__body`. + * + * The output is recorded via `__quantum__rt__result_array_record_output` + * during `finalize()`. * * @param qubit The qubit to measure * @param bit The classical bit to store the result @@ -276,21 +281,21 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * ```c++ * auto c = builder.allocClassicalBitRegister(2, "c"); * builder.measure(q0, c[0]); - * builder.measure(q1, c[1]); * ``` * ```mlir + * // In entry block: + * %zero = llvm.mlir.zero : !llvm.ptr + * %alloca = llvm.alloca %c2 x !llvm.ptr : (i64) -> !llvm.ptr + * llvm.call @"@__quantum__rt__result_array_allocate"(%c2, %alloca, %zero) : + * (i64, !llvm.ptr, !llvm.ptr) -> () + * %r = llvm.load %alloca : !llvm.ptr -> !llvm.ptr + * * // In measurements block: - * llvm.call @__quantum__qis__mz__body(%q0, %r0) : (!llvm.ptr, !llvm.ptr) -> - * () llvm.call @__quantum__qis__mz__body(%q1, %r1) : (!llvm.ptr, !llvm.ptr) - * -> () + * llvm.call @__quantum__qis__mz__body(%q, %r) : (!llvm.ptr, !llvm.ptr) -> () * - * // In output block (generated during finalize): - * @0 = internal constant [3 x i8] c"c\00" - * @1 = internal constant [5 x i8] c"c0r\00" - * @2 = internal constant [5 x i8] c"c1r\00" - * llvm.call @__quantum__rt__array_record_output(i64 2, ptr @0) - * llvm.call @__quantum__rt__result_record_output(ptr %r0, ptr @1) - * llvm.call @__quantum__rt__result_record_output(ptr %r1, ptr @2) + * // In output block: + * llvm.call @__quantum__rt__result_array_record_output(%c2, %alloca, %label) + * : (i64, !llvm.ptr, !llvm.ptr) -> () * ``` */ QIRProgramBuilder& measure(Value qubit, const Bit& bit); @@ -309,7 +314,7 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * builder.reset(q); * ``` * ```mlir - * llvm.call @__quantum__qis__reset__body(%q) : (!llvm.ptr) -> () + * llvm.call @__quantum__qis__reset__body(%q) : !llvm.ptr -> () * ``` */ QIRProgramBuilder& reset(Value qubit); @@ -350,7 +355,7 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * builder.OP_NAME(q); \ * ``` \ * ```mlir \ - * llvm.call @__quantum__qis__##QIR_NAME##__body(%q) : (!llvm.ptr) -> () \ + * llvm.call @__quantum__qis__##QIR_NAME##__body(%q) : !llvm.ptr -> () \ * ``` \ */ \ QIRProgramBuilder& OP_NAME(Value qubit); \ @@ -840,11 +845,10 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * @brief Finalize the program and return the constructed module * * @details - * Automatically deallocates all remaining allocated qubits, generates - * array-based output recording in the output block (grouped by register), - * ensures proper QIR metadata attributes are set, and transfers ownership - * of the module to the caller. The builder should not be used after calling - * this method. + * Automatically deallocates all remaining allocated qubits and result + * pointers, generates output recording in the output block, ensures proper + * QIR metadata attributes are set, and transfers ownership of the module to + * the caller. The builder should not be used after calling this method. * * @return OwningOpRef containing the constructed QIR program module */ @@ -886,11 +890,20 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { /// Exit code constant (created in entry block, used in output block) Value exitCode; - /// Cache static pointers for reuse - llvm::DenseMap ptrCache; + /// Cache static qubit pointers for reuse + llvm::DenseMap staticQubits; + + /// Set of qubit-array pointers + llvm::DenseSet qubitArrays; + + /// Map from register name to result-array pointer + llvm::StringMap resultArrays; + + /// Map from (register name, index) to loaded result + llvm::DenseMap, Value> loadedResults; - /// Map from (register_name, register_index) to result pointer - llvm::DenseMap, Value> registerResultMap; + /// Map from result index to result pointer for non-register results + llvm::DenseMap resultPtrs; /// Track qubit and result counts for QIR metadata QIRMetadata metadata_; @@ -918,10 +931,8 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * @brief Generate array-based output recording in the output block * * @details - * Called by finalize() to generate output recording calls for all tracked - * measurements. Groups measurements by register and generates: - * 1. array_record_output for each register - * 2. result_record_output for each measurement in the register + * Called by `finalize()` to generate output recording calls for all tracked + * measurements. */ void generateOutputRecording(); diff --git a/mlir/include/mlir/Dialect/QIR/CMakeLists.txt b/mlir/include/mlir/Dialect/QIR/CMakeLists.txt new file mode 100644 index 0000000000..3c339729a9 --- /dev/null +++ b/mlir/include/mlir/Dialect/QIR/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/QIR/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/QIR/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..6d81096520 --- /dev/null +++ b/mlir/include/mlir/Dialect/QIR/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name QIR) +add_public_tablegen_target(MLIRQIRTransformsIncGen) + +add_mlir_doc(Passes QIRTransforms Passes/ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/QIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.h new file mode 100644 index 0000000000..e1d281d87a --- /dev/null +++ b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include +#include + +namespace mlir::qir { + +#define GEN_PASS_DECL +#include "mlir/Dialect/QIR/Transforms/Passes.h.inc" // IWYU pragma: export + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/QIR/Transforms/Passes.h.inc" // IWYU pragma: export + +} // namespace mlir::qir diff --git a/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td new file mode 100644 index 0000000000..efb7fea4b1 --- /dev/null +++ b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td @@ -0,0 +1,23 @@ +// Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +// Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +// All rights reserved. +// +// SPDX-License-Identifier: MIT +// +// Licensed under the MIT License + +#ifndef MLIR_DIALECT_QIR_TRANSFORMS_PASSES_TD +#define MLIR_DIALECT_QIR_TRANSFORMS_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def QIRCleanupPass : Pass<"qir-cleanup", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let summary = "Remove redundant QIR runtime bookkeeping."; + let description = [{ + Removes redundant QIR runtime qubit-array allocation/release pairs that do + not contribute to observable behavior, and keeps QIR modules compact. + }]; +} + +#endif // MLIR_DIALECT_QIR_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h index 33190e535e..6a28995c83 100644 --- a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h +++ b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h @@ -34,11 +34,27 @@ namespace mlir::qir { // QIR function names +inline constexpr auto QIR_QUBIT_ARRAY_ALLOC = + "__quantum__rt__qubit_array_allocate"; +inline constexpr auto QIR_QUBIT_ARRAY_RELEASE = + "__quantum__rt__qubit_array_release"; + +inline constexpr auto QIR_QUBIT_ALLOC = "__quantum__rt__qubit_allocate"; +inline constexpr auto QIR_QUBIT_RELEASE = "__quantum__rt__qubit_release"; + +inline constexpr auto QIR_RESULT_ARRAY_ALLOC = + "__quantum__rt__result_array_allocate"; +inline constexpr auto QIR_RESULT_ARRAY_RELEASE = + "__quantum__rt__result_array_release"; + +inline constexpr auto QIR_RESULT_ALLOC = "__quantum__rt__result_allocate"; +inline constexpr auto QIR_RESULT_RELEASE = "__quantum__rt__result_release"; + inline constexpr auto QIR_INITIALIZE = "__quantum__rt__initialize"; inline constexpr auto QIR_MEASURE = "__quantum__qis__mz__body"; inline constexpr auto QIR_RECORD_OUTPUT = "__quantum__rt__result_record_output"; inline constexpr auto QIR_ARRAY_RECORD_OUTPUT = - "__quantum__rt__array_record_output"; + "__quantum__rt__result_array_record_output"; inline constexpr auto QIR_RESET = "__quantum__qis__reset__body"; inline constexpr auto QIR_GPHASE = "__quantum__qis__gphase__body"; @@ -164,7 +180,7 @@ LLVM::LLVMFuncOp getMainFunction(Operation* op); * - `required_num_qubits`: Number of qubits used * - `required_num_results`: Number of measurement results * - `qir_major_version`: 2 - * - `qir_minor_version`: 0 + * - `qir_minor_version`: 1 * - `dynamic_qubit_management`: true/false * - `dynamic_result_management`: true/false * diff --git a/mlir/include/mlir/Dialect/QTensor/CMakeLists.txt b/mlir/include/mlir/Dialect/QTensor/CMakeLists.txt index b181a84fed..3b0a561d0f 100644 --- a/mlir/include/mlir/Dialect/QTensor/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/QTensor/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index f421ee2dc6..1ac3aa1885 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -134,48 +134,6 @@ def ExtractOp let results = (outs 1DTensorOf<[QubitType]>:$out_tensor, QubitType:$result); let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor)"; - let hasFolder = 1; - let hasVerifier = 1; -} - -def ExtractSliceOp - : QTensorOp<"extract_slice", - [Pure, - TypesMatchWith<"returned tensor type matches input tensor", - "tensor", "out_tensor", "$_self">]> { - let summary = "Extract slice from tensor"; - let description = [{ - The `qtensor.extract_slice` operation is the modified version of the standard `tensor.extract_slice` - operation of the tensor dialect. It reads a one-dimensional qubit tensor and returns the extracted tensor of qubits specified by the - offset and size argument. In addition, it also returns the updated input tensor as result. - - The extract_slice operation supports the following arguments: - - - tensor: the "base" tensor from which to extract a slice. - - offset: the starting position in the base tensor from which the slice is - extracted. - - size: the length of the slice to extract from the base tensor. - - Example: - ```mlir - %outTensor, %extractedSlice = qtensor.extract_slice %tensor[%c0][%c2] : tensor<3x!qco.qubit> to tensor<2x!qco.qubit> - ``` - }]; - - let arguments = (ins 1DTensorOf<[QubitType]>:$tensor, Index:$offset, - Index:$size); - let results = (outs 1DTensorOf<[QubitType]>:$out_tensor, - 1DTensorOf<[QubitType]>:$result); - - let assemblyFormat = [{ - $tensor `[`$offset `]` `[`$size`]` - attr-dict `:` type($tensor) `to` type($result) - }]; - - let builders = [OpBuilder<(ins "Value":$tensor, "Value":$offset, - "Value":$size, CArg<"ArrayRef", "{}">:$attrs)>]; - - let hasFolder = 1; let hasVerifier = 1; } @@ -207,42 +165,7 @@ def InsertOp $scalar `into` $dest `[` $index `]` attr-dict `:` type($dest) }]; - let hasFolder = 1; - let hasVerifier = 1; -} - -def InsertSliceOp - : QTensorOp<"insert_slice", - [Pure, TypesMatchWith<"expected result type to match dest type", - "dest", "result", "$_self">]> { - let summary = "Insert slice into tensor"; - let description = [{ - The `qtensor.insert_slice` operation is a modified version of the `tensor.insert_slice` operation - of the tensor dialect. The operation inserts a tensor `source` into another tensor `dest` as specified by the - operation's offset and size arguments. This insertion consumes the `source` tensor of qubits. - - The insert_slice operation supports the following arguments: - - - source: the tensor that is inserted. The source is consumed after the insertion. - - dest: the tensor into which the source tensor is inserted. - - offset: the starting index in the `dest` tensor where the slice is inserted. - - size: the number of elements in the slice, which must match the size of the - source tensor type. - - Example: - ```mlir - %outTensor = qtensor.insert_slice %slicedTensor into %tensor[%c0][%c2] : tensor<2x!qco.qubit> into tensor<3x!qco.qubit> - ``` - }]; - - let arguments = (ins 1DTensorOf<[QubitType]>:$source, - 1DTensorOf<[QubitType]>:$dest, Index:$offset, Index:$size); - let results = (outs 1DTensorOf<[QubitType]>:$result); - let assemblyFormat = [{ - $source `into` $dest `[`$offset `]` `[`$size`]` - attr-dict `:` type($source) `into` type($dest) - }]; - + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h new file mode 100644 index 0000000000..594be34481 --- /dev/null +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" + +#include +#include +#include + +namespace mlir::qtensor { + +/** + * @brief Checks whether two index values are equivalent. + * + * @details This is a conservative check that returns true if both indices are + * constant integers with the same value. It returns false if either index is + * non-constant or if they have different constant values. Note that this means + * that some equivalent indices may be considered non-equivalent by this + * function, but no non-equivalent indices will be considered equivalent. + */ +inline bool areEquivalentIndices(Value lhs, Value rhs) { + auto lhsValue = getConstantIntValue(lhs); + auto rhsValue = getConstantIntValue(rhs); + if (!lhsValue || !rhsValue) { + return false; + } + return *lhsValue == *rhsValue; +} + +/** + * @brief Tensor-transforming ops in a scalar extract/insert chain. + */ +inline bool isTensorChainOp(Operation* op) { + return llvm::isa(op); +} + +/** + * @brief Returns the tensor input of a tensor-transforming op. + */ +inline Value getTensorChainInput(Operation* op) { + if (auto insertOp = llvm::dyn_cast(op)) { + return insertOp.getDest(); + } + if (auto extractOp = llvm::dyn_cast(op)) { + return extractOp.getTensor(); + } + return nullptr; +} + +/** + * @brief Returns the tensor output of a tensor-transforming op. + */ +inline Value getTensorChainOutput(Operation* op) { + if (auto insertOp = llvm::dyn_cast(op)) { + return insertOp.getResult(); + } + if (auto extractOp = llvm::dyn_cast(op)) { + return extractOp.getOutTensor(); + } + return nullptr; +} + +/** + * @brief Rewire the tensor input of a tensor-transforming op. + */ +inline void setTensorChainInput(Operation* op, Value tensor) { + if (llvm::isa(op)) { + op->setOperand(1, tensor); + return; + } + if (llvm::isa(op)) { + op->setOperand(0, tensor); + } +} + +} // namespace mlir::qtensor diff --git a/mlir/include/mlir/Dialect/QTensor/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/QTensor/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..de5795040d --- /dev/null +++ b/mlir/include/mlir/Dialect/QTensor/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name QTensor) +add_public_tablegen_target(MLIRQTensorTransformsIncGen) + +add_mlir_doc(Passes QTensorTransforms Passes/ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.h new file mode 100644 index 0000000000..e33924b5a5 --- /dev/null +++ b/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" + +#include +#include + +namespace mlir::qtensor { + +#define GEN_PASS_DECL +#include "mlir/Dialect/QTensor/Transforms/Passes.h.inc" // IWYU pragma: export + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/QTensor/Transforms/Passes.h.inc" // IWYU pragma: export + +} // namespace mlir::qtensor diff --git a/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.td new file mode 100644 index 0000000000..cfc322d3cf --- /dev/null +++ b/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.td @@ -0,0 +1,25 @@ +// Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +// Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +// All rights reserved. +// +// SPDX-License-Identifier: MIT +// +// Licensed under the MIT License + +#ifndef MLIR_DIALECT_QTENSOR_TRANSFORMS_PASSES_TD +#define MLIR_DIALECT_QTENSOR_TRANSFORMS_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def ShrinkQTensorToFitPass : Pass<"qtensor-shrink-to-fit", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qtensor::QTensorDialect", + "mlir::arith::ArithDialect"]; + let summary = "Shrink static qtensors to their actually accessed indices."; + let description = [{ + Shrinks one-dimensional static qtensors by tracing linear tensor chains from + `qtensor.dealloc` to `qtensor.alloc` and rebuilding the chain on a compact + allocation that only keeps accessed indices. + }]; +} + +#endif // MLIR_DIALECT_QTENSOR_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/Support/Passes.h b/mlir/include/mlir/Support/Passes.h index c671ca6970..f780be87f9 100644 --- a/mlir/include/mlir/Support/Passes.h +++ b/mlir/include/mlir/Support/Passes.h @@ -10,11 +10,42 @@ #pragma once +#include "mlir/Support/LogicalResult.h" + namespace mlir { class ModuleOp; -} +class PassManager; +} // namespace mlir + +/** + * @brief Populate a QC-oriented cleanup pipeline on the given pass manager. + * @details Adds generic cleanup and QC qubit-register shrinking. + */ +void populateQCCleanupPipeline(mlir::PassManager& pm); + +/** + * @brief Populate a QCO-oriented cleanup pipeline on the given pass manager. + * @details Adds generic cleanup and qtensor shrink-to-fit. + */ +void populateQCOCleanupPipeline(mlir::PassManager& pm); + +/** + * @brief Populate a QIR-oriented cleanup pipeline on the given pass manager. + * @details Adds generic cleanup and QIR-specific simplifications. + */ +void populateQIRCleanupPipeline(mlir::PassManager& pm); + +/** + * @brief Run the QC-oriented cleanup pipeline on a module. + */ +[[nodiscard]] mlir::LogicalResult runQCCleanupPipeline(mlir::ModuleOp module); + +/** + * @brief Run the QCO-oriented cleanup pipeline on a module. + */ +[[nodiscard]] mlir::LogicalResult runQCOCleanupPipeline(mlir::ModuleOp module); /** - * @brief Run canonicalization and dead value removal on the given module. + * @brief Run the QIR-oriented cleanup pipeline on a module. */ -void runCanonicalizationPasses(mlir::ModuleOp module); +[[nodiscard]] mlir::LogicalResult runQIRCleanupPipeline(mlir::ModuleOp module); diff --git a/mlir/lib/Compiler/CompilerPipeline.cpp b/mlir/lib/Compiler/CompilerPipeline.cpp index 119e343755..0bab524e53 100644 --- a/mlir/lib/Compiler/CompilerPipeline.cpp +++ b/mlir/lib/Compiler/CompilerPipeline.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/QCOToQC/QCOToQC.h" #include "mlir/Conversion/QCToQCO/QCToQCO.h" #include "mlir/Conversion/QCToQIR/QCToQIR.h" +#include "mlir/Support/Passes.h" #include "mlir/Support/PrettyPrinting.h" #include @@ -20,7 +21,6 @@ #include #include #include -#include #include @@ -45,14 +45,6 @@ static void prettyPrintStage(ModuleOp module, const llvm::StringRef stageName, printProgram(module, stageHeader, llvm::errs()); } -void QuantumCompilerPipeline::addCleanupPasses(PassManager& pm) { - // Always run canonicalization, common sub-expression elimination, and dead - // value removal - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - pm.addPass(createRemoveDeadValuesPass()); -} - void QuantumCompilerPipeline::configurePassManager(PassManager& pm) const { // Enable timing statistics if requested if (config_.enableTiming) { @@ -85,15 +77,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, // Determine total number of stages for progress indication // 1. QC import - // 2. QC canonicalization + // 2. QC cleanup // 3. QC-to-QCO conversion - // 4. QCO canonicalization + // 4. QCO cleanup // 5. Optimization passes - // 6. QCO canonicalization + // 6. QCO cleanup // 7. QCO-to-QC conversion - // 8. QC canonicalization + // 8. QC cleanup // 9. QC-to-QIR conversion (optional) - // 10. QIR canonicalization (optional) + // 10. QIR cleanup (optional) auto totalStages = 8; if (config_.convertToQIR) { totalStages += 2; @@ -108,14 +100,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, } } - // Stage 2: QC canonicalization - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 2: QC cleanup + if (failed( + runStage([&](PassManager& pm) { populateQCCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterInitialCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "Initial QC Canonicalization", ++currentStage, + prettyPrintStage(module, "Initial QC Cleanup", ++currentStage, totalStages); } } @@ -130,20 +123,22 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, totalStages); } } - // Stage 4: QCO canonicalization - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 4: QCO cleanup + if (failed( + runStage([&](PassManager& pm) { populateQCOCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterQCOCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "Initial QCO Canonicalization", ++currentStage, + prettyPrintStage(module, "Initial QCO Cleanup", ++currentStage, totalStages); } } // Stage 5: Optimization passes // TODO: Add optimization passes - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + if (failed( + runStage([&](PassManager& pm) { populateQCOCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { @@ -153,14 +148,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, totalStages); } } - // Stage 6: QCO canonicalization - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 6: QCO cleanup + if (failed( + runStage([&](PassManager& pm) { populateQCOCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterOptimizationCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "Final QCO Canonicalization", ++currentStage, + prettyPrintStage(module, "Final QCO Cleanup", ++currentStage, totalStages); } } @@ -175,15 +171,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, totalStages); } } - // Stage 8: QC canonicalization - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 8: QC cleanup + if (failed( + runStage([&](PassManager& pm) { populateQCCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterQCCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "Final QC Canonicalization", ++currentStage, - totalStages); + prettyPrintStage(module, "Final QC Cleanup", ++currentStage, totalStages); } } // Stage 9: QC-to-QIR conversion (optional) @@ -199,15 +195,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, totalStages); } } - // Stage 10: QIR canonicalization (optional) - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 10: QIR cleanup (optional) + if (failed(runStage( + [&](PassManager& pm) { populateQIRCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterQIRCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "QIR Canonicalization", ++currentStage, - totalStages); + prettyPrintStage(module, "QIR Cleanup", ++currentStage, totalStages); } } } diff --git a/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp b/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp index 24e467b99c..1c958b346b 100644 --- a/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp +++ b/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include #include @@ -381,6 +383,111 @@ static LogicalResult cleanUp(Operation* op) { namespace { +/** + * @brief Converts jeff.qureg_alloc to qtensor.alloc + * + * @par Example: + * ```mlir + * %qureg = jeff.qureg_alloc(%c3) : !jeff.qureg + * ``` + * is converted to + * ```mlir + * %tensor = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * ``` + */ +struct ConvertJeffQuregAllocOpToQCO final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(jeff::QuregAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto size = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getIndexType(), adaptor.getNumQubits()); + rewriter.replaceOpWithNewOp(op, size.getResult()); + return success(); + } +}; + +/** + * @brief Converts jeff.qureg_extract_index to qtensor.extract + * + * @par Example: + * ```mlir + * %qureg_out, %q = jeff.qureg_extract_index(%c0) %qureg_in : !jeff.qureg, + * !jeff.qubit + * ``` + * is converted to + * ```mlir + * %tensor_out, %q = qtensor.extract %tensor_in[%c0]: tensor<3x!qco.qubit> + * ``` + */ +struct ConvertJeffQuregExtractIndexOpToQCO final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(jeff::QuregExtractIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto index = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getIndexType(), adaptor.getIndex()); + rewriter.replaceOpWithNewOp(op, adaptor.getInQreg(), + index.getResult()); + return success(); + } +}; + +/** + * @brief Converts jeff.qureg_insert_index to qtensor.insert + * + * @par Example: + * ```mlir + * %qureg_out = jeff.qureg_insert_index(%c0) %qureg_in %q : !jeff.qureg + * ``` + * is converted to + * ```mlir + * %tensor_out = qtensor.insert %q into %tensor_in[%c0] : tensor<3x!qco.qubit> + * ``` + */ +struct ConvertJeffQuregInsertIndexOpToQCO final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(jeff::QuregInsertIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto index = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getIndexType(), adaptor.getIndex()); + rewriter.replaceOpWithNewOp( + op, adaptor.getInQubit(), adaptor.getInQreg(), index.getResult()); + return success(); + } +}; + +/** + * @brief Converts jeff.qureg_free_zero to qtensor.dealloc + * + * @par Example: + * ```mlir + * jeff.qureg_free_zero %qureg : !jeff.qureg + * ``` + * is converted to + * ```mlir + * qtensor.dealloc %tensor : tensor<3x!qco.qubit> + * ``` + */ +struct ConvertJeffQuregFreeZeroOpToQCO final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(jeff::QuregFreeZeroOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getQreg()); + return success(); + } +}; + /** * @brief Converts jeff.qubit_alloc to qco.alloc * @@ -893,7 +1000,8 @@ struct ConvertJeffMainToQCO final : OpConversionPattern { * @brief Type converter for Jeff-to-QCO conversion * * @details - * Converts `!jeff.qubit` to `!qco.qubit`. + * Converts `!jeff.qubit` to `!qco.qubit` and `!jeff.qureg` to + * `!tensor`. */ class JeffToQCOTypeConverter final : public TypeConverter { public: @@ -904,6 +1012,11 @@ class JeffToQCOTypeConverter final : public TypeConverter { addConversion([ctx](jeff::QubitType /*type*/) -> Type { return qco::QubitType::get(ctx); }); + + addConversion([ctx](jeff::QuregType /*type*/) -> Type { + return RankedTensorType::get({ShapedType::kDynamic}, + qco::QubitType::get(ctx)); + }); } }; @@ -924,7 +1037,8 @@ struct JeffToQCO final : impl::JeffToQCOBase { // Configure conversion target target.addIllegalDialect(); - target.addLegalDialect(); target.addDynamicallyLegalOp([&](func::FuncOp op) { @@ -936,6 +1050,8 @@ struct JeffToQCO final : impl::JeffToQCOBase { // Register operation conversion patterns jeff::populateJeffToNativeConversionPatterns(patterns); patterns.add< + ConvertJeffQuregAllocOpToQCO, ConvertJeffQuregExtractIndexOpToQCO, + ConvertJeffQuregInsertIndexOpToQCO, ConvertJeffQuregFreeZeroOpToQCO, ConvertJeffQubitAllocOpToQCO, ConvertJeffQubitFreeOpToQCO, ConvertJeffQubitFreeZeroOpToQCO, ConvertJeffQubitMeasureOpToQCO, ConvertJeffQubitMeasureNDOpToQCO, ConvertJeffQubitResetOpToQCO, diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index 01ed0f82a8..d28c4f7596 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include #include @@ -23,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -244,6 +247,129 @@ static LogicalResult cleanUp(Operation* op, LoweringState& state) { namespace { +/** + * @brief Converts qtensor.alloc to jeff.qureg_alloc + * + * @par Example: + * ```mlir + * %tensor = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %qureg = jeff.qureg_alloc(%c3) : !jeff.qureg + * ``` + */ +struct ConvertQTensorAllocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // TODO: Why is this not happening in native conversion? + auto sizeValue = getConstantIntValue(adaptor.getSize()); + Value size; + if (sizeValue.has_value()) { + size = jeff::IntConst32Op::create(rewriter, op.getLoc(), *sizeValue); + } else { + size = adaptor.getSize(); + } + rewriter.replaceOpWithNewOp(op, size); + return success(); + } +}; + +/** + * @brief Converts qtensor.extract to jeff.qureg_extract_index + * + * @par Example: + * ```mlir + * %tensor_out, %q = qtensor.extract %tensor_in[%c0]: tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %qureg_out, %q = jeff.qureg_extract_index(%c0) %qureg_in : !jeff.qureg, + * !jeff.qubit + * ``` + */ +struct ConvertQTensorExtractOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // TODO: Why is this not happening in native conversion? + auto indexValue = getConstantIntValue(adaptor.getIndex()); + Value index; + if (indexValue.has_value()) { + index = jeff::IntConst32Op::create(rewriter, op.getLoc(), *indexValue); + } else { + index = adaptor.getIndex(); + } + rewriter.replaceOpWithNewOp( + op, adaptor.getTensor(), index); + return success(); + } +}; + +/** + * @brief Converts qtensor.insert to jeff.qureg_insert_index + * + * @par Example: + * ```mlir + * %tensor_out = qtensor.insert %q into %tensor_in[%c0] : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %qureg_out = jeff.qureg_insert_index(%c0) %qureg_in %q : !jeff.qureg + * ``` + */ +struct ConvertQTensorInsertOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // TODO: Why is this not happening in native conversion? + auto indexValue = getConstantIntValue(adaptor.getIndex()); + Value index; + if (indexValue.has_value()) { + index = jeff::IntConst32Op::create(rewriter, op.getLoc(), *indexValue); + } else { + index = adaptor.getIndex(); + } + rewriter.replaceOpWithNewOp( + op, adaptor.getDest(), index, adaptor.getScalar()); + return success(); + } +}; + +/** + * @brief Converts qtensor.dealloc to jeff.qureg_free_zero + * + * @par Example: + * ```mlir + * qtensor.dealloc %tensor : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * jeff.qureg_free_zero %qureg : !jeff.qureg + * ``` + */ +struct ConvertQTensorDeallocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getTensor()); + return success(); + } +}; + /** * @brief Converts qco.alloc to jeff.qubit_alloc * @@ -1320,7 +1446,8 @@ struct ConvertQCOMainToJeff final : StatefulOpConversionPattern { * @brief Type converter for QCO-to-Jeff conversion * * @details - * Converts `!qco.qubit` to `!jeff.qubit`. + * Converts `!qco.qubit` to `!jeff.qubit` and `tensor` to + * `!jeff.qureg`. */ class QCOToJeffTypeConverter final : public TypeConverter { public: @@ -1331,6 +1458,13 @@ class QCOToJeffTypeConverter final : public TypeConverter { addConversion([ctx](qco::QubitType /*type*/) -> Type { return jeff::QubitType::get(ctx); }); + + addConversion([ctx](RankedTensorType type) -> Type { + if (llvm::isa(type.getElementType())) { + return jeff::QuregType::get(ctx); + } + return type; + }); } }; @@ -1352,7 +1486,8 @@ struct QCOToJeff final : impl::QCOToJeffBase { LoweringState state; // Configure conversion target - target.addIllegalDialect(); target.addLegalDialect(); @@ -1363,9 +1498,10 @@ struct QCOToJeff final : impl::QCOToJeffBase { // Register operation conversion patterns jeff::populateNativeToJeffConversionPatterns(patterns); patterns.add< - ConvertQCOAllocOpToJeff, ConvertQCOSinkOpToJeff, - ConvertQCOMeasureOpToJeff, ConvertQCOResetOpToJeff, - ConvertQCOGPhaseOpToJeff, + ConvertQTensorAllocOp, ConvertQTensorExtractOp, ConvertQTensorInsertOp, + ConvertQTensorDeallocOp, ConvertQCOAllocOpToJeff, + ConvertQCOSinkOpToJeff, ConvertQCOMeasureOpToJeff, + ConvertQCOResetOpToJeff, ConvertQCOGPhaseOpToJeff, ConvertQCOOneTargetZeroParameterToJeff, ConvertQCOOneTargetZeroParameterToJeff, ConvertQCOOneTargetZeroParameterToJeff, diff --git a/mlir/lib/Conversion/QCOToQC/CMakeLists.txt b/mlir/lib/Conversion/QCOToQC/CMakeLists.txt index f7b35a11b3..b3a1b1cf18 100644 --- a/mlir/lib/Conversion/QCOToQC/CMakeLists.txt +++ b/mlir/lib/Conversion/QCOToQC/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_conversion_library( MLIRQTensorDialect MLIRArithDialect MLIRFuncDialect + MLIRMemRefDialect MLIRTransforms MLIRFuncTransforms) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 7c1dc2baca..a9f83c3e88 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -14,13 +14,19 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include #include #include +#include #include +#include #include #include #include +#include #include #include #include @@ -94,6 +100,10 @@ namespace { * The primary conversion is from !qco.qubit to !qc.qubit, which * represents the semantic shift from value types to reference types. * + * Qubit tensor types preserve their shape during conversion: a statically + * shaped `tensor` becomes `memref`, while a + * dynamically shaped `tensor` becomes `memref`. + * * Other types (integers, booleans, etc.) pass through unchanged via * the identity conversion. */ @@ -107,6 +117,109 @@ class QCOToQCTypeConverter final : public TypeConverter { addConversion([ctx](qco::QubitType /*type*/) -> Type { return qc::QubitType::get(ctx); }); + + addConversion([ctx](RankedTensorType type) -> Type { + if (llvm::isa(type.getElementType())) { + return MemRefType::get(type.getShape(), qc::QubitType::get(ctx)); + } + return type; + }); + } +}; + +/** + * @brief Converts qtensor.alloc to memref.alloc + * + * @par Example: + * ```mlir + * %tensor = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %memref = memref.alloc(%c3) : memref<3x!qc.qubit> + * ``` + */ +struct ConvertQTensorAllocOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::AllocOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto qubitType = qc::QubitType::get(op.getContext()); + auto tensorType = llvm::cast(op.getResult().getType()); + auto memrefType = MemRefType::get(tensorType.getShape(), qubitType); + + if (tensorType.hasStaticShape()) { + // Static size: no dynamic size operand needed + rewriter.replaceOpWithNewOp(op, memrefType); + } else { + // Dynamic size: forward the runtime size operand + rewriter.replaceOpWithNewOp(op, memrefType, + op.getSize()); + } + return success(); + } +}; + +/** + * @brief Converts qtensor.extract to memref.load + * + * @par Example: + * ```mlir + * %tensor_out, %q = qtensor.extract %tensor_in[%c0]: tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %q = memref.load %memref[%c0] : memref<3x!qc.qubit> + * ``` + */ +struct ConvertQTensorExtractOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto load = memref::LoadOp::create(rewriter, op.getLoc(), + adaptor.getTensor(), adaptor.getIndex()); + rewriter.replaceOp(op, {adaptor.getTensor(), load.getResult()}); + return success(); + } +}; + +/** + * @brief Removes qtensor.insert operations + */ +struct ConvertQTensorInsertOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOp(op, adaptor.getDest()); + return success(); + } +}; + +/** + * @brief Converts qtensor.dealloc to memref.dealloc + * + * @par Example: + * ```mlir + * qtensor.dealloc %tensor : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * memref.dealloc %memref : memref<3x!qc.qubit> + * ``` + */ +struct ConvertQTensorDeallocOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getTensor()); + return success(); } }; @@ -828,13 +941,14 @@ struct QCOToQC final : impl::QCOToQCBase { RewritePatternSet patterns(context); QCOToQCTypeConverter typeConverter(context); - // Configure conversion target: QCO illegal, QC legal - target.addIllegalDialect(); - target.addLegalDialect(); + // Configure conversion target + target.addIllegalDialect(); + target.addLegalDialect(); // Register operation conversion patterns that do not need state tracking patterns.add< - ConvertQCOMeasureOp, ConvertQCOResetOp, + ConvertQTensorAllocOp, ConvertQTensorExtractOp, ConvertQTensorInsertOp, + ConvertQTensorDeallocOp, ConvertQCOMeasureOp, ConvertQCOResetOp, ConvertQCOZeroTargetOneParameterToQC, ConvertQCOOneTargetZeroParameterToQC, ConvertQCOOneTargetZeroParameterToQC, diff --git a/mlir/lib/Conversion/QCToQCO/CMakeLists.txt b/mlir/lib/Conversion/QCToQCO/CMakeLists.txt index 335d64d3e1..c448b8f765 100644 --- a/mlir/lib/Conversion/QCToQCO/CMakeLists.txt +++ b/mlir/lib/Conversion/QCToQCO/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_conversion_library( MLIRQTensorDialect MLIRArithDialect MLIRFuncDialect + MLIRMemRefDialect MLIRTransforms MLIRFuncTransforms) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 23976d6760..6323da023c 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -14,16 +14,24 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include #include #include #include +#include +#include #include #include +#include +#include +#include #include #include #include +#include #include #include #include @@ -43,6 +51,17 @@ using namespace qc; #include "mlir/Conversion/QCToQCO/QCToQCO.h.inc" namespace { + +/** + * @brief Information about a qubit + */ +struct QubitInfo { + /// Register the qubit belongs to + Value reg; + /// Index of the qubit within its register + Value index; +}; + /** * @brief State object for tracking qubit value flow during conversion * @@ -81,13 +100,21 @@ struct LoweringState { llvm::DenseMap currentQubits; }; - /// Per-region map from QC qubit references to latest QCO SSA values. + /// Per-region map from original QC qubit reference to its latest QCO SSA + /// value. /// /// @details Keys are `Operation::getParentRegion()` for ops being converted - /// (typically a `func.func` body, or a `qc.ctrl` / `qc.inv` region). + /// (typically a `func.func` body or a modifier region). llvm::DenseMap> qubitMap; - /// Stack of active modifier regions (`qc.ctrl` / `qc.inv`). + /// Per-region map from original QC register to its latest QTensor SSA value + llvm::DenseMap> tensorMap; + + /// Per-region map from original QC qubit reference to its register + /// information + llvm::DenseMap> qubitInfoMap; + + /// Stack of active modifier regions SmallVector modifierFrames; }; @@ -122,23 +149,6 @@ class StatefulOpConversionPattern : public OpConversionPattern { }; } // namespace -/** - * @brief Helper function to look up the latest QCO qubit value for a given QC - * qubit reference - * - * @param qubitMap The mapping from QC qubits to QCO qubits for the current - * region - * @param qcQubit The QC qubit reference to look up - * @return The latest QCO qubit value corresponding to the given QC qubit - * reference - */ -[[nodiscard]] static Value -lookupMappedQubit(llvm::DenseMap& qubitMap, Value qcQubit) { - auto it = qubitMap.find(qcQubit); - assert(it != qubitMap.end() && "QC qubit not found"); - return it->second; -} - /** @brief Returns whether lowering currently processes a modifier body. */ [[nodiscard]] static bool isInsideModifier(const LoweringState& state) { return !state.modifierFrames.empty(); @@ -151,17 +161,26 @@ currentModifierFrame(LoweringState& state) { return state.modifierFrames.back(); } -/** @brief Finds the nearest region-local qubit map containing @p qcQubit. */ -[[nodiscard]] static llvm::DenseMap* -findMappedQubitMap(LoweringState& state, Operation* anchor, Value qcQubit) { - for (Region* current = anchor->getParentRegion(); current != nullptr; +/** + * @brief Finds the nearest region-local map containing @p reference and + * returns the pair containing the map and a mutable reference to the value in + * the map. + */ +[[nodiscard]] static std::pair*, Value*> +findRegionLocalMap(llvm::DenseMap>& map, + Operation* anchor, Value reference) { + for (auto* current = anchor->getParentRegion(); current != nullptr; current = current->getParentRegion()) { - auto mapIt = state.qubitMap.find(current); - if (mapIt != state.qubitMap.end() && mapIt->second.contains(qcQubit)) { - return &mapIt->second; + if (auto it = map.find(current); it != map.end()) { + auto& regionMap = it->second; + if (auto valueIt = regionMap.find(reference); + valueIt != regionMap.end()) { + return {®ionMap, &valueIt->second}; + } + return {®ionMap, nullptr}; } } - return nullptr; + return {nullptr, nullptr}; } /** @brief Resolves the latest QCO SSA value for a QC qubit reference. */ @@ -175,9 +194,20 @@ findMappedQubitMap(LoweringState& state, Operation* anchor, Value qcQubit) { } } - auto* qubitMap = findMappedQubitMap(state, anchor, qcQubit); - assert(qubitMap != nullptr && "QC qubit not found"); - return lookupMappedQubit(*qubitMap, qcQubit); + const auto& [qubitMap, qubitValue] = + findRegionLocalMap(state.qubitMap, anchor, qcQubit); + assert(qubitMap != nullptr && qubitValue != nullptr && "QC qubit not found"); + return *qubitValue; +} + +/** @brief Resolves the latest QTensor SSA value for a QC register. */ +[[nodiscard]] static Value lookupMappedTensor(LoweringState& state, + Operation* anchor, Value memref) { + const auto& [tensorMap, tensorValue] = + findRegionLocalMap(state.tensorMap, anchor, memref); + assert(tensorMap != nullptr && tensorValue != nullptr && + "QC register not found"); + return *tensorValue; } /** @brief Updates the latest QCO SSA value for a QC qubit reference. */ @@ -192,14 +222,36 @@ static void assignMappedQubit(LoweringState& state, Operation* anchor, } } - if (auto* qubitMap = findMappedQubitMap(state, anchor, qcQubit)) { + auto [qubitMap, qubitValue] = + findRegionLocalMap(state.qubitMap, anchor, qcQubit); + if (qubitValue != nullptr) { + *qubitValue = qcoQubit; + return; + } + if (qubitMap != nullptr) { (*qubitMap)[qcQubit] = qcoQubit; return; } - state.qubitMap[anchor->getParentRegion()][qcQubit] = qcoQubit; } +/** @brief Updates the latest QTensor SSA value for a QC register. */ +static void assignMappedTensor(LoweringState& state, Operation* anchor, + Value memref, Value tensor) { + auto [tensorMap, tensorValue] = + findRegionLocalMap(state.tensorMap, anchor, memref); + + if (tensorValue != nullptr) { + *tensorValue = tensor; + return; + } + if (tensorMap != nullptr) { + (*tensorMap)[memref] = tensor; + return; + } + state.tensorMap[anchor->getParentRegion()][memref] = tensor; +} + /** @brief Resolves a range of QC qubits to their latest QCO values. */ template [[nodiscard]] static SmallVector @@ -271,7 +323,7 @@ struct ConvertFuncReturnOp final : StatefulOpConversionPattern { matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - Region* funcRegion = op->getParentRegion(); + auto* funcRegion = op->getParentRegion(); auto& map = state.qubitMap[funcRegion]; // Build return values from qubitMap and collect live qubit information. @@ -292,7 +344,7 @@ struct ConvertFuncReturnOp final : StatefulOpConversionPattern { } // Deallocate dead qubit values - for (Value qcoQubit : llvm::make_second_range(map)) { + for (auto qcoQubit : llvm::make_second_range(map)) { if (!liveQubits.contains(qcoQubit)) { SinkOp::create(rewriter, op.getLoc(), qcoQubit); } @@ -328,6 +380,175 @@ class QCToQCOTypeConverter final : public TypeConverter { } }; +/** + * @brief Converts memref.alloc to qtensor.alloc + * + * @par Example: + * ```mlir + * %memref = memref.alloc(%c3) : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * %tensor = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * ``` + */ +struct ConvertMemRefAllocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + if (!llvm::isa(op.getType().getElementType())) { + return failure(); + } + + auto shape = op.getType().getShape(); + if (shape.size() != 1) { + return failure(); + } + + Value qtensor; + if (shape[0] == ShapedType::kDynamic) { + qtensor = rewriter.replaceOpWithNewOp( + op, adaptor.getDynamicSizes()[0]); + } else { + auto size = + arith::ConstantIndexOp::create(rewriter, op.getLoc(), shape[0]); + qtensor = + rewriter.replaceOpWithNewOp(op, size.getResult()); + } + + auto& state = getState(); + auto memref = op.getResult(); + assignMappedTensor(state, qtensor.getDefiningOp(), memref, qtensor); + + return success(); + } +}; + +/** + * @brief Converts memref.load to qtensor.extract + * + * @par Example: + * ```mlir + * %q = memref.load %memref[%c0] : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * %tensor_out, %q = qtensor.extract %tensor_in[%c0]: tensor<3x!qco.qubit> + * ``` + */ +struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto memref = op.getMemref(); + if (!llvm::isa(memref.getType().getElementType())) { + return failure(); + } + + auto& state = getState(); + auto& qubitInfoMap = state.qubitInfoMap; + auto* operation = op.getOperation(); + + // Look up latest QTensor value for this QC register + auto qtensor = lookupMappedTensor(state, operation, memref); + + auto index = adaptor.getIndices()[0]; + auto extract = + qtensor::ExtractOp::create(rewriter, op.getLoc(), qtensor, index); + + auto qcQubit = op.getResult(); + auto qcoQubit = extract.getResult(); + + assignMappedQubit(state, operation, qcQubit, qcoQubit); + assignMappedTensor(state, operation, memref, extract.getOutTensor()); + + QubitInfo info{.reg = memref, .index = index}; + auto* parentRegion = operation->getParentRegion(); + if (auto it = qubitInfoMap.find(parentRegion); it != qubitInfoMap.end()) { + it->second[qcQubit] = info; + } else { + qubitInfoMap[parentRegion][qcQubit] = info; + } + + rewriter.eraseOp(op); + + return success(); + } +}; + +/** + * @brief Converts memref.dealloc to qtensor.dealloc + * + * @details + * Before deallocating the tensor, all qubits are inserted back into it at their + * original location. + * + * @par Example: + * ```mlir + * memref.dealloc %memref : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * %t1 = qtensor.insert %q0 into %t0[%c0] : tensor<3x!qco.qubit> + * %t2 = qtensor.insert %q1 into %t1[%c1] : tensor<3x!qco.qubit> + * %t3 = qtensor.insert %q2 into %t2[%c2] : tensor<3x!qco.qubit> + * qtensor.dealloc %t3 : tensor<3x!qco.qubit> + * ``` + */ +struct ConvertMemRefDeallocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::DeallocOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto memref = op.getMemref(); + if (!llvm::isa(memref.getType().getElementType())) { + return failure(); + } + + auto& state = getState(); + auto& qubitMap = state.qubitMap[op->getParentRegion()]; + auto& tensorMap = state.tensorMap[op->getParentRegion()]; + auto& qubitInfoMap = state.qubitInfoMap[op->getParentRegion()]; + + // Look up latest QTensor value for this QC register + auto qtensor = lookupMappedTensor(state, op.getOperation(), memref); + + // Filter out qubits belonging to this tensor + for (auto it = qubitMap.begin(); it != qubitMap.end();) { + auto current = it++; + auto qcQubit = current->first; + auto qcoQubit = current->second; + + auto infoIt = qubitInfoMap.find(qcQubit); + if (infoIt == qubitInfoMap.end()) { + continue; + } + + auto& [reg, index] = infoIt->second; + if (reg != memref) { + continue; + } + + qtensor = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, + qtensor, index) + .getResult(); + qubitMap.erase(current); + qubitInfoMap.erase(infoIt); + } + tensorMap.erase(memref); + + rewriter.replaceOpWithNewOp(op, qtensor); + return success(); + } +}; + /** * @brief Converts qc.alloc to qco.alloc * @@ -393,11 +614,11 @@ struct ConvertQCDeallocOp final : StatefulOpConversionPattern { matchAndRewrite(qc::DeallocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); + auto& qubitMap = state.qubitMap[op->getParentRegion()]; auto* operation = op.getOperation(); - auto* region = operation->getParentRegion(); - auto& qubitMap = state.qubitMap[region]; - Value qcQubit = op.getQubit(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + + auto qcQubit = op.getQubit(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the sink operation rewriter.replaceOpWithNewOp(op, qcoQubit); @@ -431,11 +652,10 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { matchAndRewrite(qc::StaticOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - auto* operation = op.getOperation(); - Value qcQubit = op.getQubit(); + auto qcQubit = op.getQubit(); auto qcoOp = rewriter.replaceOpWithNewOp(op, op.getIndex()); - assignMappedQubit(state, operation, qcQubit, qcoOp.getQubit()); + assignMappedQubit(state, qcoOp, qcQubit, qcoOp.getQubit()); return success(); } @@ -472,8 +692,8 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubit(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubit(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create qco.measure (returns both output qubit and bit result) auto qcoOp = qco::MeasureOp::create( @@ -518,8 +738,8 @@ struct ConvertQCResetOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubit(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubit(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create qco.reset (consumes input, produces output) auto qcoOp = qco::ResetOp::create(rewriter, op.getLoc(), qcoQubit); @@ -590,8 +810,8 @@ struct ConvertQCOneTargetZeroParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubitIn(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubitIn(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit); @@ -629,8 +849,8 @@ struct ConvertQCOneTargetOneParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubitIn(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubitIn(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = @@ -669,8 +889,8 @@ struct ConvertQCOneTargetTwoParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubitIn(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubitIn(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit, @@ -709,8 +929,8 @@ struct ConvertQCOneTargetThreeParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubitIn(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubitIn(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = @@ -751,10 +971,10 @@ struct ConvertQCTwoTargetZeroParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit0 = op.getQubit0In(); - Value qcQubit1 = op.getQubit1In(); - Value qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); - Value qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); + auto qcQubit0 = op.getQubit0In(); + auto qcQubit1 = op.getQubit1In(); + auto qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); + auto qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit0, qcoQubit1); @@ -794,10 +1014,10 @@ struct ConvertQCTwoTargetOneParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit0 = op.getQubit0In(); - Value qcQubit1 = op.getQubit1In(); - Value qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); - Value qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); + auto qcQubit0 = op.getQubit0In(); + auto qcQubit1 = op.getQubit1In(); + auto qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); + auto qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit0, qcoQubit1, @@ -838,10 +1058,10 @@ struct ConvertQCTwoTargetTwoParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit0 = op.getQubit0In(); - Value qcQubit1 = op.getQubit1In(); - Value qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); - Value qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); + auto qcQubit0 = op.getQubit0In(); + auto qcQubit1 = op.getQubit1In(); + auto qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); + auto qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit0, qcoQubit1, @@ -877,7 +1097,7 @@ struct ConvertQCBarrierOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - const auto qcQubits = llvm::to_vector(op.getQubits()); + auto qcQubits = op.getQubits(); auto qcoQubits = resolveMappedQubits(state, operation, qcQubits); // Create qco.barrier @@ -1055,12 +1275,23 @@ struct QCToQCO final : impl::QCToQCOBase { RewritePatternSet patterns(context); QCToQCOTypeConverter typeConverter(context); - // Configure conversion target: QC illegal, QCO legal + // Configure conversion target target.addIllegalDialect(); - target.addLegalDialect(); + target.addLegalDialect(); + + target.addDynamicallyLegalDialect([](Operation* op) { + auto isQubitMemref = [](Type t) { + auto mt = llvm::dyn_cast(t); + return mt && llvm::isa(mt.getElementType()); + }; + return llvm::none_of(op->getOperandTypes(), isQubitMemref) && + llvm::none_of(op->getResultTypes(), isQubitMemref); + }); // Register operation conversion patterns with state tracking patterns.add< + ConvertMemRefAllocOp, ConvertMemRefLoadOp, ConvertMemRefDeallocOp, ConvertQCAllocOp, ConvertQCDeallocOp, ConvertQCStaticOp, ConvertQCMeasureOp, ConvertQCResetOp, ConvertQCZeroTargetOneParameterToQCO, diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index cd1d2721b7..03caf30f02 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -18,7 +18,10 @@ #include #include #include +#include +#include #include +#include #include #include #include @@ -30,13 +33,13 @@ #include #include #include +#include #include -#include +#include #include #include #include #include -#include #include #include #include @@ -45,9 +48,9 @@ #include #include +#include #include #include -#include #include #include @@ -63,29 +66,35 @@ namespace { /** * @brief State object for tracking lowering information during QIR conversion - * - * @details - * This struct maintains state during the conversion of QC dialect - * operations to QIR (Quantum Intermediate Representation). It tracks: - * - Qubit and result counts for QIR metadata - * - Pointer value caching for reuse - * - Whether dynamic memory management is needed - * - Sequence of measurements for output recording */ struct LoweringState : QIRMetadata { - /// Map from register name to register start index - DenseMap registerStartIndexMap; + /// Cache static qubit pointers for reuse + DenseMap staticQubits; + + /// Cache MemRef sizes for reuse + DenseMap memrefSizes; - /// Map from index to pointer value for reuse - DenseMap ptrMap; + /// Map from register name to result-array pointer + llvm::StringMap resultArrays; - /// Map from (register_name, register_index) to result pointer - /// This allows caching result pointers for measurements with register info - DenseMap, Value> registerResultMap; + /// Map from (register name, index) to loaded result + llvm::DenseMap, Value> loadedResults; + + /// Map from index to result pointer for non-register results + DenseMap resultPtrs; /// Modifier information int64_t inCtrlOp = 0; DenseMap> controls; + + /// Allocator and StringSaver for stable StringRefs + llvm::BumpPtrAllocator allocator; + llvm::StringSaver stringSaver{allocator}; + + /// Block information + Block* entryBlock{}; + Block* measurementsBlock{}; + Block* outputBlock{}; }; /** @@ -195,25 +204,183 @@ namespace { * * Type conversions: * - `!qc.qubit` -> `!llvm.ptr` (opaque pointer to qubit in QIR) + * - `memref` -> `!llvm.ptr` (opaque pointer to array in QIR) */ struct QCToQIRTypeConverter final : LLVMTypeConverter { explicit QCToQIRTypeConverter(MLIRContext* ctx) : LLVMTypeConverter(ctx) { // Convert QubitType to LLVM pointer (QIR uses opaque pointers for qubits) addConversion( [ctx](QubitType /*type*/) { return LLVM::LLVMPointerType::get(ctx); }); + + addConversion([ctx](MemRefType type) -> Type { + if (llvm::isa(type.getElementType())) { + return LLVM::LLVMPointerType::get(ctx); + } + return type; + }); } }; /** - * @brief Converts qc.alloc operation to static QIR qubit allocations + * @brief Converts memref.alloc to QIR qubit-array allocation * - * @details - * QIR 2.0 does not support dynamic qubit allocation. Therefore, qc.alloc - * operations are converted to static qubit references using inttoptr with a - * constant index. + * @par Example: + * ```mlir + * %memref = memref.alloc() : memref<3x!qc.qubit> + * ``` + * becomes: + * ```mlir + * %zero = llvm.mlir.zero : !llvm.ptr + * %alloca = llvm.alloca %c3 x !llvm.ptr : (i64) -> !llvm.ptr + * llvm.call @"@__quantum__rt__qubit_array_allocate"(%c3, %alloca, %zero) : + * (i64, !llvm.ptr, !llvm.ptr) -> () + * ``` + */ +struct ConvertMemRefAllocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto shape = op.getType().getShape(); + if (shape.size() != 1) { + return rewriter.notifyMatchFailure( + op, "Only one-dimensional registers are supported"); + } + + auto& state = getState(); + state.useDynamicQubit = true; + + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + auto fnSig = + LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), + {rewriter.getI64Type(), ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration(rewriter, op, + QIR_QUBIT_ARRAY_ALLOC, fnSig); + + Value size; + if (shape[0] == ShapedType::kDynamic) { + size = adaptor.getDynamicSizes()[0]; + } else { + size = LLVM::ConstantOp::create( + rewriter, op.getLoc(), + rewriter.getI64IntegerAttr(static_cast(shape[0]))) + .getResult(); + } + state.memrefSizes.try_emplace(op.getMemref(), size); + + auto array = + LLVM::AllocaOp::create(rewriter, op.getLoc(), ptrType, ptrType, size); + auto zero = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrType); + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, + ValueRange{size, array.getResult(), zero.getResult()}); + + rewriter.replaceOp(op, array.getResult()); + + return success(); + } +}; + +/** + * @brief Converts memref.load to llvm.load * - * Register metadata (register_name, register_size, register_index) is used to - * provide a reasonable guess for a static qubit index that is still free. + * @par Example: + * ```mlir + * %q = memref.load %memref[%c0] : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * %ptr = llvm.getelementptr %alloca[1] : !llvm.ptr -> !llvm.ptr, !llvm.ptr + * %q = llvm.load %ptr : !llvm.ptr -> !llvm.ptr + * ``` + */ +struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto shape = op.getMemref().getType().getShape(); + if (shape.size() != 1) { + return rewriter.notifyMatchFailure( + op, "Only one-dimensional registers are supported"); + } + + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + auto array = adaptor.getMemref(); + auto index = adaptor.getIndices()[0]; + auto gep = LLVM::GEPOp::create(rewriter, op.getLoc(), ptrType, ptrType, + array, index); + auto load = + LLVM::LoadOp::create(rewriter, op.getLoc(), ptrType, gep.getResult()); + + rewriter.replaceOp(op, load.getResult()); + + return success(); + } +}; + +/** + * @brief Converts memref.dealloc to QIR qubit-array release + * + * @par Example: + * ```mlir + * memref.dealloc %memref : memref<3x!qc.qubit> + * ``` + * becomes: + * ```mlir + * llvm.call @"@__quantum__rt__qubit_array_release"(%c3, %alloca) : (i64, + * !llvm.ptr) -> () + * ``` + */ +struct ConvertMemRefDeallocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto shape = op.getMemref().getType().getShape(); + if (shape.size() != 1) { + return rewriter.notifyMatchFailure( + op, "Only one-dimensional registers are supported"); + } + + auto& state = getState(); + auto* ctx = getContext(); + auto i64Type = rewriter.getI64Type(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // Save current insertion point + const OpBuilder::InsertionGuard guard(rewriter); + + // Release resources in output block + rewriter.setInsertionPoint(state.outputBlock->getTerminator()); + + auto fnSig = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), + {i64Type, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration(rewriter, op, + QIR_QUBIT_ARRAY_RELEASE, fnSig); + + auto size = state.memrefSizes.lookup(op.getMemref()); + assert(size != nullptr && "Size not found"); + + // Create the release call + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, + ValueRange{size, adaptor.getMemref()}); + rewriter.eraseOp(op); + + return success(); + } +}; + +/** + * @brief Converts qc.alloc to QIR qubit allocation * * @par Example: * ```mlir @@ -221,81 +388,36 @@ struct QCToQIRTypeConverter final : LLVMTypeConverter { * ``` * becomes: * ```mlir - * %c0 = llvm.mlir.constant(0 : i64) : i64 - * %q0 = llvm.inttoptr %c0 : i64 to !llvm.ptr + * %zero = llvm.mlir.zero : !llvm.ptr + * %q = llvm.call @"@__quantum__rt__qubit_allocate"(%zero) : !llvm.ptr -> + * !llvm.ptr * ``` */ -struct ConvertQCAllocQIR final : StatefulOpConversionPattern { +struct ConvertQCAllocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(AllocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - const auto numQubits = static_cast(state.numQubits); - auto& ptrMap = state.ptrMap; - auto& registerMap = state.registerStartIndexMap; + state.useDynamicQubit = true; - // Get or create pointer value - if (op.getRegisterName() && op.getRegisterSize() && op.getRegisterIndex()) { - const auto registerName = op.getRegisterName().value(); - const auto registerSize = - static_cast(op.getRegisterSize().value()); - const auto registerIndex = - static_cast(op.getRegisterIndex().value()); + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); - if (const auto it = registerMap.find(registerName); - it != registerMap.end()) { - // Register is already tracked - // The pointer was created by the step below - const auto globalIndex = it->second + registerIndex; - if (!ptrMap.contains(globalIndex)) { - return op.emitError("Pointer not found"); - } - rewriter.replaceOp(op, ptrMap.at(globalIndex)); - return success(); - } + auto fnSig = LLVM::LLVMFunctionType::get(ptrType, {ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(rewriter, op, QIR_QUBIT_ALLOC, fnSig); - // Allocate the entire register as static qubits - registerMap[registerName] = numQubits; - SmallVector pointers; - pointers.reserve(registerSize); - for (int64_t i = 0; i < registerSize; ++i) { - Value val{}; - if (const auto it = ptrMap.find(numQubits + i); it != ptrMap.end()) { - val = it->second; - } else { - val = createPointerFromIndex(rewriter, op.getLoc(), numQubits + i); - ptrMap[numQubits + i] = val; - } - pointers.push_back(val); - } - rewriter.replaceOp(op, pointers[registerIndex]); - state.numQubits += registerSize; - return success(); - } + auto zero = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrType); + rewriter.replaceOpWithNewOp(op, fnDec, zero.getResult()); - // no register info, check if ptr has already been allocated (as a Result) - Value val{}; - if (const auto it = ptrMap.find(numQubits); it != ptrMap.end()) { - val = it->second; - } else { - val = createPointerFromIndex(rewriter, op.getLoc(), numQubits); - ptrMap[numQubits] = val; - } - rewriter.replaceOp(op, val); - state.numQubits++; return success(); } }; /** - * @brief Erases qc.dealloc operations - * - * @details - * Since QIR 2.0 does not support dynamic qubit allocation, dynamic allocations - * are converted to static allocations. Therefore, deallocation operations - * become no-ops and are simply removed from the IR. + * @brief Converts qc.dealloc to QIR qubit release * * @par Example: * ```mlir @@ -303,22 +425,39 @@ struct ConvertQCAllocQIR final : StatefulOpConversionPattern { * ``` * becomes: * ```mlir - * // (removed) + * llvm.call @"@__quantum__rt__qubit_release"(%q) : !llvm.ptr -> () * ``` */ -struct ConvertQCDeallocQIR final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertQCDeallocOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult - matchAndRewrite(DeallocOp op, OpAdaptor /*adaptor*/, + matchAndRewrite(DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // Save current insertion point + const OpBuilder::InsertionGuard guard(rewriter); + + // Release resources in output block + rewriter.setInsertionPoint(state.outputBlock->getTerminator()); + + auto fnSig = + LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(rewriter, op, QIR_QUBIT_RELEASE, fnSig); + + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, adaptor.getQubit()); rewriter.eraseOp(op); + return success(); } }; /** - * @brief Converts qc.static operation to QIR inttoptr + * @brief Converts qc.static to llvm.inttoptr * * @details * Converts a static qubit reference to an LLVM pointer by creating a constant @@ -335,7 +474,7 @@ struct ConvertQCDeallocQIR final : OpConversionPattern { * %q0 = llvm.inttoptr %c0 : i64 to !llvm.ptr * ``` */ -struct ConvertQCStaticQIR final : StatefulOpConversionPattern { +struct ConvertQCStaticOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -343,17 +482,19 @@ struct ConvertQCStaticQIR final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { const auto index = static_cast(op.getIndex()); auto& state = getState(); + // Get or create a pointer to the qubit - Value val{}; - if (const auto it = state.ptrMap.find(index); it != state.ptrMap.end()) { + Value qubit; + if (const auto it = state.staticQubits.find(index); + it != state.staticQubits.end()) { // Reuse existing pointer - val = it->second; + qubit = it->second; } else { // Create and cache for reuse - val = createPointerFromIndex(rewriter, op.getLoc(), index); - state.ptrMap.try_emplace(index, val); + qubit = createPointerFromIndex(rewriter, op.getLoc(), index); + state.staticQubits.try_emplace(index, qubit); } - rewriter.replaceOp(op, val); + rewriter.replaceOp(op, qubit); // Track maximum qubit index if (std::cmp_greater_equal(index, state.numQubits)) { @@ -365,17 +506,14 @@ struct ConvertQCStaticQIR final : StatefulOpConversionPattern { }; /** - * @brief Converts qc.measure operation to QIR measurement + * @brief Converts qc.measure to QIR measurement * * @details - * Converts qubit measurement to a QIR call to `__quantum__qis__mz__body`. - * Unlike the previous implementation, this does NOT immediately record output. - * Instead, it tracks measurements in the lowering state for deferred output - * recording in a separate output block, as required by the QIR Base Profile. + * For measurements with register information, a result array is allocated and + * all result pointers are loaded. * - * For measurements with register information, the result pointer is mapped - * to (register_name, register_index) for later retrieval. For measurements - * without register information, a sequential result pointer is assigned. + * For measurements without register information, an individual result pointer + * is allocated. * * @par Example (with register): * ```mlir @@ -383,83 +521,108 @@ struct ConvertQCStaticQIR final : StatefulOpConversionPattern { * ``` * becomes: * ```mlir - * %c0_i64 = llvm.mlir.constant(0 : i64) : i64 - * %result_ptr = llvm.inttoptr %c0_i64 : i64 to !llvm.ptr - * llvm.call @__quantum__qis__mz__body(%q, %result_ptr) : (!llvm.ptr, !llvm.ptr) - * -> () + * // In entry block: + * %zero = llvm.mlir.zero : !llvm.ptr + * %alloca = llvm.alloca %c2 x !llvm.ptr : (i64) -> !llvm.ptr + * llvm.call @"@__quantum__rt__result_array_allocate"(%c2, %alloca, %zero) : + * (i64, !llvm.ptr, !llvm.ptr) -> () + * %r = llvm.load %alloca : !llvm.ptr -> !llvm.ptr + * + * // In measurements block: + * llvm.call @__quantum__qis__mz__body(%q, %r) : (!llvm.ptr, !llvm.ptr) -> () * ``` */ -struct ConvertQCMeasureQIR final : StatefulOpConversionPattern { +struct ConvertQCMeasureOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(MeasureOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto* ctx = getContext(); - const auto ptrType = LLVM::LLVMPointerType::get(ctx); auto& state = getState(); - const auto numResults = static_cast(state.numResults); - auto& ptrMap = state.ptrMap; - auto& registerResultMap = state.registerResultMap; + state.useDynamicResult = true; - // Get or create result pointer value - Value resultValue; + auto& resultArrays = state.resultArrays; + auto& loadedResults = state.loadedResults; + auto& resultPtrs = state.resultPtrs; + + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // Save current insertion point + const OpBuilder::InsertionGuard guard(rewriter); + + // Insert allocations and constants in entry block + rewriter.setInsertionPoint(state.entryBlock->getTerminator()); + + // Get result pointer + Value result; if (op.getRegisterName() && op.getRegisterSize() && op.getRegisterIndex()) { const auto registerName = op.getRegisterName().value(); const auto registerSize = static_cast(op.getRegisterSize().value()); const auto registerIndex = static_cast(op.getRegisterIndex().value()); - const auto key = std::make_pair(registerName, registerIndex); - if (const auto it = registerResultMap.find(key); - it != registerResultMap.end()) { - resultValue = it->second; - } else { - // Allocate the entire register as static results + // Create result register if it does not exist yet + if (!resultArrays.contains(registerName)) { + auto fnSig = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), + {rewriter.getI64Type(), ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration( + rewriter, op, QIR_RESULT_ARRAY_ALLOC, fnSig); + + auto size = + LLVM::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getI64IntegerAttr(registerSize)) + .getResult(); + auto array = LLVM::AllocaOp::create(rewriter, op.getLoc(), ptrType, + ptrType, size); + auto zero = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrType); + LLVM::CallOp::create( + rewriter, op.getLoc(), fnDec, + ValueRange{size, array.getResult(), zero.getResult()}); + resultArrays.try_emplace(registerName, array.getResult()); + for (int64_t i = 0; i < registerSize; ++i) { - Value val{}; - if (const auto ptrIt = ptrMap.find(numResults + i); - ptrIt != ptrMap.end()) { - val = ptrIt->second; - } else { - val = createPointerFromIndex(rewriter, op.getLoc(), numResults + i); - ptrMap[numResults + i] = val; - } - registerResultMap.try_emplace({registerName, i}, val); + auto gep = LLVM::GEPOp::create( + rewriter, op.getLoc(), ptrType, ptrType, array.getResult(), + ValueRange{LLVM::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getI64IntegerAttr(i))}); + auto load = LLVM::LoadOp::create(rewriter, op.getLoc(), ptrType, + gep.getResult()); + loadedResults.try_emplace({state.stringSaver.save(registerName), i}, + load.getResult()); } - state.numResults += registerSize; - resultValue = registerResultMap.at(key); } + + result = loadedResults.at({registerName, registerIndex}); } else { - // Choose a safe default register name - StringRef defaultRegName = "c"; - if (llvm::any_of(registerResultMap, [](const auto& entry) { - return entry.first.first == "c"; - })) { - defaultRegName = "__unnamed__"; - } - // No register info, check if ptr has already been allocated (as a Qubit) - if (const auto it = ptrMap.find(numResults); it != ptrMap.end()) { - resultValue = it->second; - } else { - resultValue = createPointerFromIndex(rewriter, op.getLoc(), numResults); - ptrMap[numResults] = resultValue; - } - registerResultMap.insert({{defaultRegName, numResults}, resultValue}); - state.numResults++; + auto fnSig = LLVM::LLVMFunctionType::get(ptrType, {ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(rewriter, op, QIR_RESULT_ALLOC, fnSig); + + auto zero = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrType); + result = + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, zero.getResult()) + .getResult(); + + resultPtrs.try_emplace(resultPtrs.size(), result); } - // Declare QIR function - const auto fnSignature = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(ctx), {ptrType, ptrType}); - const auto fnDecl = - getOrCreateFunctionDeclaration(rewriter, op, QIR_MEASURE, fnSignature); + // Switch to measurements block + rewriter.setInsertionPoint(state.measurementsBlock->getTerminator()); + + // Create measure call + auto fnSig = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), + {ptrType, ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(rewriter, op, QIR_MEASURE, fnSig); + + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, + ValueRange{adaptor.getQubit(), result}); + + rewriter.replaceOp(op, result); - // Create CallOp and replace qc.measure with result pointer - LLVM::CallOp::create(rewriter, op.getLoc(), fnDecl, - ValueRange{adaptor.getQubit(), resultValue}); - rewriter.replaceOp(op, resultValue); return success(); } }; @@ -477,17 +640,24 @@ struct ConvertQCMeasureQIR final : StatefulOpConversionPattern { * ``` * becomes: * ```mlir - * llvm.call @__quantum__qis__reset__body(%q) : (!llvm.ptr) -> () + * llvm.call @__quantum__qis__reset__body(%q) : !llvm.ptr -> () * ``` */ -struct ConvertQCResetQIR final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertQCResetOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(ResetOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); auto* ctx = getContext(); + // Save current insertion point + const OpBuilder::InsertionGuard guard(rewriter); + + // Switch to measurements block + rewriter.setInsertionPoint(state.measurementsBlock->getTerminator()); + // Declare QIR function const auto fnSignature = LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(ctx), LLVM::LLVMPointerType::get(ctx)); @@ -515,7 +685,7 @@ struct ConvertQCResetQIR final : OpConversionPattern { * llvm.call @__quantum__qis__gphase__body(%theta) : (f64) -> () * ``` */ -struct ConvertQCGPhaseOpQIR final : StatefulOpConversionPattern { +struct ConvertQCGPhaseOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -543,11 +713,10 @@ struct ConvertQCGPhaseOpQIR final : StatefulOpConversionPattern { * ``` \ * is converted to \ * ```mlir \ - * llvm.call @__quantum__qis__QIR_NAME__body(%q) : (!llvm.ptr) -> () \ + * llvm.call @__quantum__qis__QIR_NAME__body(%q) : !llvm.ptr -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -594,8 +763,7 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, SXDG, sxdg, sxdg) * -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -635,8 +803,7 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, P, p, p, theta) * (!llvm.ptr, f64, f64) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -674,8 +841,7 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, U2, u2, u2, phi, lambda) * : (!llvm.ptr, f64, f64, f64) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -712,8 +878,7 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, U, u, u3) * !llvm.ptr) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -753,8 +918,7 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ECR, ecr, ecr) * (!llvm.ptr, !llvm.ptr, f64) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -795,8 +959,7 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, RZZ, rzz, rzz, theta) * (!llvm.ptr, !llvm.ptr, f64, f64) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -824,7 +987,7 @@ DEFINE_TWO_TARGET_TWO_PARAMETER(XXMinusYYOp, XXMINUSYY, xx_minus_yy, /** * @brief Erases qc.barrier operation, as it is a no-op in QIR */ -struct ConvertQCBarrierQIR final : StatefulOpConversionPattern { +struct ConvertQCBarrierOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -838,7 +1001,7 @@ struct ConvertQCBarrierQIR final : StatefulOpConversionPattern { /** * @brief Inlines qc.ctrl region removes the operation */ -struct ConvertQCCtrlQIR final : StatefulOpConversionPattern { +struct ConvertQCCtrlOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -862,7 +1025,7 @@ struct ConvertQCCtrlQIR final : StatefulOpConversionPattern { /** * @brief Erases qc.yield operation */ -struct ConvertQCYieldQIR final : StatefulOpConversionPattern { +struct ConvertQCYieldOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -883,17 +1046,17 @@ struct ConvertQCYieldQIR final : StatefulOpConversionPattern { * * Conversion stages: * 1. Convert func dialect to LLVM - * 2. Ensure proper block structure for QIR base profile and add - * initialization - * 3. Convert QC operations to QIR calls - * 4. Set QIR metadata attributes - * 5. Convert arith and cf dialects to LLVM - * 6. Reconcile unrealized casts + * 2. Ensure proper block structure for QIR base profile + * 3. Add QIR initialization call + * 4. Convert QC operations to QIR calls + * 5. Set QIR metadata attributes + * 6. Convert arith and cf dialects to LLVM + * 7. Reconcile unrealized casts * * @pre * The input entry function must consist of a single block. The pass will - * restructure it into four blocks. Multi-block input functions are currently - * not supported. + * restructure it into four blocks. Multi-block input functions are + * currently not supported. */ struct QCToQIR final : impl::QCToQIRBase { using QCToQIRBase::QCToQIRBase; @@ -905,21 +1068,19 @@ struct QCToQIR final : impl::QCToQIRBase { * The QIR base profile requires a specific 4-block structure: * 1. **Entry block**: Contains constant operations and initialization * 2. **Body block**: Contains reversible quantum operations (gates) - * 3. **Measurements block**: Contains irreversible operations (measure, - * reset, dealloc) + * 3. **Measurements block**: Contains irreversible operations (measure and + * reset) * 4. **Output block**: Contains output recording calls * * Blocks are connected with unconditional jumps (entry, body, measurements, * output). This structure ensures proper QIR Base Profile semantics. * - * If the function already has multiple blocks, this function does nothing. - * * @param main The main LLVM function to restructure */ - static void ensureBlocks(LLVM::LLVMFuncOp& main) { - // Return if there are already multiple blocks + static void ensureBlocks(LLVM::LLVMFuncOp& main, LoweringState& state) { if (main.getBlocks().size() > 1) { - return; + llvm::reportFatalInternalError( + "Modules with multiple blocks are not supported yet"); } // Get the existing block @@ -934,24 +1095,23 @@ struct QCToQIR final : impl::QCToQIRBase { Block* measurementsBlock = builder.createBlock(&main.getBody()); Block* outputBlock = builder.createBlock(&main.getBody()); + state.entryBlock = entryBlock; + state.measurementsBlock = measurementsBlock; + state.outputBlock = outputBlock; + auto& bodyBlockOps = bodyBlock->getOperations(); auto& outputBlockOps = outputBlock->getOperations(); - auto& measurementsBlockOps = measurementsBlock->getOperations(); // Move operations to appropriate blocks for (auto it = bodyBlock->begin(); it != bodyBlock->end();) { // Ensure iterator remains valid after potential move - if (auto& op = *it++; - isa(op) || isa(op) || isa(op)) { - // Move irreversible quantum operations to measurements block - measurementsBlockOps.splice(measurementsBlock->end(), bodyBlockOps, - Block::iterator(op)); - } else if (isa(op)) { + if (auto& op = *it++; isa(op)) { // Move return to output block outputBlockOps.splice(outputBlock->end(), bodyBlockOps, Block::iterator(op)); - } else if (op.hasTrait()) { - // Move constant like operations to the entry block + } else if (isa(op) || isa(op) || + isa(op) || op.hasTrait()) { + // Move allocations and constant-like operations to entry block entryBlock->getOperations().splice(entryBlock->end(), bodyBlockOps, Block::iterator(op)); } @@ -973,44 +1133,25 @@ struct QCToQIR final : impl::QCToQIRBase { * @brief Adds QIR initialization call to the entry block * * @details - * Inserts a call to `__quantum__rt__initialize` at the end of the entry - * block (before the jump to main block). This QIR runtime function - * initializes the quantum execution environment and takes a null pointer as - * argument. + * This QIR runtime function initializes the quantum execution environment. * * @param main The main LLVM function * @param ctx The MLIR context + * @param state The lowering state */ - static void addInitialize(LLVM::LLVMFuncOp& main, MLIRContext* ctx) { - auto moduleOp = main->getParentOfType(); - auto& firstBlock = *(main.getBlocks().begin()); - OpBuilder builder(main.getBody()); + static void addInitialize(LLVM::LLVMFuncOp& main, MLIRContext* ctx, + LoweringState& state) { + OpBuilder builder(ctx); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto voidType = LLVM::LLVMVoidType::get(ctx); - // Create a zero (null) pointer for the initialize call - builder.setInsertionPointToStart(&firstBlock); - auto zeroOp = LLVM::ZeroOp::create(builder, main->getLoc(), - LLVM::LLVMPointerType::get(ctx)); - - // Insert the initialize call before the jump to main block - const auto insertPoint = std::prev(firstBlock.getOperations().end(), 1); - builder.setInsertionPoint(&*insertPoint); - - // Get or create the initialize function declaration - auto* fnDecl = SymbolTable::lookupNearestSymbolFrom( - main, builder.getStringAttr(QIR_INITIALIZE)); - if (fnDecl == nullptr) { - const PatternRewriter::InsertionGuard guard(builder); - builder.setInsertionPointToEnd(moduleOp.getBody()); - auto fnSignature = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(ctx), LLVM::LLVMPointerType::get(ctx)); - fnDecl = LLVM::LLVMFuncOp::create(builder, main->getLoc(), QIR_INITIALIZE, - fnSignature); - } + builder.setInsertionPointToStart(state.entryBlock); - // Create the initialization call - LLVM::CallOp::create(builder, main->getLoc(), - cast(fnDecl), - ValueRange{zeroOp->getResult(0)}); + auto initSig = LLVM::LLVMFunctionType::get(voidType, ptrType); + auto initDec = + getOrCreateFunctionDeclaration(builder, main, QIR_INITIALIZE, initSig); + auto zero = LLVM::ZeroOp::create(builder, main->getLoc(), ptrType); + LLVM::CallOp::create(builder, main->getLoc(), initDec, zero.getResult()); } /** @@ -1021,29 +1162,12 @@ struct QCToQIR final : impl::QCToQIRBase { * measurements tracked during conversion. Follows the QIR Base Profile * specification for labeled output schema. * - * For each classical register, creates: - * 1. An array_record_output call with the register size and label - * 2. Individual result_record_output calls for each measurement in the - * register - * - * Labels follow the format: "{registerName}{resultIndex}r" - * - registerName: Name of the classical register (e.g., "c") - * - resultIndex: Index within the array - * - 'r' suffix: Indicates this is a result record + * Results that are part of registers are recorded via + * `__quantum__rt__result_array_record_output`. * - * Example output: - * ``` - * @0 = internal constant [3 x i8] c"c\00" - * @1 = internal constant [5 x i8] c"c0r\00" - * @2 = internal constant [5 x i8] c"c1r\00" - * call void @__quantum__rt__array_record_output(i64 2, ptr @0) - * call void @__quantum__rt__result_record_output(ptr %result0, ptr @1) - * call void @__quantum__rt__result_record_output(ptr %result1, ptr @2) - * ``` - * - * Any output recording calls that are not part of registers (i.e., - * measurements without register info) are grouped under a default label - * "c" and recorded similarly. + * Results that are not part of registers (i.e., measurements without register + * info) are grouped under a default `__unnamed__` label recorded via + * `__quantum__rt__result_record_output`. * * @param main The main LLVM function * @param ctx The MLIR context @@ -1051,12 +1175,16 @@ struct QCToQIR final : impl::QCToQIRBase { */ static void addOutputRecording(LLVM::LLVMFuncOp& main, MLIRContext* ctx, LoweringState* state) { - if (state->registerResultMap.empty()) { + auto& resultArrays = state->resultArrays; + auto& resultPtrs = state->resultPtrs; + + if (resultArrays.empty() && resultPtrs.empty()) { return; // No measurements to record } OpBuilder builder(ctx); - const auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto voidType = LLVM::LLVMVoidType::get(ctx); // Find the output block auto& outputBlock = main.getBlocks().back(); @@ -1064,94 +1192,118 @@ struct QCToQIR final : impl::QCToQIRBase { // Insert before the branch to output block builder.setInsertionPoint(&outputBlock.back()); - // Group measurements by register - llvm::StringMap>> registerGroups; - for (const auto& [key, resultPtr] : state->registerResultMap) { - const auto& [registerName, registerIndex] = key; - registerGroups[registerName].emplace_back(registerIndex, resultPtr); - } + if (!resultPtrs.empty()) { + // Sort result pointers for deterministic output + llvm::SmallVector> sortedPtrs; + for (const auto& [index, resultPtr] : resultPtrs) { + sortedPtrs.emplace_back(index, resultPtr); + } + llvm::sort(sortedPtrs, [](const auto& a, const auto& b) { + return a.first < b.first; + }); - // Sort registers by name for deterministic output - SmallVector>>> - sortedRegisters; - for (auto& [name, measurements] : registerGroups) { - sortedRegisters.emplace_back(name, std::move(measurements)); + // Create output recording for each result pointer + auto fnSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration(builder, main, + QIR_RECORD_OUTPUT, fnSig); + + for (const auto& [index, ptr] : sortedPtrs) { + auto label = createResultLabel(builder, main, + "__unnamed__" + std::to_string(index)) + .getResult(); + LLVM::CallOp::create(builder, main->getLoc(), fnDec, + ValueRange{ptr, label}); + } } - llvm::sort(sortedRegisters, - [](const auto& a, const auto& b) { return a.first < b.first; }); - - // create function declarations for output recording - const auto arrayRecordSig = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(ctx), {builder.getI64Type(), ptrType}); - const auto arrayRecordDecl = getOrCreateFunctionDeclaration( - builder, main, QIR_ARRAY_RECORD_OUTPUT, arrayRecordSig); - - const auto resultRecordSig = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(ctx), {ptrType, ptrType}); - const auto resultRecordDecl = getOrCreateFunctionDeclaration( - builder, main, QIR_RECORD_OUTPUT, resultRecordSig); - - // Generate output recording for each register - for (auto& [registerName, measurements] : sortedRegisters) { - // Sort measurements by register index - llvm::sort(measurements, [](const auto& a, const auto& b) { + + if (!resultArrays.empty()) { + // Sort registers by name for deterministic output + SmallVector> sortedRegisters; + for (auto& [name, results] : resultArrays) { + sortedRegisters.emplace_back(name, results); + } + llvm::sort(sortedRegisters, [](const auto& a, const auto& b) { return a.first < b.first; }); - const auto arraySize = measurements.size(); - auto arrayLabelOp = createResultLabel(builder, main, registerName); - auto arraySizeConst = LLVM::ConstantOp::create( - builder, main->getLoc(), - builder.getI64IntegerAttr(static_cast(arraySize))); - - LLVM::CallOp::create( - builder, main->getLoc(), arrayRecordDecl, - ValueRange{arraySizeConst.getResult(), arrayLabelOp.getResult()}); - - // Create result_record_output calls for each measurement - for (auto [regIdx, resultPtr] : measurements) { - // Create label for result: "{arrayCounter+1+i}_{registerName}{i}r" - const std::string resultLabel = - registerName.str() + std::to_string(regIdx) + "r"; - auto resultLabelOp = createResultLabel(builder, main, resultLabel); - - LLVM::CallOp::create(builder, main->getLoc(), resultRecordDecl, - ValueRange{resultPtr, resultLabelOp.getResult()}); + auto fnSig = LLVM::LLVMFunctionType::get( + voidType, {builder.getI64Type(), ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration( + builder, main, QIR_ARRAY_RECORD_OUTPUT, fnSig); + + // Generate output recording for each register + for (auto& [name, results] : sortedRegisters) { + auto size = results.getDefiningOp().getArraySize(); + auto label = createResultLabel(builder, main, name).getResult(); + + LLVM::CallOp::create(builder, main->getLoc(), fnDec, + ValueRange{size, results, label}); } } } + /** + * @brief Iterates through all result pointers and releases them + */ + static void releaseResults(LLVM::LLVMFuncOp& main, MLIRContext* ctx, + LoweringState* state) { + OpBuilder builder(ctx); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto voidType = LLVM::LLVMVoidType::get(ctx); + + // Release resources in output block + builder.setInsertionPoint(state->outputBlock->getTerminator()); + + for (auto& [_, ptr] : state->resultPtrs) { + auto sig = LLVM::LLVMFunctionType::get(voidType, {ptrType}); + auto dec = getOrCreateFunctionDeclaration(builder, main, + QIR_RESULT_RELEASE, sig); + LLVM::CallOp::create(builder, main->getLoc(), dec, ptr); + } + + for (auto& [_, array] : state->resultArrays) { + auto sig = LLVM::LLVMFunctionType::get(voidType, + {builder.getI64Type(), ptrType}); + auto dec = getOrCreateFunctionDeclaration(builder, main, + QIR_RESULT_ARRAY_RELEASE, sig); + auto size = array.getDefiningOp().getArraySize(); + LLVM::CallOp::create(builder, main->getLoc(), dec, + ValueRange{size, array}); + } + } + protected: /** * @brief Executes the QC to QIR conversion pass * * @details - * Performs the conversion in six stages: + * Performs the conversion in seven stages: * * **Stage 1: Func to LLVM** * Convert func dialect operations (main function) to LLVM dialect * equivalents. * - * **Stage 2: Block structure and initialization** + * **Stage 2: Block structure** * Create proper 4-block structure for QIR base profile (entry, main, - * irreversible, output) and insert the `__quantum__rt__initialize` call - * in the entry block. + * irreversible, output). * - * **Stage 3: QC to LLVM** - * Convert QC dialect operations to QIR calls and add output recording to - * the output block. + * **Stage 3: Initialization** + * Insert the `__quantum__rt__initialize` call. * - * **Stage 4: QIR attributes** - * Add QIR base profile metadata to the main function, including - * qubit/result counts and version information. + * **Stage 4: QC to LLVM** + * Convert QC dialect operations to QIR calls and add output recording to the + * output block. * - * **Stage 5: Standard dialects to LLVM** + * **Stage 5: QIR attributes** + * Add QIR base profile metadata to the main function, including qubit/result + * counts and version information. + * + * **Stage 6: Standard dialects to LLVM** * Convert arith and control flow dialects to LLVM (for index arithmetic and * function control flow). * - * **Stage 6: Reconcile casts** - * Clean up any unrealized cast operations introduced during type - * conversion. + * **Stage 7: Reconcile casts** + * Clean up any unrealized cast operations introduced during type conversion. */ void runOnOperation() override { MLIRContext* ctx = &getContext(); @@ -1174,7 +1326,6 @@ struct QCToQIR final : impl::QCToQIRBase { } } - // Stage 2: Ensure proper block structure and add initialization auto main = getMainFunction(moduleOp); if (!main) { moduleOp->emitError("No main function with entry_point attribute found"); @@ -1182,68 +1333,48 @@ struct QCToQIR final : impl::QCToQIRBase { return; } - ensureBlocks(main); - addInitialize(main, ctx); - LoweringState state; - // Stage 3: Convert QC dialect to LLVM (QIR calls) + // Stage 2: Create block structure + ensureBlocks(main, state); + + // Stage 3: Insert initialize call + addInitialize(main, ctx, state); + + // Stage 4: Convert QC dialect to LLVM (QIR calls) { - RewritePatternSet qcPatterns(ctx); - target.addIllegalDialect(); - - // Add conversion patterns for QC operations - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - - if (applyPartialConversion(moduleOp, target, std::move(qcPatterns)) + RewritePatternSet patterns(ctx); + target.addIllegalDialect(); + + patterns.add( + typeConverter, ctx, &state); + + if (applyPartialConversion(moduleOp, target, std::move(patterns)) .failed()) { signalPassFailure(); return; } addOutputRecording(main, ctx, &state); + + releaseResults(main, ctx, &state); } - // Stage 4: Set QIR metadata attributes + // Stage 5: Set QIR metadata attributes setQIRAttributes(main, state); - // Stage 5: Convert standard dialects to LLVM + // Stage 6: Convert standard dialects to LLVM { RewritePatternSet stdPatterns(ctx); target.addIllegalDialect(); @@ -1260,7 +1391,7 @@ struct QCToQIR final : impl::QCToQIRBase { } } - // Stage 6: Reconcile unrealized casts + // Stage 7: Reconcile unrealized casts PassManager passManager(ctx); passManager.addPass(createReconcileUnrealizedCastsPass()); if (passManager.run(moduleOp).failed()) { diff --git a/mlir/lib/Dialect/QC/Builder/CMakeLists.txt b/mlir/lib/Dialect/QC/Builder/CMakeLists.txt index 868674da35..d5f7f51caa 100644 --- a/mlir/lib/Dialect/QC/Builder/CMakeLists.txt +++ b/mlir/lib/Dialect/QC/Builder/CMakeLists.txt @@ -13,7 +13,7 @@ add_mlir_library( PUBLIC MLIRArithDialect MLIRFuncDialect - MLIRSCFDialect + MLIRMemRefDialect MLIRQCDialect) mqt_mlir_target_use_project_options(MLIRQCProgramBuilder) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 991dba1c1b..33e7459aab 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -14,12 +14,13 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" -#include #include #include +#include #include #include #include +#include #include #include #include @@ -87,29 +88,25 @@ Value QCProgramBuilder::staticQubit(const uint64_t index) { } llvm::SmallVector -QCProgramBuilder::allocQubitRegister(const int64_t size, - const std::string& name) { +QCProgramBuilder::allocQubitRegister(const int64_t size) { checkFinalized(); if (size <= 0) { llvm::reportFatalUsageError("Size must be positive"); } - // Allocate a sequence of qubits with register metadata + auto memrefType = MemRefType::get({size}, QubitType::get(ctx)); + auto memref = memref::AllocOp::create(*this, memrefType); + allocatedMemrefs.insert(memref); + llvm::SmallVector qubits; qubits.reserve(size); - - auto nameAttr = getStringAttr(name); - auto sizeAttr = getI64IntegerAttr(size); - for (int64_t i = 0; i < size; ++i) { - auto indexAttr = getI64IntegerAttr(i); - auto allocOp = AllocOp::create(*this, nameAttr, sizeAttr, indexAttr); - const auto& qubit = qubits.emplace_back(allocOp.getResult()); - // Track the allocated qubit for automatic deallocation + auto index = arith::ConstantIndexOp::create(*this, i); + auto load = memref::LoadOp::create(*this, memref, index.getResult()); + const auto& qubit = qubits.emplace_back(load.getResult()); allocatedQubits.insert(qubit); } - return qubits; } @@ -449,12 +446,14 @@ QCProgramBuilder::inv(const llvm::function_ref& body) { QCProgramBuilder& QCProgramBuilder::dealloc(Value qubit) { checkFinalized(); + if (llvm::isa_and_nonnull(qubit.getDefiningOp())) { + llvm::reportFatalUsageError( + "Register-backed qubits cannot be deallocated manually"); + } + // Check if the qubit is in the tracking set if (!allocatedQubits.erase(qubit)) { - // Qubit was not found in the set - either never allocated or already - // deallocated - llvm::reportFatalUsageError( - "Double deallocation or invalid qubit deallocation"); + llvm::reportFatalUsageError("Invalid qubit deallocation"); } // Create the DeallocOp @@ -494,25 +493,18 @@ OwningOpRef QCProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } - // Automatically deallocate all still-allocated qubits - // Sort qubits for deterministic output - llvm::SmallVector sortedQubits(allocatedQubits.begin(), - allocatedQubits.end()); - llvm::sort(sortedQubits, [](Value a, Value b) { - auto* opA = a.getDefiningOp(); - auto* opB = b.getDefiningOp(); - if (!opA || !opB || opA->getBlock() != opB->getBlock()) { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + for (auto qubit : allocatedQubits) { + if (!llvm::isa(qubit.getDefiningOp())) { + DeallocOp::create(*this, qubit); } - return opA->isBeforeInBlock(opB); - }); - for (auto qubit : sortedQubits) { - DeallocOp::create(*this, qubit); } - - // Clear the tracking set allocatedQubits.clear(); + for (auto memref : allocatedMemrefs) { + memref::DeallocOp::create(*this, memref); + } + allocatedMemrefs.clear(); + // Create constant 0 for successful exit code auto exitCode = intConstant(0); diff --git a/mlir/lib/Dialect/QC/CMakeLists.txt b/mlir/lib/Dialect/QC/CMakeLists.txt index 49d4a2a9fc..c4d5cdbc80 100644 --- a/mlir/lib/Dialect/QC/CMakeLists.txt +++ b/mlir/lib/Dialect/QC/CMakeLists.txt @@ -9,3 +9,4 @@ add_subdirectory(IR) add_subdirectory(Builder) add_subdirectory(Translation) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/QC/IR/QubitManagement/DeallocOp.cpp b/mlir/lib/Dialect/QC/IR/QubitManagement/DeallocOp.cpp index b842ab042c..db17a9ab8b 100644 --- a/mlir/lib/Dialect/QC/IR/QubitManagement/DeallocOp.cpp +++ b/mlir/lib/Dialect/QC/IR/QubitManagement/DeallocOp.cpp @@ -21,8 +21,7 @@ using namespace mlir::qc; namespace { /** - * @brief Remove matching allocation and deallocation pairs without operations - * between them. + * @brief Remove matching allocation-deallocation pairs. */ struct RemoveAllocDeallocPair final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/lib/Dialect/QC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..e64d323c01 --- /dev/null +++ b/mlir/lib/Dialect/QC/Transforms/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +file(GLOB_RECURSE PASSES_SOURCES *.cpp) + +add_mlir_library( + MLIRQCTransforms + ${PASSES_SOURCES} + LINK_LIBS + PRIVATE + MLIRQCDialect + DEPENDS + MLIRQCTransformsIncGen) + +# collect header files +file(GLOB_RECURSE PASSES_HEADERS_SOURCE + ${MQT_MLIR_SOURCE_INCLUDE_DIR}/mlir/Dialect/QC/Transforms/*.h) +file(GLOB_RECURSE PASSES_HEADERS_BUILD + ${MQT_MLIR_BUILD_INCLUDE_DIR}/mlir/Dialect/QC/Transforms/*.inc) + +# add public headers using file sets +target_sources( + MLIRQCTransforms + PUBLIC FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_SOURCE_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_SOURCE} + FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_BUILD_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_BUILD}) diff --git a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp new file mode 100644 index 0000000000..48ce84a42c --- /dev/null +++ b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QC/IR/QCDialect.h" +#include "mlir/Dialect/QC/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace mlir::qc { + +#define GEN_PASS_DEF_SHRINKQUBITREGISTERSPASS +#include "mlir/Dialect/QC/Transforms/Passes.h.inc" + +/** + * @brief Return the constant index of a one-dimensional memref load. + */ +[[nodiscard]] static std::optional +getLoadIndex(memref::LoadOp loadOp) { + if (loadOp.getIndices().size() != 1) { + return std::nullopt; + } + return getConstantIntValue(loadOp.getIndices().front()); +} + +namespace { +/** + * @brief Shrink static qubit registers to actually read indices. + */ +struct ShrinkQubitRegister final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DeallocOp op, + PatternRewriter& rewriter) const override { + auto allocOp = op.getMemref().getDefiningOp(); + if (!allocOp) { + return failure(); + } + + auto memRefType = llvm::dyn_cast(op.getMemref().getType()); + if (!memRefType || memRefType.getRank() != 1 || + !memRefType.hasStaticShape()) { + return failure(); + } + if (!llvm::isa(memRefType.getElementType())) { + return failure(); + } + if (!memRefType.getLayout().isIdentity()) { + return failure(); + } + if (memRefType.getMemorySpace() != nullptr) { + return failure(); + } + + llvm::SmallVector loadOps; + llvm::SmallVector liveIndices; + llvm::DenseMap newIndexByOldIndex; + + for (auto* user : op.getMemref().getUsers()) { + if (user == op.getOperation()) { + continue; + } + auto loadOp = llvm::dyn_cast(user); + if (!loadOp) { + return failure(); + } + auto index = getLoadIndex(loadOp); + if (!index || *index < 0 || *index >= memRefType.getDimSize(0)) { + return failure(); + } + loadOps.push_back(loadOp); + if (!loadOp.getResult().use_empty() && + !newIndexByOldIndex.contains(*index)) { + newIndexByOldIndex.try_emplace(*index, 0U); + liveIndices.push_back(*index); + } + } + + if (liveIndices.empty()) { + for (auto loadOp : loadOps) { + rewriter.eraseOp(loadOp); + } + rewriter.eraseOp(op); + rewriter.eraseOp(allocOp); + return success(); + } + + llvm::sort(liveIndices); + if (static_cast(liveIndices.size()) == memRefType.getDimSize(0) && + llvm::all_of(llvm::enumerate(liveIndices), [](const auto& indexed) { + return static_cast(indexed.index()) == indexed.value(); + })) { + return failure(); + } + + newIndexByOldIndex.clear(); + for (size_t i = 0; i < liveIndices.size(); ++i) { + newIndexByOldIndex.try_emplace(liveIndices[i], i); + } + + rewriter.setInsertionPoint(allocOp); + auto newMemRefType = + MemRefType::get({static_cast(liveIndices.size())}, + memRefType.getElementType()); + auto newAlloc = + memref::AllocOp::create(rewriter, allocOp.getLoc(), newMemRefType); + + for (auto loadOp : loadOps) { + if (loadOp.getResult().use_empty()) { + rewriter.eraseOp(loadOp); + continue; + } + + const auto oldIndex = *getLoadIndex(loadOp); + const auto newIndex = + static_cast(newIndexByOldIndex.lookup(oldIndex)); + rewriter.setInsertionPoint(loadOp); + auto indexConst = + arith::ConstantIndexOp::create(rewriter, loadOp.getLoc(), newIndex); + auto newLoad = memref::LoadOp::create(rewriter, loadOp.getLoc(), + newAlloc.getResult(), + ValueRange{indexConst.getResult()}); + rewriter.replaceOp(loadOp, newLoad); + } + + rewriter.setInsertionPoint(op); + memref::DeallocOp::create(rewriter, op.getLoc(), newAlloc.getResult()); + rewriter.eraseOp(op); + rewriter.eraseOp(allocOp); + return success(); + } +}; + +struct ShrinkQubitRegistersPass final + : impl::ShrinkQubitRegistersPassBase { +protected: + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } + } +}; +} // namespace +} // namespace mlir::qc diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index ee94512d5b..f250201cc4 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -97,8 +97,8 @@ allocateQregs(QCProgramBuilder& builder, // Allocate quantum registers using the builder SmallVector qregs; for (const auto* qregPtr : qregPtrs) { - auto qubits = builder.allocQubitRegister( - static_cast(qregPtr->getSize()), qregPtr->getName()); + auto qubits = + builder.allocQubitRegister(static_cast(qregPtr->getSize())); qregs.emplace_back(qregPtr, std::move(qubits)); } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 372498cb0c..d0d2abb9f3 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -16,6 +16,8 @@ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include "mlir/Dialect/Utils/Utils.h" +#include +#include #include #include #include @@ -32,7 +34,6 @@ #include #include -#include #include #include #include @@ -75,10 +76,10 @@ Value QCOProgramBuilder::allocQubit() { checkFinalized(); auto allocOp = AllocOp::create(*this); - const auto qubit = allocOp.getResult(); + auto qubit = allocOp.getResult(); // Track the allocated qubit as valid - validQubits.insert(qubit); + validQubits.try_emplace(qubit, QubitInfo{}); return qubit; } @@ -90,34 +91,28 @@ Value QCOProgramBuilder::staticQubit(const uint64_t index) { const auto qubit = staticOp.getQubit(); // Track the static qubit as valid - validQubits.insert(qubit); + validQubits.try_emplace(qubit, QubitInfo{}); return qubit; } llvm::SmallVector -QCOProgramBuilder::allocQubitRegister(const int64_t size, - const std::string& name) { +QCOProgramBuilder::allocQubitRegister(const int64_t size) { checkFinalized(); if (size <= 0) { llvm::reportFatalUsageError("Size must be positive"); } - llvm::SmallVector qubits; - qubits.reserve(static_cast(size)); - - auto nameAttr = getStringAttr(name); - auto sizeAttr = getI64IntegerAttr(size); + auto qtensor = qtensorAlloc(size); + llvm::SmallVector qubits; + qubits.reserve(size); for (int64_t i = 0; i < size; ++i) { - const auto indexAttr = getI64IntegerAttr(i); - auto allocOp = AllocOp::create(*this, nameAttr, sizeAttr, indexAttr); - const auto& qubit = qubits.emplace_back(allocOp.getResult()); - // Track the allocated qubit as valid - validQubits.insert(qubit); + auto [qtensorOut, qubit] = qtensorExtract(qtensor, i); + qtensor = qtensorOut; + qubits.emplace_back(qubit); } - return qubits; } @@ -152,11 +147,14 @@ void QCOProgramBuilder::updateQubitTracking(Value inputQubit, // Validate the input qubit validateQubitValue(inputQubit); + auto it = validQubits.find(inputQubit); + auto info = it->second; + // Remove the input (consumed) value from tracking - validQubits.erase(inputQubit); + validQubits.erase(it); // Add the output (new) value to tracking - validQubits.insert(outputQubit); + validQubits.try_emplace(outputQubit, info); } void QCOProgramBuilder::validateTensorValue(Value tensor) const { @@ -182,11 +180,14 @@ void QCOProgramBuilder::updateTensorTracking(Value inputTensor, // Validate the input tensor validateTensorValue(inputTensor); + auto it = validTensors.find(inputTensor); + auto info = it->second; + // Remove the input (consumed) value from tracking - validTensors.erase(inputTensor); + validTensors.erase(it); // Add the output (new) value to tracking - validTensors.insert(outputTensor); + validTensors.try_emplace(outputTensor, info); } //===----------------------------------------------------------------------===// @@ -196,11 +197,13 @@ void QCOProgramBuilder::updateTensorTracking(Value inputTensor, Value QCOProgramBuilder::qtensorAlloc( const std::variant& size) { checkFinalized(); - auto sizeValue = utils::variantToValue(*this, getLoc(), size); + auto sizeValue = variantToValue(*this, getLoc(), size); auto allocOp = qtensor::AllocOp::create(*this, sizeValue); + auto result = allocOp.getResult(); - validTensors.insert(result); + validTensors.try_emplace(result, TensorInfo{tensorCounter++}); + return result; } @@ -221,49 +224,33 @@ Value QCOProgramBuilder::qtensorFromElements(ValueRange elements) { auto fromElementsOp = qtensor::FromElementsOp::create(*this, elements); auto result = fromElementsOp.getResult(); - validTensors.insert(result); + validTensors.try_emplace(result, TensorInfo{tensorCounter++}); return result; } -std::pair -QCOProgramBuilder::qtensorExtract(Value tensor, - const std::variant& index) { +std::pair QCOProgramBuilder::qtensorExtract(Value tensor, + const int64_t index) { checkFinalized(); - auto indexValue = utils::variantToValue(*this, getLoc(), index); + auto indexValue = arith::ConstantIndexOp::create(*this, index).getResult(); auto extractOp = qtensor::ExtractOp::create(*this, tensor, indexValue); auto qubit = extractOp.getResult(); auto outTensor = extractOp.getOutTensor(); - validQubits.insert(qubit); - updateTensorTracking(tensor, outTensor); - - return {outTensor, qubit}; -} - -std::pair QCOProgramBuilder::qtensorExtractSlice( - Value tensor, const std::variant& offset, - const std::variant& size) { - checkFinalized(); - - auto offsetValue = utils::variantToValue(*this, getLoc(), offset); - auto sizesValue = utils::variantToValue(*this, getLoc(), size); - auto extractSliceOp = - qtensor::ExtractSliceOp::create(*this, tensor, offsetValue, sizesValue); - auto slicedTensor = extractSliceOp.getResult(); - auto outTensor = extractSliceOp.getOutTensor(); + validateTensorValue(tensor); + const auto regId = validTensors[tensor].regId; - validTensors.insert(slicedTensor); + validQubits.try_emplace(qubit, QubitInfo{.regId = regId, .regIndex = index}); updateTensorTracking(tensor, outTensor); - return {outTensor, slicedTensor}; + return {outTensor, qubit}; } Value QCOProgramBuilder::qtensorInsert( Value scalar, Value tensor, const std::variant& index) { checkFinalized(); - auto indexValue = utils::variantToValue(*this, getLoc(), index); + auto indexValue = variantToValue(*this, getLoc(), index); auto insertOp = qtensor::InsertOp::create(*this, scalar, tensor, indexValue); auto outTensor = insertOp.getResult(); @@ -271,24 +258,6 @@ Value QCOProgramBuilder::qtensorInsert( validateQubitValue(scalar); validQubits.erase(scalar); updateTensorTracking(tensor, outTensor); - return outTensor; -} - -Value QCOProgramBuilder::qtensorInsertSlice( - Value source, Value dest, const std::variant& offset, - const std::variant& size) { - checkFinalized(); - - auto offsetValue = utils::variantToValue(*this, getLoc(), offset); - auto sizeValue = utils::variantToValue(*this, getLoc(), size); - auto insertSliceOp = qtensor::InsertSliceOp::create(*this, source, dest, - offsetValue, sizeValue); - - auto outTensor = insertSliceOp.getResult(); - - validateTensorValue(source); - validTensors.erase(source); - updateTensorTracking(dest, outTensor); return outTensor; } @@ -329,7 +298,7 @@ Value QCOProgramBuilder::measure(Value qubit, const Bit& bit) { auto indexAttr = getI64IntegerAttr(bit.registerIndex); auto measureOp = MeasureOp::create(*this, qubit, nameAttr, sizeAttr, indexAttr); - const auto qubitOut = measureOp.getQubitOut(); + auto qubitOut = measureOp.getQubitOut(); // Update tracking updateQubitTracking(qubit, qubitOut); @@ -341,7 +310,7 @@ Value QCOProgramBuilder::reset(Value qubit) { checkFinalized(); auto resetOp = ResetOp::create(*this, qubit); - const auto qubitOut = resetOp.getQubitOut(); + auto qubitOut = resetOp.getQubitOut(); // Update tracking updateQubitTracking(qubit, qubitOut); @@ -397,7 +366,7 @@ DEFINE_ZERO_TARGET_ONE_PARAMETER(GPhaseOp, gphase, theta) Value QCOProgramBuilder::OP_NAME(Value qubit) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit); \ - const auto& qubitOut = op.getQubitOut(); \ + auto qubitOut = op.getQubitOut(); \ updateQubitTracking(qubit, qubitOut); \ return qubitOut; \ } \ @@ -442,7 +411,7 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg) Value qubit) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit, PARAM); \ - const auto& qubitOut = op.getQubitOut(); \ + auto qubitOut = op.getQubitOut(); \ updateQubitTracking(qubit, qubitOut); \ return qubitOut; \ } \ @@ -485,7 +454,7 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, phi) Value qubit) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit, PARAM1, PARAM2); \ - const auto& qubitOut = op.getQubitOut(); \ + auto qubitOut = op.getQubitOut(); \ updateQubitTracking(qubit, qubitOut); \ return qubitOut; \ } \ @@ -532,7 +501,7 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) Value qubit) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit, PARAM1, PARAM2, PARAM3); \ - const auto& qubitOut = op.getQubitOut(); \ + auto qubitOut = op.getQubitOut(); \ updateQubitTracking(qubit, qubitOut); \ return qubitOut; \ } \ @@ -579,8 +548,8 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, u, theta, phi, lambda) Value qubit1) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit0, qubit1); \ - const auto& qubit0Out = op.getQubit0Out(); \ - const auto& qubit1Out = op.getQubit1Out(); \ + auto qubit0Out = op.getQubit0Out(); \ + auto qubit1Out = op.getQubit1Out(); \ updateQubitTracking(qubit0, qubit0Out); \ updateQubitTracking(qubit1, qubit1Out); \ return {qubit0Out, qubit1Out}; \ @@ -623,8 +592,8 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr) const std::variant&(PARAM), Value qubit0, Value qubit1) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit0, qubit1, PARAM); \ - const auto& qubit0Out = op.getQubit0Out(); \ - const auto& qubit1Out = op.getQubit1Out(); \ + auto qubit0Out = op.getQubit0Out(); \ + auto qubit1Out = op.getQubit1Out(); \ updateQubitTracking(qubit0, qubit0Out); \ updateQubitTracking(qubit1, qubit1Out); \ return {qubit0Out, qubit1Out}; \ @@ -673,8 +642,8 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta) Value qubit1) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit0, qubit1, PARAM1, PARAM2); \ - const auto& qubit0Out = op.getQubit0Out(); \ - const auto& qubit1Out = op.getQubit1Out(); \ + auto qubit0Out = op.getQubit0Out(); \ + auto qubit1Out = op.getQubit1Out(); \ updateQubitTracking(qubit0, qubit0Out); \ updateQubitTracking(qubit1, qubit1Out); \ return {qubit0Out, qubit1Out}; \ @@ -724,7 +693,7 @@ ValueRange QCOProgramBuilder::barrier(ValueRange qubits) { checkFinalized(); auto op = BarrierOp::create(*this, qubits); - const auto& qubitsOut = op.getQubitsOut(); + auto qubitsOut = op.getQubitsOut(); for (const auto& [inputQubit, outputQubit] : llvm::zip(qubits, qubitsOut)) { updateQubitTracking(inputQubit, outputQubit); } @@ -742,7 +711,7 @@ std::pair QCOProgramBuilder::ctrl( auto ctrlOp = CtrlOp::create(*this, controls, targets); auto& block = ctrlOp.getBodyRegion().emplaceBlock(); - const auto qubitType = QubitType::get(getContext()); + auto qubitType = QubitType::get(getContext()); for (const auto target : targets) { const auto arg = block.addArgument(qubitType, getLoc()); updateQubitTracking(target, arg); @@ -781,8 +750,8 @@ ValueRange QCOProgramBuilder::inv( // Add block arguments for all qubits auto& block = invOp.getBodyRegion().emplaceBlock(); - const auto qubitType = QubitType::get(getContext()); - for (const auto qubit : qubits) { + auto qubitType = QubitType::get(getContext()); + for (auto qubit : qubits) { const auto arg = block.addArgument(qubitType, getLoc()); updateQubitTracking(qubit, arg); } @@ -833,7 +802,7 @@ ValueRange QCOProgramBuilder::qcoIf( llvm::function_ref(ValueRange)> elseBody) { checkFinalized(); - auto conditionValue = utils::variantToValue(*this, getLoc(), condition); + auto conditionValue = variantToValue(*this, getLoc(), condition); auto ifOp = IfOp::create(*this, conditionValue, qubits); // Create the then and else block @@ -844,8 +813,8 @@ ValueRange QCOProgramBuilder::qcoIf( for (auto qubitType : qubits.getTypes()) { const auto thenArg = thenBlock.addArgument(qubitType, getLoc()); const auto elseArg = elseBlock.addArgument(qubitType, getLoc()); - validQubits.insert(thenArg); - validQubits.insert(elseArg); + validQubits.try_emplace(thenArg, QubitInfo{}); + validQubits.try_emplace(elseArg, QubitInfo{}); } // Construct the bodies of the regions @@ -920,15 +889,36 @@ OwningOpRef QCOProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } - // Automatically deallocate all still-allocated qubits - for (auto qubit : validQubits) { - SinkOp::create(*this, qubit); + llvm::DenseSet validTensorIds; + for (const auto& [tensor, info] : validTensors) { + validTensorIds.insert(info.regId); + } + + llvm::DenseMap>> + qubitsByRegister; + for (auto [qubit, info] : validQubits) { + if (info.regId == -1 || !validTensorIds.contains(info.regId)) { + // Automatically deallocate all still-allocated qubits + SinkOp::create(*this, qubit); + } else { + qubitsByRegister[info.regId].emplace_back(qubit, info); + } } - validQubits.clear(); - for (auto tensor : validTensors) { - qtensor::DeallocOp::create(*this, tensor); + // Automatically deallocate all still-allocated tensors + for (auto& [tensor, tensorInfo] : validTensors) { + auto currentTensor = tensor; + // Filter out qubits belonging to this tensor + for (auto& [qubit, qubitInfo] : qubitsByRegister[tensorInfo.regId]) { + auto indexValue = constantFromScalar(*this, getLoc(), qubitInfo.regIndex); + currentTensor = + qtensor::InsertOp::create(*this, qubit, currentTensor, indexValue) + .getResult(); + } + // Deallocate tensor + qtensor::DeallocOp::create(*this, currentTensor); } + validQubits.clear(); validTensors.clear(); // Create constant 0 for successful exit code diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index 7ae1a733da..3db78fbd8b 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -9,27 +9,90 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" +#include +#include #include #include #include +#include #include using namespace mlir; using namespace mlir::qco; +/** + * @brief Check if a `qtensor.extract` operation reads from a `qtensor.alloc` + * chain. + * + * @details In QTensor's linear tensor model, reads/writes on different indices + * commute. We can therefore skip over `qtensor.insert` on other indices while + * tracing provenance. A write to the same index invalidates the proof. + */ +static bool originatesFromQTensorAlloc(qtensor::ExtractOp extractOp) { + auto current = extractOp.getTensor(); + + auto extractIndex = extractOp.getIndex(); + if (!getConstantIntValue(extractIndex)) { + return false; + } + + while (auto* definingOp = current.getDefiningOp()) { + if (llvm::isa(definingOp)) { + return true; + } + + if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { + auto nestedExtractIndex = nestedExtractOp.getIndex(); + if (!getConstantIntValue(nestedExtractIndex)) { + return false; + } + if (qtensor::areEquivalentIndices(extractIndex, nestedExtractIndex)) { + return false; + } + current = nestedExtractOp.getTensor(); + continue; + } + + if (auto insertOp = llvm::dyn_cast(definingOp)) { + auto insertIndex = insertOp.getIndex(); + if (!getConstantIntValue(insertIndex)) { + return false; + } + if (qtensor::areEquivalentIndices(extractIndex, insertIndex)) { + return false; + } + current = insertOp.getDest(); + continue; + } + + return false; + } + + return false; +} + namespace { /** - * @brief Remove reset operations that immediately follow an allocation. + * @brief Remove reset operations that immediately follow a `qtensor.extract` + * operation. */ -struct RemoveResetAfterAlloc final : OpRewritePattern { +struct RemoveResetAfterExtract final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ResetOp op, PatternRewriter& rewriter) const override { - // Check if the predecessor is an AllocOp - if (auto allocOp = op.getQubitIn().getDefiningOp(); !allocOp) { + // Check if the predecessor is an ExtractOp + auto extractOp = op.getQubitIn().getDefiningOp(); + if (!extractOp) { + return failure(); + } + + // Check if the tensor originates from an AllocOp + if (!originatesFromQTensorAlloc(extractOp)) { return failure(); } @@ -41,7 +104,15 @@ struct RemoveResetAfterAlloc final : OpRewritePattern { } // namespace +OpFoldResult ResetOp::fold(FoldAdaptor /*adaptor*/) { + if (getQubitIn().getDefiningOp()) { + return getQubitIn(); + } + + return {}; +} + void ResetOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp index 709c7ca0bf..08ee6e068e 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -42,18 +43,6 @@ struct MergeSubsequentP final : OpRewritePattern { } }; -/** - * @brief Remove trivial P operations. - */ -struct RemoveTrivialP final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(POp op, - PatternRewriter& rewriter) const override { - return removeTrivialOneTargetOneParameter(op, rewriter); - } -}; - } // namespace void POp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, @@ -63,9 +52,17 @@ void POp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, build(odsBuilder, odsState, qubitIn, thetaOperand); } +OpFoldResult POp::fold(FoldAdaptor /*adaptor*/) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + return getInputQubit(0); + } + return {}; +} + void POp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional POp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp index a57ea85105..eef8bf100c 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp @@ -43,18 +43,6 @@ struct MergeSubsequentRX final : OpRewritePattern { } }; -/** - * @brief Remove trivial RX operations. - */ -struct RemoveTrivialRX final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RXOp op, - PatternRewriter& rewriter) const override { - return removeTrivialOneTargetOneParameter(op, rewriter); - } -}; - } // namespace void RXOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, @@ -64,9 +52,17 @@ void RXOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, build(odsBuilder, odsState, qubitIn, thetaOperand); } +OpFoldResult RXOp::fold(FoldAdaptor /*adaptor*/) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + return getInputQubit(0); + } + return {}; +} + void RXOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional RXOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp index 667720a792..ca26cf0c78 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -56,18 +57,6 @@ struct MergeSwappedTargetsRXX final : OpRewritePattern { } }; -/** - * @brief Remove trivial RXX operations. - */ -struct RemoveTrivialRXX final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RXXOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void RXXOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -78,10 +67,20 @@ void RXXOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand); } +LogicalResult RXXOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void RXXOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add( - context); + results.add(context); } std::optional RXXOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp index 634fd106a7..3b9255250a 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp @@ -43,18 +43,6 @@ struct MergeSubsequentRY final : OpRewritePattern { } }; -/** - * @brief Remove trivial RY operations. - */ -struct RemoveTrivialRY final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RYOp op, - PatternRewriter& rewriter) const override { - return removeTrivialOneTargetOneParameter(op, rewriter); - } -}; - } // namespace void RYOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, @@ -64,9 +52,17 @@ void RYOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, build(odsBuilder, odsState, qubitIn, thetaOperand); } +OpFoldResult RYOp::fold(FoldAdaptor /*adaptor*/) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + return getInputQubit(0); + } + return {}; +} + void RYOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional RYOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp index 7709c872e6..6b9d540062 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -56,18 +57,6 @@ struct MergeSwappedTargetsRYY final : OpRewritePattern { } }; -/** - * @brief Remove trivial RYY operations. - */ -struct RemoveTrivialRYY final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RYYOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void RYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -78,10 +67,20 @@ void RYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand); } +LogicalResult RYYOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void RYYOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add( - context); + results.add(context); } std::optional RYYOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp index d888a36689..cad399846c 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -42,18 +43,6 @@ struct MergeSubsequentRZ final : OpRewritePattern { } }; -/** - * @brief Remove trivial RZ operations. - */ -struct RemoveTrivialRZ final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RZOp op, - PatternRewriter& rewriter) const override { - return removeTrivialOneTargetOneParameter(op, rewriter); - } -}; - } // namespace void RZOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, @@ -63,9 +52,17 @@ void RZOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, build(odsBuilder, odsState, qubitIn, thetaOperand); } +OpFoldResult RZOp::fold(FoldAdaptor /*adaptor*/) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + return getInputQubit(0); + } + return {}; +} + void RZOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional RZOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp index a9b9770fff..be40c3de9a 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -43,18 +44,6 @@ struct MergeSubsequentRZX final : OpRewritePattern { } }; -/** - * @brief Remove trivial RZX operations. - */ -struct RemoveTrivialRZX final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RZXOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void RZXOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -65,9 +54,20 @@ void RZXOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand); } +LogicalResult RZXOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void RZXOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional RZXOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp index 85ea44a7da..5fb05c3379 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp @@ -17,8 +17,10 @@ #include #include #include +#include #include +#include #include #include #include @@ -55,18 +57,6 @@ struct MergeSwappedTargetsRZZ final : OpRewritePattern { } }; -/** - * @brief Remove trivial RZZ operations. - */ -struct RemoveTrivialRZZ final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RZZOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void RZZOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -77,10 +67,20 @@ void RZZOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand); } +LogicalResult RZZOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void RZZOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add( - context); + results.add(context); } std::optional RZZOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp index 221f8d6d3c..79dc169ae1 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp @@ -9,7 +9,6 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" -#include "mlir/Dialect/QCO/QCOUtils.h" #include "mlir/Dialect/Utils/Utils.h" #include @@ -20,6 +19,7 @@ #include #include #include +#include #include #include @@ -77,18 +77,6 @@ struct MergeSubsequentXXMinusYY final : OpRewritePattern { } }; -/** - * @brief Remove trivial XXMinusYY operations. - */ -struct RemoveTrivialXXMinusYY final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(XXMinusYYOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void XXMinusYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -101,9 +89,20 @@ void XXMinusYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand, betaOperand); } +LogicalResult XXMinusYYOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void XXMinusYYOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional XXMinusYYOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp index d920e4777f..03d8853371 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp @@ -9,7 +9,6 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" -#include "mlir/Dialect/QCO/QCOUtils.h" #include "mlir/Dialect/Utils/Utils.h" #include @@ -19,6 +18,7 @@ #include #include #include +#include #include #include @@ -76,18 +76,6 @@ struct MergeSubsequentXXPlusYY final : OpRewritePattern { } }; -/** - * @brief Remove trivial XXPlusYY operations. - */ -struct RemoveTrivialXXPlusYY final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(XXPlusYYOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void XXPlusYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -100,9 +88,20 @@ void XXPlusYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand, betaOperand); } +LogicalResult XXPlusYYOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void XXPlusYYOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional XXPlusYYOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt index 3b965dce5e..2268584167 100644 --- a/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt @@ -15,7 +15,6 @@ add_mlir_library( PRIVATE MLIRQCODialect MLIRQCOUtils - ${dialect_libs} DEPENDS MLIRQCOTransformsIncGen) diff --git a/mlir/lib/Dialect/QIR/Builder/CMakeLists.txt b/mlir/lib/Dialect/QIR/Builder/CMakeLists.txt index 1cdc4b7527..78e8e29c9e 100644 --- a/mlir/lib/Dialect/QIR/Builder/CMakeLists.txt +++ b/mlir/lib/Dialect/QIR/Builder/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_library( LINK_LIBS PUBLIC MLIRLLVMDialect + MLIRMemRefDialect MLIRQIRUtils MLIRIR MLIRSupport) diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index 5ca6cfd050..28eb571b16 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -67,15 +68,15 @@ void QIRProgramBuilder::initialize() { measurementsBlock = mainFuncOp.addBlock(); outputBlock = mainFuncOp.addBlock(); - // Create exit code constant in entry block (where constants belong) and add - // QIR initialization call in entry block (after exit code constant) + // Create exit code constant in entry block setInsertionPointToStart(entryBlock); - auto zeroOp = LLVM::ZeroOp::create(*this, ptrType); exitCode = intConstant(0); - const auto initType = LLVM::LLVMFunctionType::get(voidType, ptrType); - auto initFunc = - getOrCreateFunctionDeclaration(*this, module, QIR_INITIALIZE, initType); - LLVM::CallOp::create(*this, initFunc, zeroOp.getResult()); + + auto initSig = LLVM::LLVMFunctionType::get(voidType, ptrType); + auto initDec = + getOrCreateFunctionDeclaration(*this, module, QIR_INITIALIZE, initSig); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + LLVM::CallOp::create(*this, initDec, zero.getResult()); // Add unconditional branches between blocks setInsertionPointToEnd(entryBlock); @@ -112,14 +113,13 @@ Value QIRProgramBuilder::staticQubit(const int64_t index) { llvm::reportFatalUsageError("Index must be non-negative"); } - // Check cache - Value val{}; - if (const auto it = ptrCache.find(index); it != ptrCache.end()) { - val = it->second; + Value qubit; + if (const auto it = staticQubits.find(index); it != staticQubits.end()) { + qubit = it->second; } else { - val = createPointerFromIndex(*this, getLoc(), index); + qubit = createPointerFromIndex(*this, getLoc(), index); // Cache for reuse - ptrCache[index] = val; + staticQubits[index] = qubit; } // Update qubit count @@ -127,7 +127,7 @@ Value QIRProgramBuilder::staticQubit(const int64_t index) { metadata_.numQubits = static_cast(index) + 1; } - return val; + return qubit; } SmallVector QIRProgramBuilder::allocQubitRegister(const int64_t size) { @@ -137,11 +137,36 @@ SmallVector QIRProgramBuilder::allocQubitRegister(const int64_t size) { llvm::reportFatalUsageError("Size must be positive"); } + metadata_.useDynamicQubit = true; + + // Save current insertion point + const InsertionGuard guard(*this); + + // Insert allocations and constants in entry block + setInsertionPoint(entryBlock->getTerminator()); + SmallVector qubits; qubits.reserve(size); + auto allocFnSignature = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(getContext()), {getI64Type(), ptrType, ptrType}); + auto allocFnDecl = getOrCreateFunctionDeclaration( + *this, module, QIR_QUBIT_ARRAY_ALLOC, allocFnSignature); + + auto array = + LLVM::AllocaOp::create(*this, ptrType, ptrType, intConstant(size)); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + LLVM::CallOp::create( + *this, allocFnDecl, + ValueRange{intConstant(size), array.getResult(), zero.getResult()}); + + qubitArrays.insert(array.getResult()); + for (int64_t i = 0; i < size; ++i) { - qubits.push_back(staticQubit(static_cast(metadata_.numQubits))); + auto gep = LLVM::GEPOp::create(*this, ptrType, ptrType, array.getResult(), + ValueRange{intConstant(i)}); + auto load = LLVM::LoadOp::create(*this, ptrType, gep.getResult()); + qubits.push_back(load.getResult()); } return qubits; @@ -156,25 +181,43 @@ QIRProgramBuilder::allocClassicalBitRegister(const int64_t size, llvm::reportFatalUsageError("Size must be positive"); } + if (name.starts_with("__unnamed__")) { + llvm::reportFatalUsageError( + "Classical register names starting with '__unnamed__' are reserved"); + } + if (resultArrays.contains(name)) { + llvm::reportFatalUsageError("Classical register already exists"); + } + + metadata_.useDynamicResult = true; + // Save current insertion point const InsertionGuard guard(*this); - // Insert in measurements block (before branch) - setInsertionPoint(measurementsBlock->getTerminator()); + // Insert allocations and constants in entry block + setInsertionPoint(entryBlock->getTerminator()); + + auto allocFnSignature = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(getContext()), {getI64Type(), ptrType, ptrType}); + auto allocFnDecl = getOrCreateFunctionDeclaration( + *this, module, QIR_RESULT_ARRAY_ALLOC, allocFnSignature); + + auto array = + LLVM::AllocaOp::create(*this, ptrType, ptrType, intConstant(size)); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + LLVM::CallOp::create( + *this, allocFnDecl, + ValueRange{intConstant(size), array.getResult(), zero.getResult()}); + + resultArrays.try_emplace(name, array.getResult()); - const auto numResults = static_cast(metadata_.numResults); for (int64_t i = 0; i < size; ++i) { - Value val{}; - if (const auto it = ptrCache.find(numResults + i); it != ptrCache.end()) { - val = it->second; - } else { - val = createPointerFromIndex(*this, getLoc(), numResults + i); - // Cache for reuse - ptrCache[numResults + i] = val; - } - registerResultMap.insert({{stringSaver.save(name), i}, val}); + auto gep = LLVM::GEPOp::create(*this, ptrType, ptrType, array.getResult(), + ValueRange{intConstant(i)}); + auto load = LLVM::LoadOp::create(*this, ptrType, gep.getResult()); + loadedResults.try_emplace({stringSaver.save(name), i}, load.getResult()); } - metadata_.numResults += size; + return {.name = name, .size = size}; } @@ -185,76 +228,60 @@ Value QIRProgramBuilder::measure(Value qubit, const int64_t resultIndex) { llvm::reportFatalUsageError("Result index must be non-negative"); } - // Choose a safe default register name - static constexpr llvm::StringLiteral DEFAULT_REG_NAME = "c"; - StringRef regName{DEFAULT_REG_NAME}; - if (llvm::any_of(registerResultMap, [](const auto& entry) { - return entry.first.first == DEFAULT_REG_NAME; - })) { - static constexpr llvm::StringLiteral FALLBACK_REG_NAME = "__unnamed__"; - regName = FALLBACK_REG_NAME; - } + metadata_.useDynamicResult = true; // Save current insertion point const InsertionGuard guard(*this); - // Insert in measurements block (before branch) - setInsertionPoint(measurementsBlock->getTerminator()); - - const auto key = std::make_pair(regName, resultIndex); - if (const auto it = registerResultMap.find(key); - it != registerResultMap.end()) { - return it->second; - } + // Insert allocations and constants in entry block + setInsertionPoint(entryBlock->getTerminator()); - Value resultValue{}; - if (const auto it = ptrCache.find(resultIndex); it != ptrCache.end()) { - resultValue = it->second; + // Get or create result pointer + Value result; + if (const auto it = resultPtrs.find(resultIndex); it != resultPtrs.end()) { + result = it->second; } else { - resultValue = createPointerFromIndex(*this, getLoc(), resultIndex); - ptrCache[resultIndex] = resultValue; - registerResultMap.try_emplace(key, resultValue); + auto fnSig = LLVM::LLVMFunctionType::get(ptrType, {ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(*this, module, QIR_RESULT_ALLOC, fnSig); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + result = LLVM::CallOp::create(*this, fnDec, zero.getResult()).getResult(); + resultPtrs.try_emplace(resultIndex, result); } - // Update result count - if (std::cmp_greater_equal(resultIndex, metadata_.numResults)) { - metadata_.numResults = static_cast(resultIndex) + 1; - } + // Switch to measurements block + setInsertionPoint(measurementsBlock->getTerminator()); - // Create mz call - const auto mzSignature = - LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); - auto mzDecl = - getOrCreateFunctionDeclaration(*this, module, QIR_MEASURE, mzSignature); - LLVM::CallOp::create(*this, mzDecl, ValueRange{qubit, resultValue}); + // Create measure call + const auto mzSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); + auto mzDec = + getOrCreateFunctionDeclaration(*this, module, QIR_MEASURE, mzSig); + LLVM::CallOp::create(*this, mzDec, ValueRange{qubit, result}); - return resultValue; + return result; } QIRProgramBuilder& QIRProgramBuilder::measure(Value qubit, const Bit& bit) { checkFinalized(); + auto it = loadedResults.find({bit.registerName, bit.registerIndex}); + if (it == loadedResults.end()) { + llvm::reportFatalUsageError( + "Bit does not belong to an allocated classical register"); + } + auto result = it->second; + // Save current insertion point const InsertionGuard guard(*this); - // Insert in measurements block (before branch) + // Switch to measurements block setInsertionPoint(measurementsBlock->getTerminator()); - // Check if we already have a result pointer for this register slot - const auto& registerName = bit.registerName; - const auto registerIndex = bit.registerIndex; - const auto key = std::make_pair(registerName, registerIndex); - if (!registerResultMap.contains(key)) { - llvm::reportFatalInternalError("Result pointer not found"); - } - const auto resultValue = registerResultMap.at(key); - - // Create mz call - const auto mzSignature = - LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); - auto mzDecl = - getOrCreateFunctionDeclaration(*this, module, QIR_MEASURE, mzSignature); - LLVM::CallOp::create(*this, mzDecl, ValueRange{qubit, resultValue}); + // Create measure call + const auto fnSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(*this, module, QIR_MEASURE, fnSig); + LLVM::CallOp::create(*this, fnDec, ValueRange{qubit, result}); return *this; } @@ -265,7 +292,7 @@ QIRProgramBuilder& QIRProgramBuilder::reset(Value qubit) { // Save current insertion point const InsertionGuard guard(*this); - // Insert in measurements block (before branch) + // Switch to measurements block setInsertionPoint(measurementsBlock->getTerminator()); // Create reset call @@ -578,7 +605,7 @@ void QIRProgramBuilder::checkFinalized() const { } void QIRProgramBuilder::generateOutputRecording() { - if (registerResultMap.empty()) { + if (resultArrays.empty() && resultPtrs.empty()) { return; // No measurements to record } @@ -588,55 +615,48 @@ void QIRProgramBuilder::generateOutputRecording() { // Insert in output block (before return) setInsertionPoint(outputBlock->getTerminator()); - // Group measurements by register - llvm::StringMap>> registerGroups; - for (const auto& [key, resultPtr] : registerResultMap) { - const auto& [regName, regIdx] = key; - registerGroups[regName].emplace_back(regIdx, resultPtr); + if (!resultPtrs.empty()) { + // Sort result pointers for deterministic output + llvm::SmallVector> sortedPtrs; + for (const auto& [index, resultPtr] : resultPtrs) { + sortedPtrs.emplace_back(index, resultPtr); + } + llvm::sort(sortedPtrs, + [](const auto& a, const auto& b) { return a.first < b.first; }); + + // Create output recording for each result pointer + auto fnSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(*this, module, QIR_RECORD_OUTPUT, fnSig); + + for (const auto& [index, ptr] : sortedPtrs) { + auto label = createResultLabel(*this, module, + "__unnamed__" + std::to_string(index)) + .getResult(); + LLVM::CallOp::create(*this, fnDec, ValueRange{ptr, label}); + } } - // Sort registers by name for deterministic output - SmallVector>>> - sortedRegisters; - for (auto& [name, measurements] : registerGroups) { - sortedRegisters.emplace_back(name, std::move(measurements)); - } - sort(sortedRegisters, - [](const auto& a, const auto& b) { return a.first < b.first; }); - - // Create array_record_output call - const auto arrayRecordSig = - LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType}); - const auto arrayRecordDecl = getOrCreateFunctionDeclaration( - *this, module, QIR_ARRAY_RECORD_OUTPUT, arrayRecordSig); - - // Create result_record_output calls for each measurement - const auto resultRecordSig = - LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); - const auto resultRecordDecl = getOrCreateFunctionDeclaration( - *this, module, QIR_RECORD_OUTPUT, resultRecordSig); - - // Generate output recording for each register - for (auto& [registerName, measurements] : sortedRegisters) { - // Sort measurements by register index - sort(measurements, - [](const auto& a, const auto& b) { return a.first < b.first; }); - - const auto arraySize = measurements.size(); - auto arrayLabelOp = createResultLabel(*this, module, registerName); - auto arraySizeConst = intConstant(static_cast(arraySize)); - - LLVM::CallOp::create(*this, arrayRecordDecl, - ValueRange{arraySizeConst, arrayLabelOp.getResult()}); - - for (const auto& [regIdx, resultPtr] : measurements) { - // Create label for result: "{registerName}{regIdx}r" - const std::string resultLabel = - registerName + std::to_string(regIdx) + "r"; - auto resultLabelOp = createResultLabel(*this, module, resultLabel); - - LLVM::CallOp::create(*this, resultRecordDecl, - ValueRange{resultPtr, resultLabelOp.getResult()}); + if (!resultArrays.empty()) { + // Sort registers by name for deterministic output + SmallVector> sortedArrays; + for (auto& [name, results] : resultArrays) { + sortedArrays.emplace_back(name, results); + } + llvm::sort(sortedArrays, + [](const auto& a, const auto& b) { return a.first < b.first; }); + + auto fnSig = + LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration(*this, module, + QIR_ARRAY_RECORD_OUTPUT, fnSig); + + // Create output recording for each register + for (auto& [name, results] : sortedArrays) { + auto size = results.getDefiningOp().getArraySize(); + auto label = createResultLabel(*this, module, name).getResult(); + + LLVM::CallOp::create(*this, fnDec, ValueRange{size, results, label}); } } } @@ -644,9 +664,38 @@ void QIRProgramBuilder::generateOutputRecording() { OwningOpRef QIRProgramBuilder::finalize() { checkFinalized(); - // Generate output recording in the output block + // Save current insertion point + const InsertionGuard guard(*this); + + // Release resources in output block + setInsertionPoint(outputBlock->getTerminator()); + + for (auto array : qubitArrays) { + auto sig = LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType}); + auto dec = getOrCreateFunctionDeclaration(*this, module, + QIR_QUBIT_ARRAY_RELEASE, sig); + auto size = array.getDefiningOp().getArraySize(); + LLVM::CallOp::create(*this, dec, ValueRange{size, array}); + } + + // Generate output recording in output block generateOutputRecording(); + for (auto& [_, ptr] : resultPtrs) { + auto sig = LLVM::LLVMFunctionType::get(voidType, {ptrType}); + auto dec = + getOrCreateFunctionDeclaration(*this, module, QIR_RESULT_RELEASE, sig); + LLVM::CallOp::create(*this, dec, ptr); + } + + for (auto& [_, array] : resultArrays) { + auto sig = LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType}); + auto dec = getOrCreateFunctionDeclaration(*this, module, + QIR_RESULT_ARRAY_RELEASE, sig); + auto size = array.getDefiningOp().getArraySize(); + LLVM::CallOp::create(*this, dec, ValueRange{size, array}); + } + auto mainFuncOp = llvm::cast(mainFunc); setQIRAttributes(mainFuncOp, metadata_); diff --git a/mlir/lib/Dialect/QIR/CMakeLists.txt b/mlir/lib/Dialect/QIR/CMakeLists.txt index e2d07911c5..2aa54c44ff 100644 --- a/mlir/lib/Dialect/QIR/CMakeLists.txt +++ b/mlir/lib/Dialect/QIR/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(Utils) add_subdirectory(Builder) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/QIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QIR/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..95afa87e3f --- /dev/null +++ b/mlir/lib/Dialect/QIR/Transforms/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +file(GLOB_RECURSE PASSES_SOURCES *.cpp) + +add_mlir_library( + MLIRQIRTransforms + ${PASSES_SOURCES} + LINK_LIBS + PRIVATE + MLIRLLVMDialect + MLIRQIRUtils + DEPENDS + MLIRQIRTransformsIncGen) + +mqt_mlir_target_use_project_options(MLIRQIRTransforms) + +# collect header files +file(GLOB_RECURSE PASSES_HEADERS_SOURCE + ${MQT_MLIR_SOURCE_INCLUDE_DIR}/mlir/Dialect/QIR/Transforms/*.h) +file(GLOB_RECURSE PASSES_HEADERS_BUILD + ${MQT_MLIR_BUILD_INCLUDE_DIR}/mlir/Dialect/QIR/Transforms/*.inc) + +# add public headers using file sets +target_sources( + MLIRQIRTransforms + PUBLIC FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_SOURCE_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_SOURCE} + FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_BUILD_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_BUILD}) diff --git a/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp b/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp new file mode 100644 index 0000000000..f1188ce19a --- /dev/null +++ b/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp @@ -0,0 +1,226 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QIR/Transforms/Passes.h" +#include "mlir/Dialect/QIR/Utils/QIRUtils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace mlir::qir { + +#define GEN_PASS_DEF_QIRCLEANUPPASS +#include "mlir/Dialect/QIR/Transforms/Passes.h.inc" + +[[nodiscard]] static StringAttr getMetadataKey(const Attribute attr) { + auto pair = llvm::dyn_cast(attr); + if (!pair || pair.size() != 2) { + return {}; + } + auto key = llvm::dyn_cast(pair[0]); + if (!key || !llvm::isa(pair[1])) { + return {}; + } + return key; +} + +[[nodiscard]] static llvm::StringRef getCalleeName(LLVM::CallOp callOp) { + auto calleeAttr = callOp.getCalleeAttr(); + auto flatRef = llvm::dyn_cast_or_null(calleeAttr); + if (!flatRef) { + return {}; + } + return flatRef.getValue(); +} + +[[nodiscard]] static bool moduleHasDynamicQubitRuntimeCalls(ModuleOp module) { + return llvm::any_of(module.getOps(), [](LLVM::CallOp callOp) { + const auto callee = getCalleeName(callOp); + return callee == QIR_QUBIT_ALLOC || callee == QIR_QUBIT_ARRAY_ALLOC; + }); +} + +[[nodiscard]] static bool moduleHasDynamicResultRuntimeCalls(ModuleOp module) { + return llvm::any_of(module.getOps(), [](LLVM::CallOp callOp) { + const auto callee = getCalleeName(callOp); + return callee == QIR_RESULT_ALLOC || callee == QIR_RESULT_ARRAY_ALLOC; + }); +} + +static void dropUnusedExternalDeclarations(ModuleOp module) { + for (auto funcOp : + llvm::make_early_inc_range(module.getOps())) { + if (!funcOp.isExternal()) { + continue; + } + if (!SymbolTable::symbolKnownUseEmpty(funcOp, module)) { + continue; + } + funcOp.erase(); + } +} + +static void normalizeQIRMetadata(ModuleOp module) { + auto main = getMainFunction(module); + if (!main) { + return; + } + + auto passthroughAttr = main->getAttrOfType("passthrough"); + if (!passthroughAttr) { + return; + } + + const bool hasDynamicQubit = moduleHasDynamicQubitRuntimeCalls(module); + const bool hasDynamicResult = moduleHasDynamicResultRuntimeCalls(module); + if (hasDynamicQubit && hasDynamicResult) { + return; + } + + ArrayAttr requiredNumQubitsAttr = nullptr; + ArrayAttr requiredNumResultsAttr = nullptr; + for (const auto attr : passthroughAttr) { + const auto key = getMetadataKey(attr); + if (!key) { + continue; + } + if (key.getValue() == "required_num_qubits") { + requiredNumQubitsAttr = llvm::cast(attr); + } else if (key.getValue() == "required_num_results") { + requiredNumResultsAttr = llvm::cast(attr); + } + } + + OpBuilder builder(module.getContext()); + SmallVector updatedMetadata; + updatedMetadata.reserve(passthroughAttr.size() + 2); + + for (const auto attr : passthroughAttr) { + const auto key = getMetadataKey(attr); + if (!key) { + updatedMetadata.push_back(attr); + continue; + } + + if (key.getValue() == "dynamic_qubit_management" && !hasDynamicQubit) { + if (requiredNumQubitsAttr) { + updatedMetadata.push_back(requiredNumQubitsAttr); + } + continue; + } + if (key.getValue() == "dynamic_result_management" && !hasDynamicResult) { + if (requiredNumResultsAttr) { + updatedMetadata.push_back(requiredNumResultsAttr); + } + continue; + } + + updatedMetadata.push_back(attr); + } + + main->setAttr("passthrough", builder.getArrayAttr(updatedMetadata)); +} + +namespace { + +/** + * @brief Remove matching allocation-release pairs of qubit arrays. + * @details Matches an unused + * `__quantum__rt__qubit_array_allocate`-`__quantum__rt__qubit_array_release` + * pair on the same stack slot. + */ +struct RemoveDeadQubitArrayPair final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp releaseCall, + PatternRewriter& rewriter) const override { + if (getCalleeName(releaseCall) != QIR_QUBIT_ARRAY_RELEASE || + releaseCall.getNumOperands() < 2) { + return failure(); + } + + auto allocaOp = releaseCall.getOperand(1).getDefiningOp(); + if (!allocaOp) { + return failure(); + } + + LLVM::CallOp allocCall = nullptr; + for (Operation* user : allocaOp.getResult().getUsers()) { + auto callOp = llvm::dyn_cast(user); + if (!callOp) { + return failure(); + } + + if (callOp == releaseCall) { + continue; + } + + if (getCalleeName(callOp) != QIR_QUBIT_ARRAY_ALLOC || + callOp.getNumOperands() < 2 || + callOp.getOperand(1) != allocaOp.getResult()) { + return failure(); + } + if (allocCall != nullptr) { + return failure(); + } + allocCall = callOp; + } + + if (!allocCall) { + return failure(); + } + + rewriter.eraseOp(releaseCall); + rewriter.eraseOp(allocCall); + if (allocaOp->use_empty()) { + rewriter.eraseOp(allocaOp); + } + return success(); + } +}; + +/** + * @brief Clean up QIR. + * @details Removes dead allocation-release pairs of qubit arrays, drops unused + * external declarations, and normalizes QIR metadata. + */ +struct QIRCleanupPass final : impl::QIRCleanupPassBase { +protected: + void runOnOperation() override { + auto module = getOperation(); + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + signalPassFailure(); + return; + } + + dropUnusedExternalDeclarations(module); + normalizeQIRMetadata(module); + } +}; + +} // namespace + +} // namespace mlir::qir diff --git a/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp b/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp index 1093921606..b698c06dc7 100644 --- a/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp +++ b/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp @@ -57,6 +57,11 @@ LLVM::LLVMFuncOp getMainFunction(Operation* op) { } void setQIRAttributes(LLVM::LLVMFuncOp& main, const QIRMetadata& metadata) { + if (metadata.useDynamicQubit && metadata.numQubits != 0) { + llvm::reportFatalUsageError( + "Cannot use dynamic qubit allocation if static qubits are allocated"); + } + OpBuilder builder(main.getBody()); llvm::SmallVector attributes; @@ -73,17 +78,26 @@ void setQIRAttributes(LLVM::LLVMFuncOp& main, const QIRMetadata& metadata) { attributes.emplace_back(builder.getStrArrayAttr( {"required_num_results", std::to_string(metadata.numResults)})); - // QIR version (Base Profile spec requires version 2.0) - attributes.emplace_back(builder.getStrArrayAttr({"qir_major_version", "2"})); - attributes.emplace_back(builder.getStrArrayAttr({"qir_minor_version", "0"})); + // Management model or resource requirements + if (metadata.useDynamicQubit) { + attributes.emplace_back( + builder.getStrArrayAttr({"dynamic_qubit_management", "true"})); + } else { + attributes.emplace_back(builder.getStrArrayAttr( + {"required_num_qubits", std::to_string(metadata.numQubits)})); + } - // Management model - attributes.emplace_back( - builder.getStrArrayAttr({"dynamic_qubit_management", - metadata.useDynamicQubit ? "true" : "false"})); - attributes.emplace_back( - builder.getStrArrayAttr({"dynamic_result_management", - metadata.useDynamicResult ? "true" : "false"})); + if (metadata.useDynamicResult) { + attributes.emplace_back( + builder.getStrArrayAttr({"dynamic_result_management", "true"})); + } else { + attributes.emplace_back(builder.getStrArrayAttr( + {"required_num_results", std::to_string(metadata.numResults)})); + } + + // QIR version (Base Profile spec requires version 2.1) + attributes.emplace_back(builder.getStrArrayAttr({"qir_major_version", "2"})); + attributes.emplace_back(builder.getStrArrayAttr({"qir_minor_version", "1"})); main->setAttr("passthrough", builder.getArrayAttr(attributes)); } diff --git a/mlir/lib/Dialect/QTensor/CMakeLists.txt b/mlir/lib/Dialect/QTensor/CMakeLists.txt index b181a84fed..3b0a561d0f 100644 --- a/mlir/lib/Dialect/QTensor/CMakeLists.txt +++ b/mlir/lib/Dialect/QTensor/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp index 898b8b6412..05978b6f36 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp @@ -45,16 +45,15 @@ LogicalResult AllocOp::verify() { if (sizeValue && *sizeValue <= 0) { return emitOpError("Constant size operand must be positive"); } - if (sizeValue.has_value() == resultType.isDynamicDim(0)) { - return emitOpError("Size operand and result type must both be static or " - "both be dynamic, but got ") - << (sizeValue ? "static size with dynamic result" - : "dynamic size with static result"); - } - if (sizeValue && resultSize != *sizeValue) { - return emitOpError("Constant size operand (") - << *sizeValue << ") does not match static result size (" - << resultSize << ")"; + if (!resultType.isDynamicDim(0)) { + if (!sizeValue) { + return emitOpError("Static result type requires constant size operand"); + } + if (resultSize != *sizeValue) { + return emitOpError("Constant size operand (") + << *sizeValue << ") does not match static result size (" + << resultSize << ")"; + } } return success(); diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp index 90f076ede1..4abddab7cc 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp @@ -20,7 +20,7 @@ using namespace mlir::qtensor; namespace { /** - * @brief Remove matching allocation and deallocation pairs without operations + * @brief Remove matching allocation-deallocation pairs without operations * between them. */ struct RemoveAllocDeallocPair final : OpRewritePattern { @@ -35,7 +35,7 @@ struct RemoveAllocDeallocPair final : OpRewritePattern { return failure(); } - // Remove the AllocOp and the DeallocOp + // Remove the AllocOp and the DeallocOp. rewriter.eraseOp(op); rewriter.eraseOp(allocOp); return success(); diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 27e8de6995..a899d3ae41 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -33,34 +33,3 @@ LogicalResult ExtractOp::verify() { } return success(); } - -/** - * @brief If an ExtractOp consumes an InsertOp with the same index, - * return the scalar and the destTensor from the InsertOp directly. - */ -static InsertOp foldExtractAfterInsert(ExtractOp extractOp) { - auto insertOp = extractOp.getTensor().getDefiningOp(); - if (!insertOp) { - return nullptr; - } - - Value insertIndex = insertOp.getIndex(); - Value extractIndex = extractOp.getIndex(); - - if (getAsOpFoldResult(insertIndex) != getAsOpFoldResult(extractIndex)) { - return nullptr; - } - - return insertOp; -} - -LogicalResult ExtractOp::fold(FoldAdaptor /*adaptor*/, - SmallVectorImpl& results) { - if (auto insertOp = foldExtractAfterInsert(*this)) { - results.emplace_back(insertOp.getDest()); - results.emplace_back(insertOp.getScalar()); - return success(); - } - - return failure(); -} diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp deleted file mode 100644 index 3c7648b04d..0000000000 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM - * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH - * All rights reserved. - * - * SPDX-License-Identifier: MIT - * - * Licensed under the MIT License - */ - -#include "mlir/Dialect/QTensor/IR/QTensorOps.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace mlir; -using namespace mlir::qtensor; - -void ExtractSliceOp::build(OpBuilder& b, OperationState& result, Value tensor, - Value offset, Value size, - ArrayRef attrs) { - auto tensorType = cast(tensor.getType()); - auto sizeValue = getConstantIntValue(size); - auto resultType = RankedTensorType::get( - {sizeValue ? *sizeValue : ShapedType::kDynamic}, - tensorType.getElementType(), tensorType.getEncoding()); - - result.addAttributes(attrs); - build(b, result, {tensor.getType(), resultType}, tensor, offset, size); -} - -LogicalResult ExtractSliceOp::verify() { - auto tensorDim = getTensor().getType().getDimSize(0); - auto resultDim = getResult().getType().getDimSize(0); - auto constOffset = getConstantIntValue(getOffset()); - auto constSize = getConstantIntValue(getSize()); - - if (constOffset && *constOffset < 0) { - return emitOpError("Offset must be non-negative"); - } - - if (constSize && *constSize <= 0) { - return emitOpError("Size must be positive"); - } - - if (constOffset && constSize && !ShapedType::isDynamic(tensorDim)) { - if (*constOffset + *constSize > tensorDim) { - return emitOpError("Offset + Size exceeds source dimension"); - } - } - - if (constSize && !ShapedType::isDynamic(resultDim)) { - if (resultDim != *constSize) { - return emitOpError("Result tensor dimension must match size operand"); - } - } - - return success(); -} - -/** - * @brief If an ExtractSliceOp consumes an InsertSliceOp with the same offset - * and size, return the sourceTensor and the destTensor from the InsertSliceOp - * directly. - */ -static InsertSliceOp -foldExtractAfterInsertSlice(ExtractSliceOp extractSliceOp) { - auto insertSliceOp = - extractSliceOp.getTensor().getDefiningOp(); - if (!insertSliceOp) { - return nullptr; - } - - auto insertOffset = insertSliceOp.getOffset(); - auto extractOffset = extractSliceOp.getOffset(); - auto insertSize = insertSliceOp.getSize(); - auto extractSize = extractSliceOp.getSize(); - - if (getAsOpFoldResult(insertOffset) != getAsOpFoldResult(extractOffset) || - getAsOpFoldResult(insertSize) != getAsOpFoldResult(extractSize)) { - return nullptr; - } - - return insertSliceOp; -} - -LogicalResult ExtractSliceOp::fold(FoldAdaptor /*adaptor*/, - SmallVectorImpl& results) { - if (auto insertOp = foldExtractAfterInsertSlice(*this)) { - results.emplace_back(insertOp.getDest()); - results.emplace_back(insertOp.getSource()); - return success(); - } - - return failure(); -} diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 982d4a6335..adeac1cb8d 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -9,60 +9,146 @@ */ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" +#include #include #include +#include #include +#include +#include #include #include using namespace mlir; using namespace mlir::qtensor; -LogicalResult InsertOp::verify() { - auto dstDim = getDest().getType().getDimSize(0); - auto index = getConstantIntValue(getIndex()); - - if (index) { - if (*index < 0) { - return emitOpError("Index must be non-negative"); - } - if (!ShapedType::isDynamic(dstDim) && *index >= dstDim) { - return emitOpError("Index exceeds tensor dimension"); - } - } - - return success(); +/** + * @brief Checks whether removing an extract-insert pair is linearity-safe. + */ +static bool isRemovableExtractInsertPair(InsertOp insertOp, + ExtractOp extractOp) { + return insertOp.getScalar() == extractOp.getResult() && + areEquivalentIndices(insertOp.getIndex(), extractOp.getIndex()); } /** - * @brief If an InsertOp consumes an ExtractOp with the same index, - * return the tensor from the extractOp directly. + * @brief Folds an insert operation after a matching extract operation into the + * original tensor. */ static Value foldInsertAfterExtract(InsertOp insertOp) { auto extractOp = insertOp.getScalar().getDefiningOp(); - if (!extractOp) { return nullptr; } + if (insertOp.getDest() != extractOp.getOutTensor()) { return nullptr; } + if (!isRemovableExtractInsertPair(insertOp, extractOp)) { + return nullptr; + } + + return extractOp.getTensor(); +} + +/** + * @brief Finds the extract operation corresponding to a given insert operation. + * + * @details The function traverses the tensor chain of the insert operation + * until it finds the matching extract operation. + */ +static ExtractOp findMatchingExtractInTensorChain(InsertOp insertOp) { + auto current = insertOp.getDest(); auto insertIndex = insertOp.getIndex(); - auto extractIndex = extractOp.getIndex(); - if (getAsOpFoldResult(insertIndex) != getAsOpFoldResult(extractIndex)) { + if (!getConstantIntValue(insertIndex)) { return nullptr; } - return extractOp.getTensor(); + while (auto* definingOp = current.getDefiningOp()) { + if (auto nestedInsertOp = llvm::dyn_cast(definingOp)) { + auto nestedInsertIndex = nestedInsertOp.getIndex(); + if (!getConstantIntValue(nestedInsertIndex)) { + return nullptr; + } + // A more recent write to the same index shadows all older extracts + if (areEquivalentIndices(nestedInsertIndex, insertIndex)) { + return nullptr; + } + current = nestedInsertOp.getDest(); + continue; + } + if (auto extractOp = llvm::dyn_cast(definingOp)) { + auto extractIndex = extractOp.getIndex(); + if (!getConstantIntValue(extractIndex)) { + return nullptr; + } + if (areEquivalentIndices(extractIndex, insertIndex)) { + return extractOp; + } + current = extractOp.getTensor(); + continue; + } + break; + } + return nullptr; +} + +namespace { + +/** + * @brief Remove matching extract-insert pairs. + */ +struct RemoveExtractInsertPair final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertOp op, + PatternRewriter& rewriter) const override { + auto extractOp = findMatchingExtractInTensorChain(op); + if (!extractOp) { + return failure(); + } + + if (!isRemovableExtractInsertPair(op, extractOp)) { + return failure(); + } + + rewriter.replaceOp(op, op.getDest()); + rewriter.replaceOp(extractOp, {extractOp.getTensor(), nullptr}); + + return success(); + } +}; + +} // namespace + +LogicalResult InsertOp::verify() { + auto dstDim = getDest().getType().getDimSize(0); + auto index = getConstantIntValue(getIndex()); + + if (index) { + if (*index < 0) { + return emitOpError("Index must be non-negative"); + } + if (!ShapedType::isDynamic(dstDim) && *index >= dstDim) { + return emitOpError("Index exceeds tensor dimension"); + } + } + + return success(); } OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { if (auto result = foldInsertAfterExtract(*this)) { return result; } - return {}; } + +void InsertOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp deleted file mode 100644 index 1a7b6526ab..0000000000 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM - * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH - * All rights reserved. - * - * SPDX-License-Identifier: MIT - * - * Licensed under the MIT License - */ - -#include "mlir/Dialect/QTensor/IR/QTensorOps.h" - -#include -#include -#include -#include -#include -#include - -using namespace mlir; -using namespace mlir::qtensor; - -LogicalResult InsertSliceOp::verify() { - auto srcDim = getSource().getType().getDimSize(0); - auto dstDim = getDest().getType().getDimSize(0); - auto constOffset = getConstantIntValue(getOffset()); - auto constSize = getConstantIntValue(getSize()); - - if (constOffset && *constOffset < 0) { - return emitOpError("Offset must be non-negative"); - } - - if (constSize && *constSize <= 0) { - return emitOpError("Size must be positive"); - } - - if (constSize && !ShapedType::isDynamic(srcDim)) { - if (*constSize != srcDim) { - return emitOpError("Size must match source dimension"); - } - } - - if (constOffset && constSize && !ShapedType::isDynamic(dstDim)) { - if (*constSize > dstDim || *constOffset > dstDim - *constSize) { - return emitOpError("Offset + Size exceeds destination dimension"); - } - } - - return success(); -} - -/** - * @brief If an InsertSliceOp consumes an ExtractSliceOp with the same offset - * and size, return the sourceTensor from the extractSliceOp directly. - */ -static Value foldInsertAfterExtractSlice(InsertSliceOp insertSliceOp) { - auto extractSliceOp = - insertSliceOp.getSource().getDefiningOp(); - if (!extractSliceOp) { - return nullptr; - } - - if (extractSliceOp.getOutTensor() != insertSliceOp.getDest()) { - return nullptr; - } - - auto insertOffset = insertSliceOp.getOffset(); - auto extractOffset = extractSliceOp.getOffset(); - auto insertSize = insertSliceOp.getSize(); - auto extractSize = extractSliceOp.getSize(); - - if (getAsOpFoldResult(insertOffset) != getAsOpFoldResult(extractOffset) || - getAsOpFoldResult(insertSize) != getAsOpFoldResult(extractSize)) { - return nullptr; - } - - return extractSliceOp.getTensor(); -} - -OpFoldResult InsertSliceOp::fold(FoldAdaptor /*adaptor*/) { - if (auto result = foldInsertAfterExtractSlice(*this)) { - return result; - } - - return {}; -} diff --git a/mlir/lib/Dialect/QTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QTensor/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..2c110797bd --- /dev/null +++ b/mlir/lib/Dialect/QTensor/Transforms/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +file(GLOB_RECURSE PASSES_SOURCES *.cpp) + +add_mlir_library( + MLIRQTensorTransforms + ${PASSES_SOURCES} + LINK_LIBS + PRIVATE + MLIRQTensorDialect + DEPENDS + MLIRQTensorTransformsIncGen) + +# collect header files +file(GLOB_RECURSE PASSES_HEADERS_SOURCE + ${MQT_MLIR_SOURCE_INCLUDE_DIR}/mlir/Dialect/QTensor/Transforms/*.h) +file(GLOB_RECURSE PASSES_HEADERS_BUILD + ${MQT_MLIR_BUILD_INCLUDE_DIR}/mlir/Dialect/QTensor/Transforms/*.inc) + +# add public headers using file sets +target_sources( + MLIRQTensorTransforms + PUBLIC FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_SOURCE_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_SOURCE} + FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_BUILD_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_BUILD}) diff --git a/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp new file mode 100644 index 0000000000..feb64d6e15 --- /dev/null +++ b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp @@ -0,0 +1,291 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace mlir::qtensor { + +#define GEN_PASS_DEF_SHRINKQTENSORTOFITPASS +#include "mlir/Dialect/QTensor/Transforms/Passes.h.inc" + +/** + * @brief Return the unique user of a linear qtensor value. + */ +[[nodiscard]] static Operation* getLinearTensorUser(const Value tensor) { + assert(tensor.hasOneUse() && "Expected a linear tensor with exactly one use"); + return *tensor.getUsers().begin(); +} + +/** + * @brief Mark a single live index. + */ +[[nodiscard]] static LogicalResult markLiveIndex(const int64_t index, + llvm::BitVector& liveIndices) { + if (index < 0 || std::cmp_greater_equal(index, liveIndices.size())) { + return failure(); + } + liveIndices.set(static_cast(index)); + return success(); +} + +/** + * @brief Redirect the tensor operand from @p from to @p to. + */ +[[nodiscard]] static LogicalResult remapTensorOperand(Operation* op, Value from, + Value to) { + if (auto extractOp = llvm::dyn_cast(op)) { + if (extractOp.getTensor() != from) { + return failure(); + } + extractOp->setOperand(0, to); + return success(); + } + if (auto insertOp = llvm::dyn_cast(op)) { + if (insertOp.getDest() != from) { + return failure(); + } + insertOp->setOperand(1, to); + return success(); + } + if (auto deallocOp = llvm::dyn_cast(op)) { + if (deallocOp.getTensor() != from) { + return failure(); + } + deallocOp->setOperand(0, to); + return success(); + } + return failure(); +} + +/** + * @brief Walk alloc->dealloc and collect all touched indices. + */ +[[nodiscard]] static LogicalResult collectLiveIndices(AllocOp allocOp, + llvm::BitVector& live, + DeallocOp& deallocOp) { + auto tensor = allocOp.getResult(); + while (true) { + auto* user = getLinearTensorUser(tensor); + if (user == nullptr) { + return failure(); + } + + if (auto currentDealloc = llvm::dyn_cast(user)) { + if (currentDealloc.getTensor() != tensor) { + return failure(); + } + deallocOp = currentDealloc; + return success(); + } + + if (auto extractOp = llvm::dyn_cast(user)) { + if (extractOp.getTensor() != tensor) { + return failure(); + } + auto index = getConstantIntValue(extractOp.getIndex()); + if (!index || failed(markLiveIndex(*index, live))) { + return failure(); + } + tensor = extractOp.getOutTensor(); + continue; + } + + if (auto insertOp = llvm::dyn_cast(user)) { + if (insertOp.getDest() != tensor) { + return failure(); + } + auto index = getConstantIntValue(insertOp.getIndex()); + if (!index || failed(markLiveIndex(*index, live))) { + return failure(); + } + tensor = insertOp.getResult(); + continue; + } + + return failure(); + } +} + +namespace { + +/** + * @brief Shrink static qtensors by removing never-accessed indices. + * @details QTensor is linear, so this rewrite follows a single use-def chain. + */ +struct ShrinkStaticQTensor final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocOp allocOp, + PatternRewriter& rewriter) const override { + auto oldSize = getConstantIntValue(allocOp.getSize()); + if (!oldSize || *oldSize <= 0) { + return failure(); + } + + llvm::BitVector live(static_cast(*oldSize), false); + DeallocOp oldDeallocOp{}; + if (failed(collectLiveIndices(allocOp, live, oldDeallocOp))) { + return failure(); + } + + if (!oldDeallocOp) { + return failure(); + } + + llvm::SmallVector newIndexByOldIndex(static_cast(*oldSize), + -1); + int64_t newSize = 0; + for (int64_t index = 0; index < *oldSize; ++index) { + if (live.test(static_cast(index))) { + newIndexByOldIndex[static_cast(index)] = newSize++; + } + } + + if (newSize <= 0 || newSize == *oldSize) { + return failure(); + } + + rewriter.setInsertionPoint(allocOp); + auto size = + arith::ConstantIndexOp::create(rewriter, allocOp.getLoc(), newSize); + auto newAlloc = + AllocOp::create(rewriter, allocOp.getLoc(), size.getResult()); + + auto oldTensor = allocOp.getResult(); + auto currentTensor = newAlloc.getResult(); + while (true) { + Operation* currentOp = getLinearTensorUser(oldTensor); + if (currentOp == nullptr) { + return failure(); + } + + if (auto deallocOp = llvm::dyn_cast(currentOp)) { + if (deallocOp != oldDeallocOp || deallocOp.getTensor() != oldTensor) { + return failure(); + } + rewriter.setInsertionPoint(deallocOp); + DeallocOp::create(rewriter, deallocOp.getLoc(), currentTensor); + rewriter.eraseOp(deallocOp); + break; + } + + if (auto extractOp = llvm::dyn_cast(currentOp)) { + if (extractOp.getTensor() != oldTensor) { + return failure(); + } + const auto oldIndex = *getConstantIntValue(extractOp.getIndex()); + if (oldIndex < 0 || + std::cmp_greater_equal(oldIndex, newIndexByOldIndex.size())) { + return failure(); + } + const auto mappedIndex = + newIndexByOldIndex[static_cast(oldIndex)]; + if (mappedIndex < 0) { + return failure(); + } + auto oldOutTensor = extractOp.getOutTensor(); + auto* nextOp = getLinearTensorUser(oldOutTensor); + if (nextOp == nullptr) { + return failure(); + } + + rewriter.setInsertionPoint(extractOp); + auto index = arith::ConstantIndexOp::create( + rewriter, extractOp.getLoc(), mappedIndex); + auto newExtract = ExtractOp::create(rewriter, extractOp.getLoc(), + currentTensor, index.getResult()); + rewriter.replaceAllUsesWith(extractOp.getResult(), + newExtract.getResult()); + + currentTensor = newExtract.getOutTensor(); + if (failed(remapTensorOperand(nextOp, oldOutTensor, oldTensor))) { + return failure(); + } + rewriter.eraseOp(extractOp); + continue; + } + + if (auto insertOp = llvm::dyn_cast(currentOp)) { + if (insertOp.getDest() != oldTensor) { + return failure(); + } + const auto oldIndex = *getConstantIntValue(insertOp.getIndex()); + if (oldIndex < 0 || + std::cmp_greater_equal(oldIndex, newIndexByOldIndex.size())) { + return failure(); + } + const auto mappedIndex = + newIndexByOldIndex[static_cast(oldIndex)]; + if (mappedIndex < 0) { + return failure(); + } + auto oldResultTensor = insertOp.getResult(); + auto* nextOp = getLinearTensorUser(oldResultTensor); + if (nextOp == nullptr) { + return failure(); + } + + rewriter.setInsertionPoint(insertOp); + auto index = arith::ConstantIndexOp::create(rewriter, insertOp.getLoc(), + mappedIndex); + auto newInsert = + InsertOp::create(rewriter, insertOp.getLoc(), insertOp.getScalar(), + currentTensor, index.getResult()); + + currentTensor = newInsert.getResult(); + if (failed(remapTensorOperand(nextOp, oldResultTensor, oldTensor))) { + return failure(); + } + rewriter.eraseOp(insertOp); + continue; + } + + return failure(); + } + + rewriter.eraseOp(allocOp); + return success(); + } +}; + +struct ShrinkQTensorToFitPass final + : impl::ShrinkQTensorToFitPassBase { +protected: + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::qtensor diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt index 0e93033d4a..462d791329 100644 --- a/mlir/lib/Support/CMakeLists.txt +++ b/mlir/lib/Support/CMakeLists.txt @@ -23,7 +23,11 @@ add_mlir_library( MLIRTransformUtils MLIRLLVMDialect MLIRFuncDialect - MLIRArithDialect) + MLIRArithDialect + MLIRQCTransforms + MLIRQIRTransforms + MLIRQTensorTransforms + MLIRQTensorDialect) mqt_mlir_target_use_project_options(MLIRSupportMQT) diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 4c3d470289..db2799a533 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -10,6 +10,8 @@ #include "mlir/Support/IRVerification.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" + #include #include #include @@ -23,6 +25,8 @@ #include #include #include +#include +#include #include #include #include @@ -116,8 +120,261 @@ struct StructuralOperationKey { /// Map to track value equivalence between two modules. using ValueEquivalenceMap = llvm::DenseMap; + +using OperationSet = llvm::DenseSet; + +struct InsertWrite { + Value scalar; + Value index; +}; + +struct InsertChainSummary { + Value baseTensor; + Value finalTensor; + llvm::SmallVector writes; +}; + } // namespace +static bool areValuesEquivalent(Value lhs, Value rhs, + ValueEquivalenceMap& valueMap) { + if (auto it = valueMap.find(lhs); it != valueMap.end()) { + return it->second == rhs; + } + valueMap[lhs] = rhs; + return true; +} + +static bool areIndexValuesEquivalent(Value lhs, Value rhs, + ValueEquivalenceMap& valueMap) { + if (qtensor::areEquivalentIndices(lhs, rhs)) { + return true; + } + return areValuesEquivalent(lhs, rhs, valueMap); +} + +static bool isQTensorInsertOp(Operation* op) { + return llvm::isa(op); +} + +static bool isCommutableQTensorInsertDependency(Operation* dependent, + Operation* dependency) { + auto dependentInsert = llvm::dyn_cast(dependent); + auto dependencyInsert = llvm::dyn_cast(dependency); + if (!dependentInsert || !dependencyInsert) { + return false; + } + if (dependentInsert.getDest() != dependencyInsert.getResult()) { + return false; + } + auto dependentIndex = dependentInsert.getIndex(); + auto dependencyIndex = dependencyInsert.getIndex(); + if (!getConstantIntValue(dependentIndex) || + !getConstantIntValue(dependencyIndex)) { + return false; + } + return !qtensor::areEquivalentIndices(dependentIndex, dependencyIndex); +} + +static Value getInsertChainBaseTensor(Value tensor, const OperationSet& group) { + auto current = tensor; + while (auto insertOp = current.getDefiningOp()) { + if (!group.contains(insertOp.getOperation())) { + break; + } + current = insertOp.getDest(); + } + return current; +} + +static bool +summarizeInsertGroup(llvm::ArrayRef ops, + llvm::SmallVectorImpl& chains) { + OperationSet groupOps; + for (Operation* op : ops) { + groupOps.insert(op); + } + + llvm::DenseSet consumedInsertResults; + for (Operation* op : ops) { + auto insertOp = llvm::cast(op); + if (auto definingInsert = + insertOp.getDest().getDefiningOp()) { + if (groupOps.contains(definingInsert.getOperation())) { + consumedInsertResults.insert(insertOp.getDest()); + } + } + } + + llvm::DenseMap chainByBaseTensor; + for (Operation* op : ops) { + auto insertOp = llvm::cast(op); + const Value baseTensor = + getInsertChainBaseTensor(insertOp.getDest(), groupOps); + + size_t chainIdx = 0; + if (auto it = chainByBaseTensor.find(baseTensor); + it != chainByBaseTensor.end()) { + chainIdx = it->second; + } else { + chainIdx = chains.size(); + chainByBaseTensor[baseTensor] = chainIdx; + InsertChainSummary summary; + summary.baseTensor = baseTensor; + chains.emplace_back(std::move(summary)); + } + + auto& chain = chains[chainIdx]; + chain.writes.push_back(InsertWrite{.scalar = insertOp.getScalar(), + .index = insertOp.getIndex()}); + + if (!consumedInsertResults.contains(insertOp.getResult())) { + if (chain.finalTensor) { + return false; + } + chain.finalTensor = insertOp.getResult(); + } + } + + for (const auto& chain : chains) { + if (!chain.finalTensor) { + return false; + } + + // Reordering writes to the same index is not semantics-preserving. + llvm::SmallVector seenIndices; + for (const auto& write : chain.writes) { + if (llvm::any_of(seenIndices, [&](Value seenIndex) { + return qtensor::areEquivalentIndices(seenIndex, write.index); + })) { + return false; + } + seenIndices.push_back(write.index); + } + } + + return true; +} + +static bool areInsertWritesEquivalentRec(const size_t lhsIdx, + llvm::ArrayRef lhsWrites, + llvm::ArrayRef rhsWrites, + llvm::SmallVectorImpl& rhsUsed, + ValueEquivalenceMap& valueMap) { + if (lhsIdx == lhsWrites.size()) { + return true; + } + + for (size_t rhsIdx = 0; rhsIdx < rhsWrites.size(); ++rhsIdx) { + if (rhsUsed[rhsIdx] != 0) { + continue; + } + + ValueEquivalenceMap tempMap = valueMap; + if (!areValuesEquivalent(lhsWrites[lhsIdx].scalar, rhsWrites[rhsIdx].scalar, + tempMap) || + !areIndexValuesEquivalent(lhsWrites[lhsIdx].index, + rhsWrites[rhsIdx].index, tempMap)) { + continue; + } + + rhsUsed[rhsIdx] = 1; + if (areInsertWritesEquivalentRec(lhsIdx + 1, lhsWrites, rhsWrites, rhsUsed, + tempMap)) { + valueMap = std::move(tempMap); + return true; + } + rhsUsed[rhsIdx] = 0; + } + + return false; +} + +static bool areInsertWritesEquivalent(llvm::ArrayRef lhsWrites, + llvm::ArrayRef rhsWrites, + ValueEquivalenceMap& valueMap) { + if (lhsWrites.size() != rhsWrites.size()) { + return false; + } + llvm::SmallVector rhsUsed(rhsWrites.size(), 0); + return areInsertWritesEquivalentRec(0, lhsWrites, rhsWrites, rhsUsed, + valueMap); +} + +static bool areInsertChainsEquivalent(const InsertChainSummary& lhsChain, + const InsertChainSummary& rhsChain, + ValueEquivalenceMap& valueMap) { + ValueEquivalenceMap tempMap = valueMap; + if (!areValuesEquivalent(lhsChain.baseTensor, rhsChain.baseTensor, tempMap)) { + return false; + } + + if (!areInsertWritesEquivalent(lhsChain.writes, rhsChain.writes, tempMap)) { + return false; + } + + if (!areValuesEquivalent(lhsChain.finalTensor, rhsChain.finalTensor, + tempMap)) { + return false; + } + + valueMap = std::move(tempMap); + return true; +} + +static bool areInsertGroupsEquivalentRec( + const size_t lhsChainIdx, llvm::ArrayRef lhsChains, + llvm::ArrayRef rhsChains, + llvm::SmallVectorImpl& rhsChainUsed, ValueEquivalenceMap& valueMap) { + if (lhsChainIdx == lhsChains.size()) { + return true; + } + + for (size_t rhsChainIdx = 0; rhsChainIdx < rhsChains.size(); ++rhsChainIdx) { + if (rhsChainUsed[rhsChainIdx] != 0) { + continue; + } + + ValueEquivalenceMap tempMap = valueMap; + if (!areInsertChainsEquivalent(lhsChains[lhsChainIdx], + rhsChains[rhsChainIdx], tempMap)) { + continue; + } + + rhsChainUsed[rhsChainIdx] = 1; + if (areInsertGroupsEquivalentRec(lhsChainIdx + 1, lhsChains, rhsChains, + rhsChainUsed, tempMap)) { + valueMap = std::move(tempMap); + return true; + } + rhsChainUsed[rhsChainIdx] = 0; + } + + return false; +} + +static bool areInsertGroupsEquivalent(llvm::ArrayRef lhsOps, + llvm::ArrayRef rhsOps, + ValueEquivalenceMap& valueMap) { + if (lhsOps.size() != rhsOps.size()) { + return false; + } + + llvm::SmallVector lhsChains; + llvm::SmallVector rhsChains; + if (!summarizeInsertGroup(lhsOps, lhsChains) || + !summarizeInsertGroup(rhsOps, rhsChains)) { + return false; + } + if (lhsChains.size() != rhsChains.size()) { + return false; + } + + llvm::SmallVector rhsChainUsed(rhsChains.size(), 0); + return areInsertGroupsEquivalentRec(0, lhsChains, rhsChains, rhsChainUsed, + valueMap); +} + /// DenseMapInfo specialization for StructuralOperationKey template <> struct llvm::DenseMapInfo { static StructuralOperationKey getEmptyKey() { @@ -378,18 +635,21 @@ llvm::SmallVector processed; llvm::SmallVector currentGroup; for (auto* op : ops) { bool dependsOnCurrent = false; // Check if this operation depends on any operation in the current group - for (const auto* groupOp : currentGroup) { - if (dependsOn[op].contains(groupOp)) { - dependsOnCurrent = true; - break; + for (auto* groupOp : currentGroup) { + if (!dependsOn[op].contains(groupOp)) { + continue; } + if (isCommutableQTensorInsertDependency(op, groupOp)) { + continue; + } + dependsOnCurrent = true; + break; } // Check if this operation has ordering constraints @@ -502,6 +762,18 @@ static bool areBlocksEquivalent(Block& lhs, Block& rhs, auto& lhsGroup = lhsGroups[groupIdx]; auto& rhsGroup = rhsGroups[groupIdx]; + const bool lhsInsertGroup = llvm::all_of(lhsGroup, isQTensorInsertOp); + const bool rhsInsertGroup = llvm::all_of(rhsGroup, isQTensorInsertOp); + if (lhsInsertGroup || rhsInsertGroup) { + if (!lhsInsertGroup || !rhsInsertGroup) { + return false; + } + if (!areInsertGroupsEquivalent(lhsGroup, rhsGroup, valueMap)) { + return false; + } + continue; + } + if (!areIndependentGroupsEquivalent(lhsGroup, rhsGroup)) { return false; } diff --git a/mlir/lib/Support/Passes.cpp b/mlir/lib/Support/Passes.cpp index 5e998761bc..0369056533 100644 --- a/mlir/lib/Support/Passes.cpp +++ b/mlir/lib/Support/Passes.cpp @@ -10,19 +10,67 @@ #include "mlir/Support/Passes.h" +#include "mlir/Dialect/QC/Transforms/Passes.h" +#include "mlir/Dialect/QIR/Transforms/Passes.h" +#include "mlir/Dialect/QTensor/Transforms/Passes.h" + +#include +#include #include #include #include +#include #include using namespace mlir; -void runCanonicalizationPasses(ModuleOp module) { - PassManager pm(module.getContext()); +static void addSimplificationPasses(PassManager& pm) { pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - pm.addPass(createRemoveDeadValuesPass()); +} + +static LogicalResult +runWithPassManager(ModuleOp module, + const llvm::function_ref populatePasses, + const llvm::StringRef errorMessage) { + PassManager pm(module.getContext()); + populatePasses(pm); if (pm.run(module).failed()) { - llvm::errs() << "Failed to run canonicalization passes.\n"; + llvm::errs() << errorMessage << "\n"; + return failure(); } + return success(); +} + +void populateQCCleanupPipeline(PassManager& pm) { + addSimplificationPasses(pm); + pm.addPass(qc::createShrinkQubitRegistersPass()); + pm.addPass(createRemoveDeadValuesPass()); +} + +void populateQCOCleanupPipeline(PassManager& pm) { + addSimplificationPasses(pm); + pm.addPass(qtensor::createShrinkQTensorToFitPass()); + pm.addPass(createRemoveDeadValuesPass()); +} + +void populateQIRCleanupPipeline(PassManager& pm) { + addSimplificationPasses(pm); + pm.addPass(qir::createQIRCleanupPass()); + pm.addPass(createRemoveDeadValuesPass()); +} + +[[nodiscard]] LogicalResult runQCCleanupPipeline(ModuleOp module) { + return runWithPassManager(module, populateQCCleanupPipeline, + "Failed to run QC cleanup pipeline."); +} + +[[nodiscard]] LogicalResult runQCOCleanupPipeline(ModuleOp module) { + return runWithPassManager(module, populateQCOCleanupPipeline, + "Failed to run QCO cleanup pipeline."); +} + +[[nodiscard]] LogicalResult runQIRCleanupPipeline(ModuleOp module) { + return runWithPassManager(module, populateQIRCleanupPipeline, + "Failed to run QIR cleanup pipeline."); } diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index 184bb7d677..ba6f4a18e1 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/QC/Translation/TranslateQuantumComputationToQC.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QIR/Builder/QIRProgramBuilder.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" #include "mlir/Support/IRVerification.h" #include "mlir/Support/Passes.h" #include "qc_programs.h" @@ -27,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -85,8 +87,9 @@ class CompilerPipelineTest void SetUp() override { mlir::DialectRegistry registry; registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); @@ -96,7 +99,7 @@ class CompilerPipelineTest [[nodiscard]] mlir::OwningOpRef buildQCReference(const QCProgramBuilderFn builder) const { auto module = mlir::qc::QCProgramBuilder::build(context.get(), builder.fn); - runCanonicalizationPasses(module.get()); + EXPECT_TRUE(runQCCleanupPipeline(module.get()).succeeded()); return module; } @@ -104,7 +107,7 @@ class CompilerPipelineTest buildQIRReference(const QIRProgramBuilderFn builder) const { auto module = mlir::qir::QIRProgramBuilder::build(context.get(), builder.fn); - runCanonicalizationPasses(module.get()); + EXPECT_TRUE(runQIRCleanupPipeline(module.get()).succeeded()); return module; } @@ -220,21 +223,21 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(mlir::qir::staticQubitsWithInv), false}, CompilerPipelineTestCase{"AllocQubit", MQT_NAMED_BUILDER(qc::allocQubit), nullptr, - MQT_NAMED_BUILDER(mlir::qc::allocQubit), - MQT_NAMED_BUILDER(mlir::qir::allocQubit)}, - CompilerPipelineTestCase{ - "AllocQubitRegister", MQT_NAMED_BUILDER(qc::allocQubitRegister), - nullptr, MQT_NAMED_BUILDER(mlir::qc::allocQubitRegister), - MQT_NAMED_BUILDER(mlir::qir::allocQubitRegister)}, + MQT_NAMED_BUILDER(mlir::qc::emptyQC), + MQT_NAMED_BUILDER(mlir::qir::emptyQIR)}, + CompilerPipelineTestCase{"AllocQubitRegister", + MQT_NAMED_BUILDER(qc::allocQubitRegister), + nullptr, MQT_NAMED_BUILDER(mlir::qc::emptyQC), + MQT_NAMED_BUILDER(mlir::qir::emptyQIR)}, CompilerPipelineTestCase{ "AllocMultipleQubitRegisters", MQT_NAMED_BUILDER(qc::allocMultipleQubitRegisters), nullptr, - MQT_NAMED_BUILDER(mlir::qc::allocMultipleQubitRegisters), - MQT_NAMED_BUILDER(mlir::qir::allocMultipleQubitRegisters)}, - CompilerPipelineTestCase{ - "AllocLargeRegister", MQT_NAMED_BUILDER(qc::allocLargeRegister), - nullptr, MQT_NAMED_BUILDER(mlir::qc::allocLargeRegister), - MQT_NAMED_BUILDER(mlir::qir::allocLargeRegister)}, + MQT_NAMED_BUILDER(mlir::qc::emptyQC), + MQT_NAMED_BUILDER(mlir::qir::emptyQIR)}, + CompilerPipelineTestCase{"AllocLargeRegister", + MQT_NAMED_BUILDER(qc::allocLargeRegister), + nullptr, MQT_NAMED_BUILDER(mlir::qc::emptyQC), + MQT_NAMED_BUILDER(mlir::qir::emptyQIR)}, CompilerPipelineTestCase{ "SingleMeasurementToSingleBit", MQT_NAMED_BUILDER(qc::singleMeasurementToSingleBit), nullptr, diff --git a/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt b/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt index 868d4e5656..43c4143b6d 100644 --- a/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt +++ b/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt @@ -24,4 +24,7 @@ target_link_libraries( mqt_mlir_configure_unittest_target(${target_name}) -gtest_discover_tests(${target_name} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) +# TODO(https://github.com/unitaryfoundation/jeff/issues/46): Enable this again once static +# information is preserved + +# gtest_discover_tests(${target_name} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) diff --git a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp index 48b0b22926..1afb831cdc 100644 --- a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp +++ b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -93,7 +94,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(program.get(), "Original QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -101,7 +102,12 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(program.get(), "Converted Jeff IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + PassManager pm(context.get()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createRemoveDeadValuesPass()); + EXPECT_TRUE(pm.run(program.get()).succeeded()); + printer.record(program.get(), "Canonicalized Converted Jeff IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -109,7 +115,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(program.get(), "Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -119,7 +125,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index 86133967cb..c1e5d81d43 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" #include "mlir/Support/IRVerification.h" #include "mlir/Support/Passes.h" #include "qc_programs.h" @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -60,8 +62,9 @@ class QCOToQCTest : public testing::TestWithParam { void SetUp() override { // Register all necessary dialects DialectRegistry registry; - registry.insert(); + registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); @@ -87,7 +90,7 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { printer.record(program.get(), "Original QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -95,7 +98,7 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { printer.record(program.get(), "Converted QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQCCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized Converted QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -105,7 +108,7 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(runQCCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index ce017e956f..9454f8309a 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" #include "mlir/Support/IRVerification.h" #include "mlir/Support/Passes.h" #include "qc_programs.h" @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -60,8 +62,9 @@ class QCToQCOTest : public testing::TestWithParam { void SetUp() override { // Register all necessary dialects DialectRegistry registry; - registry.insert(); + registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); @@ -86,7 +89,7 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { printer.record(program.get(), "Original QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQCCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -94,7 +97,7 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { printer.record(program.get(), "Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -104,7 +107,7 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index ef938ce4e7..65982d3673 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -61,7 +62,7 @@ class QCToQIRTest : public testing::TestWithParam { void SetUp() override { DialectRegistry registry; registry.insert(); + func::FuncDialect, memref::MemRefDialect>(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); @@ -86,7 +87,7 @@ TEST_P(QCToQIRTest, ProgramEquivalence) { printer.record(program.get(), "Original QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQCCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -94,7 +95,7 @@ TEST_P(QCToQIRTest, ProgramEquivalence) { printer.record(program.get(), "Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQIRCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -104,7 +105,7 @@ TEST_P(QCToQIRTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(runQIRCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 41a55f38fe..2c62a1e81e 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -9,4 +9,5 @@ add_subdirectory(QC) add_subdirectory(QCO) add_subdirectory(QIR) +add_subdirectory(QTensor) add_subdirectory(Utils) diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index d5266d35bf..9c3afe4ecf 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -60,7 +61,8 @@ class QCTest : public testing::TestWithParam { void QCTest::SetUp() { // Register all necessary dialects DialectRegistry registry; - registry.insert(); + registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); @@ -76,7 +78,7 @@ TEST_P(QCTest, ProgramEquivalence) { printer.record(program.get(), "Original QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQCCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -85,7 +87,7 @@ TEST_P(QCTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(runQCCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp index e8850c3706..79bcd296c0 100644 --- a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp +++ b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -58,7 +59,7 @@ class QuantumComputationTranslationTest void SetUp() override { mlir::DialectRegistry registry; registry.insert(); + mlir::func::FuncDialect, mlir::memref::MemRefDialect>(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); @@ -80,7 +81,7 @@ TEST_P(QuantumComputationTranslationTest, ProgramEquivalence) { printer.record(translated.get(), "Translated QC IR" + name); EXPECT_TRUE(mlir::verify(*translated).succeeded()); - runCanonicalizationPasses(translated.get()); + EXPECT_TRUE(runQCCleanupPipeline(translated.get()).succeeded()); printer.record(translated.get(), "Canonicalized Translated QC IR" + name); EXPECT_TRUE(mlir::verify(*translated).succeeded()); @@ -90,7 +91,7 @@ TEST_P(QuantumComputationTranslationTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QC IR" + name); EXPECT_TRUE(mlir::verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(runQCCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QC IR" + name); EXPECT_TRUE(mlir::verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 8c0c54d2ed..5a7c1d8add 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include "mlir/Support/IRVerification.h" #include "mlir/Support/Passes.h" #include "qco_programs.h" @@ -65,7 +66,6 @@ class QCOTest : public testing::TestWithParam { context->loadAllAvailableDialects(); } }; - } // namespace TEST_P(QCOTest, ProgramEquivalence) { @@ -78,7 +78,7 @@ TEST_P(QCOTest, ProgramEquivalence) { printer.record(program.get(), "Original QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -87,7 +87,7 @@ TEST_P(QCOTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); @@ -97,30 +97,35 @@ TEST_P(QCOTest, ProgramEquivalence) { TEST_F(QCOTest, DirectIfBuilder) { // Test If construction directly - qco::QCOProgramBuilder builder(context.get()); + QCOProgramBuilder builder(context.get()); builder.initialize(); - auto q0 = AllocOp::create(builder); - auto q1 = HOp::create(builder, q0); + auto c0 = arith::ConstantIndexOp::create(builder, 0); + auto c1 = arith::ConstantIndexOp::create(builder, 1); + auto r0 = qtensor::AllocOp::create(builder, c1); + auto extractOp = qtensor::ExtractOp::create(builder, r0, c0); + auto q1 = HOp::create(builder, extractOp.getResult()); auto measureOp = MeasureOp::create(builder, q1); auto ifOp = IfOp::create(builder, measureOp.getResult(), measureOp.getQubitOut(), [&](ValueRange qubits) -> llvm::SmallVector { auto innerQubit = XOp::create(builder, qubits[0]); - return llvm::SmallVector{innerQubit}; + return llvm::SmallVector{innerQubit}; }); - SinkOp::create(builder, ifOp.getResult(0)); + auto r2 = qtensor::InsertOp::create(builder, ifOp.getResult(0), + extractOp.getOutTensor(), c0); + qtensor::DeallocOp::create(builder, r2); auto directBuilder = builder.finalize(); ASSERT_TRUE(directBuilder); EXPECT_TRUE(verify(*directBuilder).succeeded()); - runCanonicalizationPasses(directBuilder.get()); + EXPECT_TRUE(runQCOCleanupPipeline(directBuilder.get()).succeeded()); EXPECT_TRUE(verify(*directBuilder).succeeded()); auto refBuilder = QCOProgramBuilder::build(context.get(), MQT_NAMED_BUILDER(simpleIf).fn); ASSERT_TRUE(refBuilder); EXPECT_TRUE(verify(*refBuilder).succeeded()); - runCanonicalizationPasses(refBuilder.get()); + EXPECT_TRUE(runQCOCleanupPipeline(refBuilder.get()).succeeded()); EXPECT_TRUE(verify(*refBuilder).succeeded()); EXPECT_TRUE(areModulesEquivalentWithPermutations(directBuilder.get(), @@ -1078,56 +1083,3 @@ INSTANTIATE_TEST_SUITE_P( QCOTestCase{"AllocSinkPair", MQT_NAMED_BUILDER(allocSinkPair), MQT_NAMED_BUILDER(emptyQCO)})); /// @} - -/// \name QTensor/QTensor.cpp -/// @{ -INSTANTIATE_TEST_SUITE_P( - QTensorTest, QCOTest, - testing::Values( - QCOTestCase{"QTensorAlloc", MQT_NAMED_BUILDER(qtensorAlloc), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QCOTestCase{"QTensorAllocDealloc", MQT_NAMED_BUILDER(qtensorDealloc), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QCOTestCase{"QTensorFromElements", - MQT_NAMED_BUILDER(qtensorFromElements), - MQT_NAMED_BUILDER(qtensorFromElements)}, - QCOTestCase{"QTensorExtract", MQT_NAMED_BUILDER(qtensorExtract), - MQT_NAMED_BUILDER(qtensorExtract)}, - QCOTestCase{"QTensorInsert", MQT_NAMED_BUILDER(qtensorInsert), - MQT_NAMED_BUILDER(qtensorInsert)}, - QCOTestCase{"QTensorExtractSlice", - MQT_NAMED_BUILDER(qtensorExtractSlice), - MQT_NAMED_BUILDER(qtensorExtractSlice)}, - QCOTestCase{"QTensorInsertSlice", MQT_NAMED_BUILDER(qtensorInsertSlice), - MQT_NAMED_BUILDER(qtensorInsertSlice)}, - QCOTestCase{"QTensorExtractInsertSameIndex", - MQT_NAMED_BUILDER(qtensorExtractInsertSameIndex), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QCOTestCase{"QTensorExtractInsertIndexMismatch", - MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch), - MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch)}, - QCOTestCase{"QTensorInsertExtractSameIndex", - MQT_NAMED_BUILDER(qtensorInsertExtractSameIndex), - MQT_NAMED_BUILDER(qtensorInsert)}, - QCOTestCase{"QTensorInsertExtractIndexMismatch", - MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch), - MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch)}, - QCOTestCase{"QTensorExtractSliceInsertSliceSameOffset", - MQT_NAMED_BUILDER(qtensorExtractSliceInsertSliceSameOffset), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QCOTestCase{ - "QTensorExtractSliceInsertSliceOffsetMismatch", - MQT_NAMED_BUILDER(qtensorExtractSliceInsertSliceOffsetMismatch), - MQT_NAMED_BUILDER(qtensorExtractSliceInsertSliceOffsetMismatch)}, - QCOTestCase{"QTensorInsertSliceExtractSliceSameOffset", - MQT_NAMED_BUILDER(qtensorInsertSliceExtractSliceSameOffset), - MQT_NAMED_BUILDER(qtensorInsertSlice)}, - QCOTestCase{ - "QTensorInsertSliceExtractSliceOffsetMismatch", - MQT_NAMED_BUILDER(qtensorInsertSliceExtractSliceOffsetMismatch), - MQT_NAMED_BUILDER(qtensorInsertSliceExtractSliceOffsetMismatch)}, - QCOTestCase{ - "QTensorExtractSliceExtractInsertInsertSlice", - MQT_NAMED_BUILDER(qtensorExtractSliceExtractInsertInsertSlice), - MQT_NAMED_BUILDER(qtensorAlloc)})); -/// @} diff --git a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp index 9d4b38f9c8..225a234b84 100644 --- a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp +++ b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp @@ -71,7 +71,7 @@ TEST_P(QIRTest, ProgramEquivalence) { printer.record(program.get(), "Original QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + EXPECT_TRUE(runQIRCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -80,7 +80,7 @@ TEST_P(QIRTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(runQIRCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QTensor/CMakeLists.txt b/mlir/unittests/Dialect/QTensor/CMakeLists.txt new file mode 100644 index 0000000000..b181a84fed --- /dev/null +++ b/mlir/unittests/Dialect/QTensor/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_subdirectory(IR) diff --git a/mlir/unittests/Dialect/QTensor/IR/CMakeLists.txt b/mlir/unittests/Dialect/QTensor/IR/CMakeLists.txt new file mode 100644 index 0000000000..b7e2227682 --- /dev/null +++ b/mlir/unittests/Dialect/QTensor/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(qtensor_ir_target mqt-core-mlir-unittest-qtensor-ir) +add_executable(${qtensor_ir_target} test_qtensor_ir.cpp) +target_link_libraries(${qtensor_ir_target} PRIVATE GTest::gtest_main MLIRParser MLIRSupportMQT + MLIRQCOProgramBuilder MLIRQCOPrograms) +mqt_mlir_configure_unittest_target(${qtensor_ir_target}) + +gtest_discover_tests(${qtensor_ir_target} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp new file mode 100644 index 0000000000..793f29f726 --- /dev/null +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -0,0 +1,517 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +/** + * @file test_qtensor_ir.cpp + * @brief Dedicated unit-test suite for the QTensor MLIR dialect. + */ + +#include "TestCaseUtils.h" +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" +#include "mlir/Support/IRVerification.h" +#include "mlir/Support/Passes.h" +#include "qco_programs.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::qtensor; +using namespace mlir::qco; + +namespace { + +class QTensorTest : public ::testing::Test { +protected: + std::unique_ptr context; + + void SetUp() override { + DialectRegistry registry; + registry.insert(); + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } + + /// Build a module using the QCOProgramBuilder and run the cleanup pipeline. + [[nodiscard]] OwningOpRef + buildAndCanonicalize(void (*buildFn)(QCOProgramBuilder&)) const { + auto module = QCOProgramBuilder::build(context.get(), buildFn); + if (!module) { + return {}; + } + if (runQCOCleanupPipeline(module.get()).failed()) { + return {}; + } + return module; + } + + /// Count occurrences of a specific op kind inside a module. + template + [[nodiscard]] static std::size_t countOps(ModuleOp module) { + std::size_t count = 0; + module.walk([&](OpT) { ++count; }); + return count; + } +}; + +// ============================================================================ +// QTensorUtils +// ============================================================================ + +TEST_F(QTensorTest, AreEquivalentIndicesSameValueIsEquivalent) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto c2 = arith::ConstantIndexOp::create(builder, 2); + EXPECT_TRUE(areEquivalentIndices(c2.getResult(), c2.getResult())); +} + +TEST_F(QTensorTest, AreEquivalentIndicesSameConstantsAreEquivalent) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto lhs = arith::ConstantIndexOp::create(builder, 2); + auto rhs = arith::ConstantIndexOp::create(builder, 2); + EXPECT_TRUE(areEquivalentIndices(lhs.getResult(), rhs.getResult())); +} + +TEST_F(QTensorTest, AreEquivalentIndicesDifferentConstantsAreNotEquivalent) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto c0 = arith::ConstantIndexOp::create(builder, 0); + auto c1 = arith::ConstantIndexOp::create(builder, 1); + EXPECT_FALSE(areEquivalentIndices(c0.getResult(), c1.getResult())); +} + +// ============================================================================ +// AllocOp +// ============================================================================ + +/// AllocOp with a constant size ≤ 0 must fail verification. +TEST_F(QTensorTest, AllocOpZeroSizeFailsVerification) { + auto loc = UnknownLoc::get(context.get()); + auto module = ModuleOp::create(loc); + ImplicitLocOpBuilder b(loc, context.get()); + b.setInsertionPointToStart(module.getBody()); + + auto qubitType = qco::QubitType::get(context.get()); + auto tensorType = RankedTensorType::get({ShapedType::kDynamic}, qubitType); + auto c0 = arith::ConstantIndexOp::create(b, 0); + AllocOp::create(b, tensorType, c0.getResult()); + + EXPECT_TRUE(verify(module).failed()); +} + +/// AllocOp where static result type dim ≠ constant size must fail. +TEST_F(QTensorTest, AllocOpStaticTypeMismatchFailsVerification) { + auto loc = UnknownLoc::get(context.get()); + auto module = ModuleOp::create(loc); + ImplicitLocOpBuilder b(loc, context.get()); + b.setInsertionPointToStart(module.getBody()); + + auto qubitType = qco::QubitType::get(context.get()); + auto tensorType = RankedTensorType::get({3}, qubitType); + auto c2 = arith::ConstantIndexOp::create(b, 2); + AllocOp::create(b, tensorType, c2.getResult()); + + EXPECT_TRUE(verify(module).failed()); +} + +/// AllocOp with a dynamic result type but a constant size operand is valid. +TEST_F(QTensorTest, AllocOpDynamicTypeWithConstantSizeVerifies) { + auto loc = UnknownLoc::get(context.get()); + auto module = ModuleOp::create(loc); + ImplicitLocOpBuilder b(loc, context.get()); + b.setInsertionPointToStart(module.getBody()); + + auto qubitType = qco::QubitType::get(context.get()); + auto tensorType = RankedTensorType::get({ShapedType::kDynamic}, qubitType); + auto c3 = arith::ConstantIndexOp::create(b, 3); + AllocOp::create(b, tensorType, c3.getResult()); + + EXPECT_TRUE(verify(module).succeeded()); +} + +/// AllocOp with a static result type but a dynamic size fails verification. +TEST_F(QTensorTest, AllocOpStaticTypeWithDynamicSizeOperandFailsVerification) { + auto loc = UnknownLoc::get(context.get()); + auto module = ModuleOp::create(loc); + ImplicitLocOpBuilder b(loc, context.get()); + b.setInsertionPointToStart(module.getBody()); + + // We need a block argument to act as a non-constant size + // Create a func.func to hold the block argument + auto funcType = + FunctionType::get(context.get(), {IndexType::get(context.get())}, {}); + auto func = func::FuncOp::create(b, "test", funcType); + auto* block = func.addEntryBlock(); + b.setInsertionPointToStart(block); + + auto qubitType = qco::QubitType::get(context.get()); + auto tensorType = RankedTensorType::get({3}, qubitType); + auto size = block->getArgument(0); + AllocOp::create(b, tensorType, size); + func::ReturnOp::create(b); + + EXPECT_TRUE(verify(module).failed()); +} + +// ============================================================================ +// DeallocOp +// ============================================================================ + +/// An alloc immediately followed by dealloc should be eliminated entirely. +TEST_F(QTensorTest, DeallocOpAllocDeallocPairIsRemoved) { + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto t = b.qtensorAlloc(3); + b.qtensorDealloc(t); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); + // Both AllocOp and DeallocOp should have been erased. + EXPECT_EQ(countOps(*canonicalized), 0U); + EXPECT_EQ(countOps(*canonicalized), 0U); +} + +// ============================================================================ +// ExtractOp +// ============================================================================ + +/// An extract at a negative constant index fails verification. +TEST_F(QTensorTest, ExtractOpNegativeIndexFailsVerification) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto tensor = builder.qtensorAlloc(3); + auto index = arith::ConstantIndexOp::create(builder, -1); + ExtractOp::create(builder, tensor, index.getResult()); + auto module = builder.finalize(); + + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).failed()); +} + +/// An extract at an index equal to the tensor dimension fails verification. +TEST_F(QTensorTest, ExtractOpIndexAtDimFailsVerification) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto tensor = builder.qtensorAlloc(3); + auto index = arith::ConstantIndexOp::create(builder, 3); + ExtractOp::create(builder, tensor, index.getResult()); + auto module = builder.finalize(); + + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).failed()); +} + +// ============================================================================ +// InsertOp +// ============================================================================ + +/// An insert at a negative constant index fails verification. +TEST_F(QTensorTest, InsertOpNegativeIndexFailsVerification) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto tensor0 = builder.qtensorAlloc(3); + auto [tensor1, q0] = builder.qtensorExtract(tensor0, 0); + auto index = arith::ConstantIndexOp::create(builder, -1); + InsertOp::create(builder, q0, tensor1, index.getResult()); + auto module = builder.finalize(); + + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).failed()); +} + +/// An insert at an index equal to the destination dimension fails verification. +TEST_F(QTensorTest, InsertOpIndexAtDimFailsVerification) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto tensor0 = builder.qtensorAlloc(3); + auto [tensor1, q0] = builder.qtensorExtract(tensor0, 0); + auto index = arith::ConstantIndexOp::create(builder, 3); + InsertOp::create(builder, q0, tensor1, index.getResult()); + auto module = builder.finalize(); + + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).failed()); +} + +} // namespace + +// ============================================================================ +// Canonicalization +// ============================================================================ + +static OwningOpRef +buildTwoQubitInsertChainProgram(MLIRContext* context, + const bool reverseInsertOrder, + const bool swapInsertTargets) { + const int64_t q0Target = swapInsertTargets ? 1 : 0; + const int64_t q1Target = swapInsertTargets ? 0 : 1; + + QCOProgramBuilder builder(context); + builder.initialize(); + + Value q0 = nullptr; + Value q1 = nullptr; + + auto tensor = builder.qtensorAlloc(2); + std::tie(tensor, q0) = builder.qtensorExtract(tensor, 0); + std::tie(tensor, q1) = builder.qtensorExtract(tensor, 1); + + if (reverseInsertOrder) { + tensor = builder.qtensorInsert(q1, tensor, q1Target); + tensor = builder.qtensorInsert(q0, tensor, q0Target); + } else { + tensor = builder.qtensorInsert(q0, tensor, q0Target); + tensor = builder.qtensorInsert(q1, tensor, q1Target); + } + + builder.qtensorDealloc(tensor); + return builder.finalize(); +} + +static OwningOpRef +buildResetWithCommutingInsertProgram(MLIRContext* context, + const bool withReset) { + QCOProgramBuilder builder(context); + builder.initialize(); + + Value q0 = nullptr; + Value q1 = nullptr; + + auto tensor = builder.qtensorAlloc(2); + std::tie(tensor, q0) = builder.qtensorExtract(tensor, 0); + tensor = builder.qtensorInsert(q0, tensor, 0); + std::tie(tensor, q1) = builder.qtensorExtract(tensor, 1); + if (withReset) { + q1 = builder.reset(q1); + } + tensor = builder.qtensorInsert(q1, tensor, 1); + + builder.qtensorDealloc(tensor); + return builder.finalize(); +} + +static OwningOpRef +buildResetWithSameIndexInsertProgram(MLIRContext* context, + const bool withReset) { + QCOProgramBuilder builder(context); + builder.initialize(); + + Value q0 = nullptr; + Value q10 = nullptr; + Value q11 = nullptr; + + auto tensor = builder.qtensorAlloc(2); + std::tie(tensor, q0) = builder.qtensorExtract(tensor, 0); + std::tie(tensor, q10) = builder.qtensorExtract(tensor, 1); + q10 = builder.h(q10); + tensor = builder.qtensorInsert(q10, tensor, 1); + std::tie(tensor, q11) = builder.qtensorExtract(tensor, 1); + if (withReset) { + q11 = builder.reset(q11); + } + tensor = builder.qtensorInsert(q11, tensor, 1); + tensor = builder.qtensorInsert(q0, tensor, 0); + + builder.qtensorDealloc(tensor); + return builder.finalize(); +} + +namespace { + +TEST_F(QTensorTest, InsertChainPermutationEquivalence) { + auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildTwoQubitInsertChainProgram(context.get(), true, false); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QTensorTest, InsertChainDifferentAssignmentsNotEquivalent) { + auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildTwoQubitInsertChainProgram(context.get(), true, true); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_FALSE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QTensorTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { + auto program = buildResetWithCommutingInsertProgram(context.get(), true); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildResetWithCommutingInsertProgram(context.get(), false); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QTensorTest, ResetAfterExtractThroughSameIndexInsertIsNotEliminated) { + auto program = buildResetWithSameIndexInsertProgram(context.get(), true); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildResetWithSameIndexInsertProgram(context.get(), false); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_FALSE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +// ============================================================================ +// Integration +// ============================================================================ + +struct QTensorIntegrationTestCase { + std::string name; + mqt::test::NamedBuilder programBuilder; + mqt::test::NamedBuilder referenceBuilder; + + friend std::ostream& operator<<(std::ostream& os, + const QTensorIntegrationTestCase& info); +}; + +// NOLINTNEXTLINE(llvm-prefer-static-over-anonymous-namespace) +std::ostream& operator<<(std::ostream& os, + const QTensorIntegrationTestCase& info) { + return os << "QTensor{" << info.name << "}"; +} + +class QTensorIntegrationTest + : public testing::TestWithParam { +protected: + std::unique_ptr context; + + void SetUp() override { + DialectRegistry registry; + registry.insert(); + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } +}; + +TEST_P(QTensorIntegrationTest, ProgramEquivalence) { + const auto& [_, programBuilder, referenceBuilder] = GetParam(); + const auto name = " (" + GetParam().name + ")"; + mqt::test::DeferredPrinter printer; + + auto program = QCOProgramBuilder::build(context.get(), programBuilder.fn); + ASSERT_TRUE(program); + printer.record(program.get(), "Original QTensor IR" + name); + EXPECT_TRUE(verify(*program).succeeded()); + + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); + printer.record(program.get(), "Canonicalized QTensor IR" + name); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = QCOProgramBuilder::build(context.get(), referenceBuilder.fn); + ASSERT_TRUE(reference); + printer.record(reference.get(), "Reference QTensor IR" + name); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); + printer.record(reference.get(), "Canonicalized Reference QTensor IR" + name); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +/// @name QTensor/QTensor.cpp (relocated from QCO test suite) +/// @{ +INSTANTIATE_TEST_SUITE_P( + QTensorOpsTest, QTensorIntegrationTest, + testing::Values( + QTensorIntegrationTestCase{"QTensorAlloc", + MQT_NAMED_BUILDER(qtensorAlloc), + MQT_NAMED_BUILDER(qtensorAlloc)}, + QTensorIntegrationTestCase{"QTensorAllocDealloc", + MQT_NAMED_BUILDER(qtensorDealloc), + MQT_NAMED_BUILDER(qtensorAlloc)}, + QTensorIntegrationTestCase{"QTensorFromElements", + MQT_NAMED_BUILDER(qtensorFromElements), + MQT_NAMED_BUILDER(qtensorFromElements)}, + QTensorIntegrationTestCase{"QTensorExtract", + MQT_NAMED_BUILDER(qtensorExtract), + MQT_NAMED_BUILDER(qtensorExtract)}, + QTensorIntegrationTestCase{"QTensorInsert", + MQT_NAMED_BUILDER(qtensorInsert), + MQT_NAMED_BUILDER(qtensorInsert)}, + QTensorIntegrationTestCase{ + "QTensorExtractInsertSameIndex", + MQT_NAMED_BUILDER(qtensorExtractInsertSameIndex), + MQT_NAMED_BUILDER(qtensorAlloc)}, + QTensorIntegrationTestCase{ + "QTensorExtractInsertIndexMismatch", + MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch), + MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch)}, + QTensorIntegrationTestCase{ + "QTensorInsertExtractSameIndex", + MQT_NAMED_BUILDER(qtensorInsertExtractSameIndex), + MQT_NAMED_BUILDER(qtensorInsert)}, + QTensorIntegrationTestCase{ + "QTensorInsertExtractIndexMismatch", + MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch), + MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch)})); +/// @} + +} // namespace diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 6b55fbe957..2357ce2440 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -23,8 +23,8 @@ void allocQubit(QCProgramBuilder& b) { b.allocQubit(); } void allocQubitRegister(QCProgramBuilder& b) { b.allocQubitRegister(2); } void allocMultipleQubitRegisters(QCProgramBuilder& b) { - b.allocQubitRegister(2, "reg0"); - b.allocQubitRegister(3, "reg1"); + b.allocQubitRegister(2); + b.allocQubitRegister(3); } void allocLargeRegister(QCProgramBuilder& b) { b.allocQubitRegister(100); } @@ -126,8 +126,8 @@ void multipleClassicalRegistersAndMeasurements(QCProgramBuilder& b) { } void resetQubitWithoutOp(QCProgramBuilder& b) { - auto q = b.allocQubit(); - b.reset(q); + auto q = b.allocQubitRegister(1); + b.reset(q[0]); } void resetMultipleQubitsWithoutOp(QCProgramBuilder& b) { @@ -137,16 +137,16 @@ void resetMultipleQubitsWithoutOp(QCProgramBuilder& b) { } void repeatedResetWithoutOp(QCProgramBuilder& b) { - auto q = b.allocQubit(); - b.reset(q); - b.reset(q); - b.reset(q); + auto q = b.allocQubitRegister(1); + b.reset(q[0]); + b.reset(q[0]); + b.reset(q[0]); } void resetQubitAfterSingleOp(QCProgramBuilder& b) { - auto q = b.allocQubit(); - b.h(q); - b.reset(q); + auto q = b.allocQubitRegister(1); + b.h(q[0]); + b.reset(q[0]); } void resetMultipleQubitsAfterSingleOp(QCProgramBuilder& b) { @@ -158,11 +158,11 @@ void resetMultipleQubitsAfterSingleOp(QCProgramBuilder& b) { } void repeatedResetAfterSingleOp(QCProgramBuilder& b) { - auto q = b.allocQubit(); - b.h(q); - b.reset(q); - b.reset(q); - b.reset(q); + auto q = b.allocQubitRegister(1); + b.h(q[0]); + b.reset(q[0]); + b.reset(q[0]); + b.reset(q[0]); } void globalPhase(QCProgramBuilder& b) { b.gphase(0.123); } @@ -178,8 +178,8 @@ void multipleControlledGlobalPhase(QCProgramBuilder& b) { } void nestedControlledGlobalPhase(QCProgramBuilder& b) { - auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cgphase(0.123, reg[1]); }); + auto q = b.allocQubitRegister(3); + b.ctrl(q[0], [&] { b.cgphase(0.123, q[1]); }); } void trivialControlledGlobalPhase(QCProgramBuilder& b) { @@ -212,8 +212,8 @@ void multipleControlledIdentity(QCProgramBuilder& b) { } void nestedControlledIdentity(QCProgramBuilder& b) { - auto reg = b.allocQubitRegister(3); - b.ctrl(reg[2], [&] { b.cid(reg[1], reg[0]); }); + auto q = b.allocQubitRegister(3); + b.ctrl(q[2], [&] { b.cid(q[1], q[0]); }); } void trivialControlledIdentity(QCProgramBuilder& b) { diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index e7695e3cac..0f86e41595 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -27,8 +27,8 @@ void allocQubit(QCOProgramBuilder& b) { b.allocQubit(); } void allocQubitRegister(QCOProgramBuilder& b) { b.allocQubitRegister(2); } void allocMultipleQubitRegisters(QCOProgramBuilder& b) { - b.allocQubitRegister(2, "reg0"); - b.allocQubitRegister(3, "reg1"); + b.allocQubitRegister(2); + b.allocQubitRegister(3); } void allocLargeRegister(QCOProgramBuilder& b) { b.allocQubitRegister(100); } @@ -126,9 +126,9 @@ void repeatedResetWithoutOp(QCOProgramBuilder& b) { } void resetQubitAfterSingleOp(QCOProgramBuilder& b) { - auto q = b.allocQubit(); - q = b.h(q); - q = b.reset(q); + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); + q[0] = b.reset(q[0]); } void resetMultipleQubitsAfterSingleOp(QCOProgramBuilder& b) { @@ -140,11 +140,11 @@ void resetMultipleQubitsAfterSingleOp(QCOProgramBuilder& b) { } void repeatedResetAfterSingleOp(QCOProgramBuilder& b) { - auto q = b.allocQubit(); - q = b.h(q); - q = b.reset(q); - q = b.reset(q); - q = b.reset(q); + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); + q[0] = b.reset(q[0]); + q[0] = b.reset(q[0]); + q[0] = b.reset(q[0]); } void globalPhase(QCOProgramBuilder& b) { b.gphase(0.123); } @@ -2193,21 +2193,6 @@ void qtensorInsert(QCOProgramBuilder& b) { b.qtensorInsert(q1, extractOutTensor, 0); } -void qtensorExtractSlice(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - b.qtensorExtractSlice(qtensor, 0, 2); -} - -void qtensorInsertSlice(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - auto [extractOutTensor, q0] = b.qtensorExtract(slicedTensor, 0); - auto q1 = b.h(q0); - auto insertOutTensor = b.qtensorInsert(q1, extractOutTensor, 0); - b.qtensorInsertSlice(insertOutTensor, extractSliceOutTensor, 0, 2); -} - void qtensorExtractInsertIndexMismatch(QCOProgramBuilder& b) { auto qtensor = b.qtensorAlloc(3); auto [extractOutTensor, q0] = b.qtensorExtract(qtensor, 0); @@ -2220,20 +2205,6 @@ void qtensorExtractInsertSameIndex(QCOProgramBuilder& b) { b.qtensorInsert(q0, extractOutTensor, 0); } -void qtensorExtractSliceInsertSliceOffsetMismatch(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - b.qtensorInsertSlice(slicedTensor, extractSliceOutTensor, 1, 2); -} - -void qtensorExtractSliceInsertSliceSameOffset(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - b.qtensorInsertSlice(slicedTensor, extractSliceOutTensor, 0, 2); -} - void qtensorInsertExtractIndexMismatch(QCOProgramBuilder& b) { auto qtensor = b.qtensorAlloc(3); auto [extractOutTensor, q0] = b.qtensorExtract(qtensor, 0); @@ -2252,41 +2223,4 @@ void qtensorInsertExtractSameIndex(QCOProgramBuilder& b) { b.qtensorInsert(q2, extractOutTensor1, 0); } -void qtensorInsertSliceExtractSliceOffsetMismatch(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - auto [extractOutTensor, q0] = b.qtensorExtract(slicedTensor, 0); - auto q1 = b.h(q0); - auto insertOutTensor = b.qtensorInsert(q1, extractOutTensor, 0); - auto insertSliceOutTensor = - b.qtensorInsertSlice(insertOutTensor, extractSliceOutTensor, 0, 2); - auto [extractSliceOutTensor1, slicedTensor1] = - b.qtensorExtractSlice(insertSliceOutTensor, 1, 2); - b.qtensorInsertSlice(slicedTensor1, extractSliceOutTensor1, 0, 2); -} - -void qtensorInsertSliceExtractSliceSameOffset(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - auto [extractOutTensor, q0] = b.qtensorExtract(slicedTensor, 0); - auto q1 = b.h(q0); - auto insertOutTensor = b.qtensorInsert(q1, extractOutTensor, 0); - auto insertSliceOutTensor = - b.qtensorInsertSlice(insertOutTensor, extractSliceOutTensor, 0, 2); - auto [extractSliceOutTensor1, slicedTensor1] = - b.qtensorExtractSlice(insertSliceOutTensor, 0, 2); - b.qtensorInsertSlice(slicedTensor1, extractSliceOutTensor1, 0, 2); -} - -void qtensorExtractSliceExtractInsertInsertSlice(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - auto [extractOutTensor, q0] = b.qtensorExtract(slicedTensor, 0); - auto insertOutTensor = b.qtensorInsert(q0, extractOutTensor, 0); - b.qtensorInsertSlice(insertOutTensor, extractSliceOutTensor, 0, 2); -} - } // namespace mlir::qco diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index a14100102c..3d503bc472 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -1007,12 +1007,6 @@ void qtensorExtract(QCOProgramBuilder& b); /// Inserts a qubit into a tensor. void qtensorInsert(QCOProgramBuilder& b); -/// Extracts a slice from a tensor. -void qtensorExtractSlice(QCOProgramBuilder& b); - -/// Inserts a slice into a tensor. -void qtensorInsertSlice(QCOProgramBuilder& b); - /// Extracts a qubit from a tensor and inserts it immediately at a different /// index. void qtensorExtractInsertIndexMismatch(QCOProgramBuilder& b); @@ -1027,25 +1021,4 @@ void qtensorInsertExtractIndexMismatch(QCOProgramBuilder& b); /// Inserts a qubit into a tensor and extracts it immediately at the same index. void qtensorInsertExtractSameIndex(QCOProgramBuilder& b); -/// Extracts a slice of qubits from a tensor and inserts it immediately at a -/// different offset. -void qtensorExtractSliceInsertSliceOffsetMismatch(QCOProgramBuilder& b); - -/// Extracts a slice of qubits from a tensor and inserts it immediately at the -/// same offset. -void qtensorExtractSliceInsertSliceSameOffset(QCOProgramBuilder& b); - -/// Inserts a slice of qubits into a tensor and extracts it immediately at a -/// different offset. -void qtensorInsertSliceExtractSliceOffsetMismatch(QCOProgramBuilder& b); - -/// Inserts a slice of qubits into a tensor and extracts it immediately at the -/// same offset. -void qtensorInsertSliceExtractSliceSameOffset(QCOProgramBuilder& b); - -/// Extracts a slice of qubits, extracts a qubit from the slice, inserts the -/// qubit back into the slice, and inserts the slice back into the tensor -/// immediately at the same index and offset. -void qtensorExtractSliceExtractInsertInsertSlice(QCOProgramBuilder& b); - } // namespace mlir::qco