From 3df877781e7556f822c86dcf21cea81a102d4322 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Thu, 19 Mar 2026 21:26:41 +0100 Subject: [PATCH 01/71] Replace behavior of allocQubitRegister() methods --- .../mlir/Conversion/QCOToQC/QCOToQC.td | 3 +- .../mlir/Conversion/QCToQCO/QCToQCO.td | 3 +- .../Dialect/QC/Builder/QCProgramBuilder.h | 2 + mlir/include/mlir/Dialect/QC/IR/QCTypes.td | 11 +- mlir/lib/Conversion/QCOToQC/CMakeLists.txt | 1 + mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 85 ++++++++++- mlir/lib/Conversion/QCToQCO/CMakeLists.txt | 1 + mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 140 +++++++++++++++++- mlir/lib/Dialect/QC/Builder/CMakeLists.txt | 2 +- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 53 +++++-- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 25 ++-- .../Conversion/QCOToQC/test_qco_to_qc.cpp | 7 +- .../Conversion/QCToQCO/test_qc_to_qco.cpp | 7 +- mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp | 4 +- .../test_quantum_computation_translation.cpp | 3 +- 15 files changed, 298 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/Conversion/QCOToQC/QCOToQC.td b/mlir/include/mlir/Conversion/QCOToQC/QCOToQC.td index a773b3d5c2..e77e57ad9a 100644 --- a/mlir/include/mlir/Conversion/QCOToQC/QCOToQC.td +++ b/mlir/include/mlir/Conversion/QCOToQC/QCOToQC.td @@ -16,5 +16,6 @@ def QCOToQC : Pass<"qco-to-qc"> { It handles the transformation of qubit values in QCO to qubit references in QC, ensuring that the semantics of quantum operations are preserved during the conversion process. }]; - let dependentDialects = ["mlir::qc::QCDialect"]; + let dependentDialects = ["mlir::memref::MemRefDialect", + "mlir::qc::QCDialect"]; } diff --git a/mlir/include/mlir/Conversion/QCToQCO/QCToQCO.td b/mlir/include/mlir/Conversion/QCToQCO/QCToQCO.td index 9ffd33d4f7..8b1f6fadc4 100644 --- a/mlir/include/mlir/Conversion/QCToQCO/QCToQCO.td +++ b/mlir/include/mlir/Conversion/QCToQCO/QCToQCO.td @@ -16,5 +16,6 @@ def QCToQCO : Pass<"qc-to-qco"> { It handles the transformation of qubit references in QC to qubit values in QCO, ensuring that the semantics of quantum operations are preserved during the conversion process. }]; - let dependentDialects = ["mlir::qco::QCODialect"]; + let dependentDialects = ["mlir::arith::ArithDialect", "mlir::qco::QCODialect", + "mlir::qtensor::QTensorDialect"]; } diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index fdf5ab7310..867ef62515 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -942,6 +942,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { /// Track allocated qubits for automatic deallocation llvm::DenseSet allocatedQubits; + llvm::DenseSet allocatedMemrefs; + /// Check if the builder has been finalized void checkFinalized() const; }; diff --git a/mlir/include/mlir/Dialect/QC/IR/QCTypes.td b/mlir/include/mlir/Dialect/QC/IR/QCTypes.td index 85d36311fa..8296fef23c 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCTypes.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCTypes.td @@ -12,20 +12,19 @@ include "mlir/Dialect/QC/IR/QCDialect.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" class QCType traits = []> : TypeDef { let mnemonic = typeMnemonic; } -def QubitType : QCType<"Qubit", "qubit"> { +def QubitType : QCType<"Qubit", "qubit", [MemRefElementTypeInterface]> { let summary = "QC qubit reference type"; let description = [{ - The `!qc.qubit` type represents a reference to a quantum bit in the - QC dialect. Operations using this type modify qubits in place using - reference semantics, similar to how classical imperative languages handle - mutable references. - }]; + The `!qc.qubit` type represents a reference to a quantum bit in the QC dialect. + Operations using this type modify qubits in place using reference semantics, similar to how classical imperative languages handle mutable references. + }]; } #endif // MLIR_DIALECT_QC_IR_QCTYPES_TD diff --git a/mlir/lib/Conversion/QCOToQC/CMakeLists.txt b/mlir/lib/Conversion/QCOToQC/CMakeLists.txt index eba206ceda..c2fa03c520 100644 --- a/mlir/lib/Conversion/QCOToQC/CMakeLists.txt +++ b/mlir/lib/Conversion/QCOToQC/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_conversion_library( MLIRQTensorDialect MLIRArithDialect MLIRFuncDialect + MLIRMemRefDialect MLIRTransforms MLIRFuncTransforms DISABLE_INSTALL) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 381486b02b..5855d73070 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -14,9 +14,12 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include #include +#include #include #include #include @@ -59,6 +62,77 @@ class QCOToQCTypeConverter final : public TypeConverter { addConversion([ctx](qco::QubitType /*type*/) -> Type { return qc::QubitType::get(ctx); }); + + addConversion([ctx](RankedTensorType type) -> Type { + if (llvm::isa(type.getElementType())) { + // TODO: Can we make it work with type.getShape()? + return MemRefType::get({ShapedType::kDynamic}, qc::QubitType::get(ctx)); + } + return type; + }); + } +}; + +struct ConvertQTensorAllocOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::AllocOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto qubitType = qc::QubitType::get(op.getContext()); + auto memrefType = mlir::MemRefType::get({ShapedType::kDynamic}, qubitType); + rewriter.replaceOpWithNewOp(op, memrefType, op.getSize()); + return success(); + } +}; + +struct ConvertQTensorExtractOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto load = memref::LoadOp::create(rewriter, op.getLoc(), + adaptor.getTensor(), adaptor.getIndex()); + rewriter.replaceOp(op, {adaptor.getTensor(), load.getResult()}); + return success(); + } +}; + +// struct ConvertQTensorInsertOp final : OpConversionPattern +// { +// using OpConversionPattern::OpConversionPattern; + +// LogicalResult +// matchAndRewrite(qtensor::InsertOp op, OpAdaptor adaptor, +// ConversionPatternRewriter& rewriter) const override { +// auto store = +// memref::StoreOp::create(rewriter, op.getLoc(), adaptor.getScalar(), +// adaptor.getDest(), adaptor.getIndex()); +// rewriter.replaceOp(op, adaptor.getDest()); +// return success(); +// } +// }; + +struct ConvertQTensorInsertOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertQTensorDeallocOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getTensor()); + return success(); } }; @@ -751,15 +825,16 @@ struct QCOToQC final : impl::QCOToQCBase { RewritePatternSet patterns(context); QCOToQCTypeConverter typeConverter(context); - // Configure conversion target: QCO illegal, QC legal - target.addIllegalDialect(); - target.addLegalDialect(); + // Configure conversion target + target.addIllegalDialect(); + target.addLegalDialect(); // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion patterns.add< - ConvertQCOAllocOp, ConvertQCODeallocOp, ConvertQCOStaticOp, - ConvertQCOMeasureOp, ConvertQCOResetOp, + ConvertQTensorAllocOp, ConvertQTensorExtractOp, ConvertQTensorInsertOp, + ConvertQTensorDeallocOp, ConvertQCOAllocOp, ConvertQCODeallocOp, + ConvertQCOStaticOp, ConvertQCOMeasureOp, ConvertQCOResetOp, ConvertQCOZeroTargetOneParameterToQC, ConvertQCOOneTargetZeroParameterToQC, ConvertQCOOneTargetZeroParameterToQC, diff --git a/mlir/lib/Conversion/QCToQCO/CMakeLists.txt b/mlir/lib/Conversion/QCToQCO/CMakeLists.txt index 0474683bfd..157a47593a 100644 --- a/mlir/lib/Conversion/QCToQCO/CMakeLists.txt +++ b/mlir/lib/Conversion/QCToQCO/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_conversion_library( MLIRQTensorDialect MLIRArithDialect MLIRFuncDialect + MLIRMemRefDialect MLIRTransforms MLIRFuncTransforms DISABLE_INSTALL) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 474f80aacf..299c980b8c 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -14,11 +14,14 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include #include #include #include +#include #include #include #include @@ -75,6 +78,8 @@ struct LoweringState { /// Map from original QC qubit references to their latest QCO SSA values llvm::DenseMap qubitMap; + llvm::DenseMap qtensorMap; + /// Modifier information int64_t inNestedRegion = 0; DenseMap> targetsIn; @@ -135,6 +140,130 @@ class QCToQCOTypeConverter final : public TypeConverter { } }; +struct ConvertMemRefAllocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto& qtensorMap = getState().qtensorMap; + + auto shape = op.getType().getShape(); + if (shape.size() != 1) { + return failure(); + } + + Value qtensor; + if (shape[0] == ShapedType::kDynamic) { + qtensor = rewriter.replaceOpWithNewOp( + op, adaptor.getDynamicSizes()[0]); + } else { + auto size = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIndexAttr(shape[0])); + qtensor = + rewriter.replaceOpWithNewOp(op, size.getResult()); + } + + qtensorMap.try_emplace(op.getResult(), qtensor); + + return success(); + } +}; + +struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto& qubitMap = getState().qubitMap; + auto& qtensorMap = getState().qtensorMap; + + // Look up the latest QTensor value for this QC register + auto memref = op.getMemref(); + assert(qtensorMap.contains(memref) && "QC register not found"); + auto qtensor = qtensorMap[memref]; + + auto extract = qtensor::ExtractOp::create(rewriter, op.getLoc(), qtensor, + adaptor.getIndices()[0]); + + qubitMap.try_emplace(op.getResult(), extract.getResult()); + qtensorMap[memref] = extract.getOutTensor(); + + rewriter.eraseOp(op); + + return success(); + } +}; + +// struct ConvertMemRefStoreOp final +// : StatefulOpConversionPattern { +// using StatefulOpConversionPattern::StatefulOpConversionPattern; + +// LogicalResult +// matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, +// ConversionPatternRewriter& rewriter) const override { +// auto& qubitMap = getState().qubitMap; +// auto& qtensorMap = getState().qtensorMap; + +// // Look up the latest QCO value for this QC qubit +// auto qcQubit = op.getValue(); +// assert(qubitMap.contains(qcQubit) && "QC qubit not found"); +// auto qcoQubit = qubitMap[qcQubit]; + +// // Look up the latest QTensor value for this QC register +// auto memref = op.getMemref(); +// assert(qtensorMap.contains(memref) && "QC register not found"); +// auto qtensor = qtensorMap[memref]; + +// auto store = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, +// qtensor, adaptor.getIndices()[0]); + +// qubitMap.erase(qcQubit); +// qtensorMap[memref] = store.getResult(); + +// rewriter.eraseOp(op); + +// return success(); +// } +// }; + +struct ConvertMemRefDeallocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto& qubitMap = getState().qubitMap; + auto& qtensorMap = getState().qtensorMap; + + // Look up the latest QTensor value for this QC register + auto memref = op.getMemref(); + assert(qtensorMap.contains(memref) && "QC register not found"); + auto qtensor = qtensorMap[memref]; + + // Insert all qubits + // TODO: Use dedicated map + int64_t i = 0; + for (auto [_, qcoQubit] : qubitMap) { + auto index = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIndexAttr(i)); + auto insert = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, + qtensor, index.getResult()); + qtensor = insert.getResult(); + ++i; + } + + rewriter.replaceOpWithNewOp(op, qtensor); + + qtensorMap.erase(memref); + + return success(); + } +}; + /** * @brief Converts qc.alloc to qco.alloc * @@ -987,7 +1116,7 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(qc::InvOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& [qubitMap, inNestedRegion, targetsIn, targetsOut] = getState(); + auto& [qubitMap, _, inNestedRegion, targetsIn, targetsOut] = getState(); // Get QCO targets from state map const auto numTargets = op.getNumTargets(); @@ -1108,12 +1237,15 @@ struct QCToQCO final : impl::QCToQCOBase { RewritePatternSet patterns(context); QCToQCOTypeConverter typeConverter(context); - // Configure conversion target: QC illegal, QCO legal - target.addIllegalDialect(); - target.addLegalDialect(); + // Configure conversion target + // TODO: Do not blanket-illegalize memref + target.addIllegalDialect(); + target.addLegalDialect(); // Register operation conversion patterns with state tracking patterns.add< + ConvertMemRefAllocOp, ConvertMemRefLoadOp, ConvertMemRefDeallocOp, ConvertQCAllocOp, ConvertQCDeallocOp, ConvertQCStaticOp, ConvertQCMeasureOp, ConvertQCResetOp, ConvertQCZeroTargetOneParameterToQCO, diff --git a/mlir/lib/Dialect/QC/Builder/CMakeLists.txt b/mlir/lib/Dialect/QC/Builder/CMakeLists.txt index 71ae043b39..58563c2ddd 100644 --- a/mlir/lib/Dialect/QC/Builder/CMakeLists.txt +++ b/mlir/lib/Dialect/QC/Builder/CMakeLists.txt @@ -13,7 +13,7 @@ add_mlir_library( PUBLIC MLIRArithDialect MLIRFuncDialect - MLIRSCFDialect + MLIRMemRefDialect MLIRQCDialect DISABLE_INSTALL) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index da5cf312dd..e138e0794b 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -101,21 +102,23 @@ QCProgramBuilder::allocQubitRegister(const int64_t size, llvm::reportFatalUsageError("Size must be positive"); } - // Allocate a sequence of qubits with register metadata + auto qubitType = QubitType::get(ctx); + auto memrefType = mlir::MemRefType::get({size}, qubitType); + auto memref = memref::AllocOp::create(*this, memrefType); + llvm::SmallVector qubits; qubits.reserve(size); - auto nameAttr = getStringAttr(name); - auto sizeAttr = getI64IntegerAttr(size); - for (int64_t i = 0; i < size; ++i) { - auto indexAttr = getI64IntegerAttr(i); - auto allocOp = AllocOp::create(*this, nameAttr, sizeAttr, indexAttr); - const auto& qubit = qubits.emplace_back(allocOp.getResult()); + auto index = arith::ConstantOp::create(*this, getIndexAttr(i)); + auto load = memref::LoadOp::create(*this, memref, index.getResult()); + const auto& qubit = qubits.emplace_back(load.getResult()); // Track the allocated qubit for automatic deallocation allocatedQubits.insert(qubit); } + allocatedMemrefs.insert(memref); + return qubits; } @@ -500,25 +503,47 @@ OwningOpRef QCProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } - // Automatically deallocate all still-allocated qubits - // Sort qubits for deterministic output - llvm::SmallVector sortedQubits(allocatedQubits.begin(), - allocatedQubits.end()); - llvm::sort(sortedQubits, [](Value a, Value b) { + // llvm::SmallVector freeQubits; + // for (auto qubit : allocatedQubits) { + // auto memref = qubit.getDefiningOp(); + // if (!memref) { + // freeQubits.emplace_back(qubit); + // } + // } + + auto blockOrderComparator = [](Value a, Value b) { auto* opA = a.getDefiningOp(); auto* opB = b.getDefiningOp(); if (!opA || !opB || opA->getBlock() != opB->getBlock()) { return a.getAsOpaquePointer() < b.getAsOpaquePointer(); } return opA->isBeforeInBlock(opB); - }); + }; + + // Automatically deallocate all still-allocated qubits + // Sort qubits for deterministic output + llvm::SmallVector sortedQubits(allocatedQubits.begin(), + allocatedQubits.end()); + llvm::sort(sortedQubits, blockOrderComparator); + for (auto qubit : sortedQubits) { DeallocOp::create(*this, qubit); } - // Clear the tracking set allocatedQubits.clear(); + // Automatically deallocate all still-allocated memrefs + // Sort memrefs for deterministic output + llvm::SmallVector sortedMemrefs(allocatedMemrefs.begin(), + allocatedMemrefs.end()); + llvm::sort(sortedMemrefs, blockOrderComparator); + + for (auto memref : sortedMemrefs) { + memref::DeallocOp::create(*this, memref); + } + + allocatedMemrefs.clear(); + // Create constant 0 for successful exit code auto exitCode = intConstant(0); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 05b55ff6dc..92a8f36114 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -109,20 +109,18 @@ QCOProgramBuilder::allocQubitRegister(const int64_t size, llvm::reportFatalUsageError("Size must be positive"); } - llvm::SmallVector qubits; - qubits.reserve(static_cast(size)); + auto qtensor = qtensorAlloc(size); - auto nameAttr = getStringAttr(name); - auto sizeAttr = getI64IntegerAttr(size); + llvm::SmallVector qubits; + qubits.reserve(size); for (int64_t i = 0; i < size; ++i) { - const auto indexAttr = getI64IntegerAttr(i); - auto allocOp = AllocOp::create(*this, nameAttr, sizeAttr, indexAttr); - const auto& qubit = qubits.emplace_back(allocOp.getResult()); - // Track the allocated qubit as valid - validQubits.insert(qubit); + auto [qtensorOut, qubit] = qtensorExtract(qtensor, i); + qtensor = qtensorOut; + qubits.emplace_back(qubit); } + // TODO: Return qtensor return qubits; } @@ -201,11 +199,13 @@ void QCOProgramBuilder::updateTensorTracking(Value inputTensor, Value QCOProgramBuilder::qtensorAlloc( const std::variant& size) { checkFinalized(); - auto sizeValue = utils::variantToValue(*this, getLoc(), size); + auto sizeValue = utils::variantToValue(*this, getLoc(), size); auto allocOp = qtensor::AllocOp::create(*this, sizeValue); + auto result = allocOp.getResult(); validTensors.insert(result); + return result; } @@ -924,6 +924,8 @@ OwningOpRef QCOProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } + // TODO: Determine "free" qubits? + auto blockOrderComparator = [](Value a, Value b) { auto* opA = a.getDefiningOp(); auto* opB = b.getDefiningOp(); @@ -942,6 +944,8 @@ OwningOpRef QCOProgramBuilder::finalize() { DeallocOp::create(*this, qubit); } + validQubits.clear(); + // Automatically deallocate all still-allocated tensors // Sort tensors for deterministic output llvm::SmallVector sortedTensors(validTensors.begin(), @@ -952,7 +956,6 @@ OwningOpRef QCOProgramBuilder::finalize() { qtensor::DeallocOp::create(*this, tensor); } - validQubits.clear(); validTensors.clear(); // Create constant 0 for successful exit code diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index bb81b94ecc..139219f79d 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" #include "mlir/Support/IRVerification.h" #include "mlir/Support/Passes.h" #include "qc_programs.h" @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -60,8 +62,9 @@ class QCOToQCTest : public testing::TestWithParam { void SetUp() override { // Register all necessary dialects DialectRegistry registry; - registry.insert(); + registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index 98d6c9e012..7b99692785 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" #include "mlir/Support/IRVerification.h" #include "mlir/Support/Passes.h" #include "qc_programs.h" @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -60,8 +62,9 @@ class QCToQCOTest : public testing::TestWithParam { void SetUp() override { // Register all necessary dialects DialectRegistry registry; - registry.insert(); + registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index 221b750f7e..3153062411 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -60,7 +61,8 @@ class QCTest : public testing::TestWithParam { void QCTest::SetUp() { // Register all necessary dialects DialectRegistry registry; - registry.insert(); + registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); diff --git a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp index e8850c3706..ef41b04bb5 100644 --- a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp +++ b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -58,7 +59,7 @@ class QuantumComputationTranslationTest void SetUp() override { mlir::DialectRegistry registry; registry.insert(); + mlir::func::FuncDialect, mlir::memref::MemRefDialect>(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); From d318cba0977b1b9007cc9d93672743c994024f48 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 20 Mar 2026 18:10:51 +0100 Subject: [PATCH 02/71] Update conversion from QC to QIR --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 19 +- .../Dialect/QIR/Builder/QIRProgramBuilder.h | 10 +- .../include/mlir/Dialect/QIR/Utils/QIRUtils.h | 18 +- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 2 +- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 75 +- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 640 ++++++++++-------- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 19 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 113 +++- mlir/lib/Dialect/QIR/Builder/CMakeLists.txt | 1 + .../Dialect/QIR/Builder/QIRProgramBuilder.cpp | 293 ++++---- .../Compiler/test_compiler_pipeline.cpp | 7 +- .../Conversion/QCToQIR/test_qc_to_qir.cpp | 3 +- 12 files changed, 697 insertions(+), 503 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index a09c49ea7a..553b07a584 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -10,6 +10,7 @@ #pragma once +#include #include #include #include @@ -272,8 +273,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { * %outTensor, %q0 = qtensor.extract %tensor[%c0]: tensor<3x!qco.qubit> * ``` */ - std::pair - qtensorExtract(Value tensor, const std::variant& index); + std::pair qtensorExtract(Value tensor, const int64_t index); /** * @brief Extract a qubit slice from a tensor @@ -1347,11 +1347,18 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ void updateQubitTracking(Value inputQubit, Value outputQubit); + int64_t tensorCounter = 0; + + struct QubitInfo { + int64_t regId = -1; + int64_t regIndex = -1; + }; + /// Track valid (unconsumed) qubit SSA values for linear type enforcement. /// Only values present in this set are valid for use in operations. /// When an operation consumes a qubit and produces a new one, the old value /// is removed and the new output is added. - llvm::DenseSet validQubits; + llvm::DenseMap validQubits; /** * @brief Validate that a tensor value is valid and unconsumed. This also @@ -1369,10 +1376,14 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ void updateTensorTracking(Value inputTensor, Value outputTensor); + struct TensorInfo { + int64_t regId = -1; + }; + /// Track valid (unconsumed) tensor SSA values for linear type enforcement. /// Only values present in this set are valid for use in operations. /// When an operation consumes a tensor and produces a new one, the old value /// is removed and the new output is added. - llvm::DenseSet validTensors; + llvm::DenseMap validTensors; }; } // namespace mlir::qco diff --git a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h index d09133fd71..b7809d872e 100644 --- a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h @@ -889,8 +889,14 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { /// Cache static pointers for reuse llvm::DenseMap ptrCache; - /// Map from (register_name, register_index) to result pointer - llvm::DenseMap, Value> registerResultMap; + /// Set of qubit-array pointers + llvm::DenseSet qubitArrays; + + /// Map from register name to result-array pointer + llvm::StringMap resultArrays; + + /// Map from result index to result pointer + llvm::DenseMap resultPtrs; /// Track qubit and result counts for QIR metadata QIRMetadata metadata_; diff --git a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h index 33190e535e..0bc326fcbb 100644 --- a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h +++ b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h @@ -34,11 +34,27 @@ namespace mlir::qir { // QIR function names +inline constexpr auto QIR_QUBIT_ARRAY_ALLOC = + "@__quantum__rt__qubit_array_allocate"; +inline constexpr auto QIR_QUBIT_ARRAY_RELEASE = + "@__quantum__rt__qubit_array_release"; + +inline constexpr auto QIR_QUBIT_ALLOC = "@__quantum__rt__qubit_allocate"; +inline constexpr auto QIR_QUBIT_RELEASE = "@__quantum__rt__qubit_release"; + +inline constexpr auto QIR_RESULT_ARRAY_ALLOC = + "@__quantum__rt__result_array_allocate"; +inline constexpr auto QIR_RESULT_ARRAY_RELEASE = + "@__quantum__rt__result_array_release"; + +inline constexpr auto QIR_RESULT_ALLOC = "@__quantum__rt__result_allocate"; +inline constexpr auto QIR_RESULT_RELEASE = "@__quantum__rt__result_release"; + inline constexpr auto QIR_INITIALIZE = "__quantum__rt__initialize"; inline constexpr auto QIR_MEASURE = "__quantum__qis__mz__body"; inline constexpr auto QIR_RECORD_OUTPUT = "__quantum__rt__result_record_output"; inline constexpr auto QIR_ARRAY_RECORD_OUTPUT = - "__quantum__rt__array_record_output"; + "__quantum__rt__result_array_record_output"; inline constexpr auto QIR_RESET = "__quantum__qis__reset__body"; inline constexpr auto QIR_GPHASE = "__quantum__qis__gphase__body"; diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 5855d73070..0bfeeaea8e 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -120,7 +120,7 @@ struct ConvertQTensorInsertOp final : OpConversionPattern { LogicalResult matchAndRewrite(qtensor::InsertOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - rewriter.eraseOp(op); + rewriter.replaceOp(op, adaptor.getDest()); return success(); } }; diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 299c980b8c..0f04d621d4 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -46,6 +46,11 @@ using namespace qc; namespace { +struct QubitInfo { + Value reg; + Value index; +}; + /** * @brief State object for tracking qubit value flow during conversion * @@ -80,6 +85,8 @@ struct LoweringState { llvm::DenseMap qtensorMap; + llvm::DenseMap qubitInfos; + /// Modifier information int64_t inNestedRegion = 0; DenseMap> targetsIn; @@ -178,17 +185,21 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap; + auto& qubitInfos = getState().qubitInfos; auto& qtensorMap = getState().qtensorMap; - // Look up the latest QTensor value for this QC register + // Look up latest QTensor value for this QC register auto memref = op.getMemref(); assert(qtensorMap.contains(memref) && "QC register not found"); auto qtensor = qtensorMap[memref]; - auto extract = qtensor::ExtractOp::create(rewriter, op.getLoc(), qtensor, - adaptor.getIndices()[0]); + auto index = adaptor.getIndices()[0]; + + auto extract = + qtensor::ExtractOp::create(rewriter, op.getLoc(), qtensor, index); qubitMap.try_emplace(op.getResult(), extract.getResult()); + qubitInfos.try_emplace(op.getResult(), QubitInfo{memref, index}); qtensorMap[memref] = extract.getOutTensor(); rewriter.eraseOp(op); @@ -207,12 +218,12 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { // auto& qubitMap = getState().qubitMap; // auto& qtensorMap = getState().qtensorMap; -// // Look up the latest QCO value for this QC qubit +// // Look up latest QCO value for this QC qubit // auto qcQubit = op.getValue(); // assert(qubitMap.contains(qcQubit) && "QC qubit not found"); // auto qcoQubit = qubitMap[qcQubit]; -// // Look up the latest QTensor value for this QC register +// // Look up latest QTensor value for this QC register // auto memref = op.getMemref(); // assert(qtensorMap.contains(memref) && "QC register not found"); // auto qtensor = qtensorMap[memref]; @@ -237,23 +248,43 @@ struct ConvertMemRefDeallocOp final matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap; + auto& qubitInfos = getState().qubitInfos; auto& qtensorMap = getState().qtensorMap; - // Look up the latest QTensor value for this QC register + // Look up latest QTensor value for this QC register auto memref = op.getMemref(); assert(qtensorMap.contains(memref) && "QC register not found"); auto qtensor = qtensorMap[memref]; - // Insert all qubits - // TODO: Use dedicated map - int64_t i = 0; - for (auto [_, qcoQubit] : qubitMap) { - auto index = arith::ConstantOp::create(rewriter, op.getLoc(), - rewriter.getIndexAttr(i)); + // Filter out qubits belonging to this tensor + llvm::SmallVector> toInsert; + toInsert.reserve(qubitMap.size()); + for (auto [qcQubit, qcoQubit] : qubitMap) { + auto& info = qubitInfos[qcQubit]; + if (info.reg != memref) { + continue; + } + toInsert.emplace_back(qcQubit, qcoQubit); + } + + // Sort qubits for deterministic output + llvm::sort(toInsert, [](const auto& a, const auto& b) { + auto* opA = a.first.getDefiningOp(); + auto* opB = b.first.getDefiningOp(); + if (!opA || !opB || opA->getBlock() != opB->getBlock()) { + return a.first.getAsOpaquePointer() < b.first.getAsOpaquePointer(); + } + return opA->isBeforeInBlock(opB); + }); + + // Insert qubits + for (auto [qcQubit, qcoQubit] : toInsert) { + auto& info = qubitInfos[qcQubit]; + auto index = info.index; auto insert = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, - qtensor, index.getResult()); + qtensor, index); qtensor = insert.getResult(); - ++i; + qubitInfos.erase(qcQubit); } rewriter.replaceOpWithNewOp(op, qtensor); @@ -330,7 +361,7 @@ struct ConvertQCDeallocOp final : StatefulOpConversionPattern { auto& qubitMap = getState().qubitMap; auto qcQubit = op.getQubit(); - // Look up the latest QCO value for this QC qubit + // Look up latest QCO value for this QC qubit assert(qubitMap.contains(qcQubit) && "QC qubit not found"); auto qcoQubit = qubitMap[qcQubit]; @@ -1027,6 +1058,8 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto& qubitMap = state.qubitMap; + auto& inNestedRegion = state.inNestedRegion; + auto& targetsIn = state.targetsIn; // Get QCO controls from state map auto qcControls = op.getControls(); @@ -1054,7 +1087,7 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { // Update the state map if this is a top-level CtrlOp // Nested CtrlOps are managed via the targetsIn and targetsOut maps - if (state.inNestedRegion == 0) { + if (inNestedRegion == 0) { for (const auto& [qcControl, qcoControl] : llvm::zip(qcControls, qcoOp.getControlsOut())) { qubitMap[qcControl] = qcoControl; @@ -1067,7 +1100,7 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { } // Update modifier information - state.inNestedRegion++; + inNestedRegion++; // Clone body region from QC to QCO auto& dstRegion = qcoOp.getRegion(); @@ -1086,7 +1119,7 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { qcoTargetAliases.emplace_back(entryBlock.addArgument(qubitType, opLoc)); } }); - state.targetsIn[state.inNestedRegion] = std::move(qcoTargetAliases); + targetsIn[inNestedRegion] = std::move(qcoTargetAliases); rewriter.eraseOp(op); return success(); @@ -1116,7 +1149,11 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(qc::InvOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& [qubitMap, _, inNestedRegion, targetsIn, targetsOut] = getState(); + auto& state = getState(); + auto& qubitMap = state.qubitMap; + auto& inNestedRegion = state.inNestedRegion; + auto& targetsIn = state.targetsIn; + auto& targetsOut = state.targetsOut; // Get QCO targets from state map const auto numTargets = op.getNumTargets(); diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index cd1d2721b7..a9b3b12cbc 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -73,19 +74,22 @@ namespace { * - Sequence of measurements for output recording */ struct LoweringState : QIRMetadata { - /// Map from register name to register start index - DenseMap registerStartIndexMap; - - /// Map from index to pointer value for reuse + /// Map from index to qubit pointer DenseMap ptrMap; - /// Map from (register_name, register_index) to result pointer - /// This allows caching result pointers for measurements with register info - DenseMap, Value> registerResultMap; + /// Map from register name to result-array pointer + llvm::StringMap resultArrays; + + /// Map from index to result pointer + DenseMap resultPtrs; /// Modifier information int64_t inCtrlOp = 0; DenseMap> controls; + + // Block information + Block* entryBlock; + Block* measurementsBlock; }; /** @@ -201,6 +205,121 @@ struct QCToQIRTypeConverter final : LLVMTypeConverter { // Convert QubitType to LLVM pointer (QIR uses opaque pointers for qubits) addConversion( [ctx](QubitType /*type*/) { return LLVM::LLVMPointerType::get(ctx); }); + + addConversion( + [ctx](MemRefType /*type*/) { return LLVM::LLVMPointerType::get(ctx); }); + } +}; + +struct ConvertMemRefAllocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + auto fnSig = + LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), + {rewriter.getI64Type(), ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration(rewriter, op, + QIR_QUBIT_ARRAY_ALLOC, fnSig); + + auto shape = op.getType().getShape(); + if (shape.size() != 1) { + return failure(); + } + + Value size; + if (shape[0] == ShapedType::kDynamic) { + size = adaptor.getDynamicSizes()[0]; + } else { + size = LLVM::ConstantOp::create( + rewriter, op.getLoc(), + rewriter.getI64IntegerAttr(static_cast(shape[0]))) + .getResult(); + } + + auto array = + LLVM::AllocaOp::create(rewriter, op.getLoc(), ptrType, ptrType, size); + auto zero = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrType); + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, + ValueRange{size, array.getResult(), zero.getResult()}); + + rewriter.replaceOp(op, array.getResult()); + + return success(); + } +}; + +struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + auto array = adaptor.getMemref(); + auto index = adaptor.getIndices()[0]; + auto gep = LLVM::GEPOp::create(rewriter, op.getLoc(), ptrType, ptrType, + array, index); + auto load = + LLVM::LoadOp::create(rewriter, op.getLoc(), ptrType, gep.getResult()); + + rewriter.replaceOp(op, load.getResult()); + + return success(); + } +}; + +struct ConvertMemRefDeallocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto* ctx = getContext(); + auto i64Type = rewriter.getI64Type(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + auto shape = op.getMemref().getType().getShape(); + if (shape.size() != 1) { + return failure(); + } + + // Save current insertion point + const OpBuilder::InsertionGuard guard(rewriter); + + // Switch to measurements block + rewriter.setInsertionPoint(getState().measurementsBlock->getTerminator()); + + Value size; + if (shape[0] == ShapedType::kDynamic) { + llvm::errs() << "I do not know yet\n"; + return failure(); + } else { + size = LLVM::ConstantOp::create( + rewriter, op.getLoc(), + rewriter.getI64IntegerAttr(static_cast(shape[0]))) + .getResult(); + } + + auto fnSig = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), + {i64Type, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration(rewriter, op, + QIR_QUBIT_ARRAY_RELEASE, fnSig); + + // Create the release call + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, + ValueRange{size, adaptor.getMemref()}); + rewriter.eraseOp(op); + + return success(); } }; @@ -225,66 +344,22 @@ struct QCToQIRTypeConverter final : LLVMTypeConverter { * %q0 = llvm.inttoptr %c0 : i64 to !llvm.ptr * ``` */ -struct ConvertQCAllocQIR final : StatefulOpConversionPattern { +struct ConvertQCAllocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(AllocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& state = getState(); - const auto numQubits = static_cast(state.numQubits); - auto& ptrMap = state.ptrMap; - auto& registerMap = state.registerStartIndexMap; - - // Get or create pointer value - if (op.getRegisterName() && op.getRegisterSize() && op.getRegisterIndex()) { - const auto registerName = op.getRegisterName().value(); - const auto registerSize = - static_cast(op.getRegisterSize().value()); - const auto registerIndex = - static_cast(op.getRegisterIndex().value()); + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); - if (const auto it = registerMap.find(registerName); - it != registerMap.end()) { - // Register is already tracked - // The pointer was created by the step below - const auto globalIndex = it->second + registerIndex; - if (!ptrMap.contains(globalIndex)) { - return op.emitError("Pointer not found"); - } - rewriter.replaceOp(op, ptrMap.at(globalIndex)); - return success(); - } + auto fnSig = LLVM::LLVMFunctionType::get(ptrType, {ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(rewriter, op, QIR_QUBIT_ALLOC, fnSig); - // Allocate the entire register as static qubits - registerMap[registerName] = numQubits; - SmallVector pointers; - pointers.reserve(registerSize); - for (int64_t i = 0; i < registerSize; ++i) { - Value val{}; - if (const auto it = ptrMap.find(numQubits + i); it != ptrMap.end()) { - val = it->second; - } else { - val = createPointerFromIndex(rewriter, op.getLoc(), numQubits + i); - ptrMap[numQubits + i] = val; - } - pointers.push_back(val); - } - rewriter.replaceOp(op, pointers[registerIndex]); - state.numQubits += registerSize; - return success(); - } + auto zero = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrType); + rewriter.replaceOpWithNewOp(op, fnDec, zero.getResult()); - // no register info, check if ptr has already been allocated (as a Result) - Value val{}; - if (const auto it = ptrMap.find(numQubits); it != ptrMap.end()) { - val = it->second; - } else { - val = createPointerFromIndex(rewriter, op.getLoc(), numQubits); - ptrMap[numQubits] = val; - } - rewriter.replaceOp(op, val); - state.numQubits++; return success(); } }; @@ -293,9 +368,9 @@ struct ConvertQCAllocQIR final : StatefulOpConversionPattern { * @brief Erases qc.dealloc operations * * @details - * Since QIR 2.0 does not support dynamic qubit allocation, dynamic allocations - * are converted to static allocations. Therefore, deallocation operations - * become no-ops and are simply removed from the IR. + * Since QIR 2.0 does not support dynamic qubit allocation, dynamic + * allocations are converted to static allocations. Therefore, deallocation + * operations become no-ops and are simply removed from the IR. * * @par Example: * ```mlir @@ -306,13 +381,29 @@ struct ConvertQCAllocQIR final : StatefulOpConversionPattern { * // (removed) * ``` */ -struct ConvertQCDeallocQIR final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertQCDeallocOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult - matchAndRewrite(DeallocOp op, OpAdaptor /*adaptor*/, + matchAndRewrite(DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // Save current insertion point + const OpBuilder::InsertionGuard guard(rewriter); + + // Switch to measurements block + rewriter.setInsertionPoint(getState().measurementsBlock->getTerminator()); + + auto fnSig = + LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(rewriter, op, QIR_QUBIT_RELEASE, fnSig); + + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, adaptor.getQubit()); rewriter.eraseOp(op); + return success(); } }; @@ -335,32 +426,14 @@ struct ConvertQCDeallocQIR final : OpConversionPattern { * %q0 = llvm.inttoptr %c0 : i64 to !llvm.ptr * ``` */ -struct ConvertQCStaticQIR final : StatefulOpConversionPattern { +struct ConvertQCStaticOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(StaticOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - const auto index = static_cast(op.getIndex()); - auto& state = getState(); - // Get or create a pointer to the qubit - Value val{}; - if (const auto it = state.ptrMap.find(index); it != state.ptrMap.end()) { - // Reuse existing pointer - val = it->second; - } else { - // Create and cache for reuse - val = createPointerFromIndex(rewriter, op.getLoc(), index); - state.ptrMap.try_emplace(index, val); - } - rewriter.replaceOp(op, val); - - // Track maximum qubit index - if (std::cmp_greater_equal(index, state.numQubits)) { - state.numQubits = index + 1; - } - - return success(); + // TODO: Figure this out + return failure(); } }; @@ -369,9 +442,10 @@ struct ConvertQCStaticQIR final : StatefulOpConversionPattern { * * @details * Converts qubit measurement to a QIR call to `__quantum__qis__mz__body`. - * Unlike the previous implementation, this does NOT immediately record output. - * Instead, it tracks measurements in the lowering state for deferred output - * recording in a separate output block, as required by the QIR Base Profile. + * Unlike the previous implementation, this does NOT immediately record + * output. Instead, it tracks measurements in the lowering state for deferred + * output recording in a separate output block, as required by the QIR Base + * Profile. * * For measurements with register information, the result pointer is mapped * to (register_name, register_index) for later retrieval. For measurements @@ -385,81 +459,98 @@ struct ConvertQCStaticQIR final : StatefulOpConversionPattern { * ```mlir * %c0_i64 = llvm.mlir.constant(0 : i64) : i64 * %result_ptr = llvm.inttoptr %c0_i64 : i64 to !llvm.ptr - * llvm.call @__quantum__qis__mz__body(%q, %result_ptr) : (!llvm.ptr, !llvm.ptr) + * llvm.call @__quantum__qis__mz__body(%q, %result_ptr) : (!llvm.ptr, + * !llvm.ptr) * -> () * ``` */ -struct ConvertQCMeasureQIR final : StatefulOpConversionPattern { +struct ConvertQCMeasureOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(MeasureOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto* ctx = getContext(); - const auto ptrType = LLVM::LLVMPointerType::get(ctx); auto& state = getState(); - const auto numResults = static_cast(state.numResults); - auto& ptrMap = state.ptrMap; - auto& registerResultMap = state.registerResultMap; + auto& resultArrays = state.resultArrays; + auto& resultPtrs = state.resultPtrs; + + auto* ctx = getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // Save current insertion point + const OpBuilder::InsertionGuard guard(rewriter); - // Get or create result pointer value - Value resultValue; + // Insert allocations and constants in entry block + rewriter.setInsertionPoint(state.entryBlock->getTerminator()); + + // Get result pointer + Value result; if (op.getRegisterName() && op.getRegisterSize() && op.getRegisterIndex()) { const auto registerName = op.getRegisterName().value(); const auto registerSize = static_cast(op.getRegisterSize().value()); const auto registerIndex = static_cast(op.getRegisterIndex().value()); - const auto key = std::make_pair(registerName, registerIndex); - - if (const auto it = registerResultMap.find(key); - it != registerResultMap.end()) { - resultValue = it->second; - } else { - // Allocate the entire register as static results - for (int64_t i = 0; i < registerSize; ++i) { - Value val{}; - if (const auto ptrIt = ptrMap.find(numResults + i); - ptrIt != ptrMap.end()) { - val = ptrIt->second; - } else { - val = createPointerFromIndex(rewriter, op.getLoc(), numResults + i); - ptrMap[numResults + i] = val; - } - registerResultMap.try_emplace({registerName, i}, val); - } - state.numResults += registerSize; - resultValue = registerResultMap.at(key); + + // Create result register if it does not exist yet + if (resultArrays.find(registerName) == resultArrays.end()) { + auto fnSig = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), + {rewriter.getI64Type(), ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration( + rewriter, op, QIR_RESULT_ARRAY_ALLOC, fnSig); + + auto size = + LLVM::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getI64IntegerAttr(registerSize)) + .getResult(); + auto array = LLVM::AllocaOp::create(rewriter, op.getLoc(), ptrType, + ptrType, size); + auto zero = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrType); + LLVM::CallOp::create( + rewriter, op.getLoc(), fnDec, + ValueRange{size, array.getResult(), zero.getResult()}); + resultArrays.try_emplace(registerName, array.getResult()); } + + auto array = resultArrays[registerName]; + auto index = + LLVM::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getI64IntegerAttr(registerIndex)) + .getResult(); + auto gep = LLVM::GEPOp::create(rewriter, op.getLoc(), ptrType, ptrType, + array, index); + auto load = + LLVM::LoadOp::create(rewriter, op.getLoc(), ptrType, gep.getResult()); + result = load.getResult(); } else { - // Choose a safe default register name - StringRef defaultRegName = "c"; - if (llvm::any_of(registerResultMap, [](const auto& entry) { - return entry.first.first == "c"; - })) { - defaultRegName = "__unnamed__"; - } - // No register info, check if ptr has already been allocated (as a Qubit) - if (const auto it = ptrMap.find(numResults); it != ptrMap.end()) { - resultValue = it->second; - } else { - resultValue = createPointerFromIndex(rewriter, op.getLoc(), numResults); - ptrMap[numResults] = resultValue; - } - registerResultMap.insert({{defaultRegName, numResults}, resultValue}); - state.numResults++; + auto fnSig = + LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(rewriter, op, QIR_RESULT_ALLOC, fnSig); + + auto zero = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrType); + result = + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, zero.getResult()) + .getResult(); + + resultPtrs.try_emplace(resultPtrs.size(), result); } - // Declare QIR function - const auto fnSignature = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(ctx), {ptrType, ptrType}); - const auto fnDecl = - getOrCreateFunctionDeclaration(rewriter, op, QIR_MEASURE, fnSignature); + // Switch to measurements block + rewriter.setInsertionPoint(state.measurementsBlock->getTerminator()); + + // Create measure call + auto fnSig = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), + {ptrType, ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(rewriter, op, QIR_MEASURE, fnSig); + + LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, + ValueRange{adaptor.getQubit(), result}); + + rewriter.replaceOp(op, result); - // Create CallOp and replace qc.measure with result pointer - LLVM::CallOp::create(rewriter, op.getLoc(), fnDecl, - ValueRange{adaptor.getQubit(), resultValue}); - rewriter.replaceOp(op, resultValue); return success(); } }; @@ -480,14 +571,20 @@ struct ConvertQCMeasureQIR final : StatefulOpConversionPattern { * llvm.call @__quantum__qis__reset__body(%q) : (!llvm.ptr) -> () * ``` */ -struct ConvertQCResetQIR final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertQCResetOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(ResetOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto* ctx = getContext(); + // Save current insertion point + const OpBuilder::InsertionGuard guard(rewriter); + + // Switch to measurements block + rewriter.setInsertionPoint(getState().measurementsBlock->getTerminator()); + // Declare QIR function const auto fnSignature = LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(ctx), LLVM::LLVMPointerType::get(ctx)); @@ -515,7 +612,7 @@ struct ConvertQCResetQIR final : OpConversionPattern { * llvm.call @__quantum__qis__gphase__body(%theta) : (f64) -> () * ``` */ -struct ConvertQCGPhaseOpQIR final : StatefulOpConversionPattern { +struct ConvertQCGPhaseOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -546,8 +643,7 @@ struct ConvertQCGPhaseOpQIR final : StatefulOpConversionPattern { * llvm.call @__quantum__qis__QIR_NAME__body(%q) : (!llvm.ptr) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -594,8 +690,7 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, SXDG, sxdg, sxdg) * -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -635,8 +730,7 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, P, p, p, theta) * (!llvm.ptr, f64, f64) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -674,8 +768,7 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, U2, u2, u2, phi, lambda) * : (!llvm.ptr, f64, f64, f64) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -712,8 +805,7 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, U, u, u3) * !llvm.ptr) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -753,8 +845,7 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ECR, ecr, ecr) * (!llvm.ptr, !llvm.ptr, f64) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -795,8 +886,7 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, RZZ, rzz, rzz, theta) * (!llvm.ptr, !llvm.ptr, f64, f64) -> () \ * ``` \ */ \ - struct ConvertQC##OP_CLASS##QIR final \ - : StatefulOpConversionPattern { \ + struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ using StatefulOpConversionPattern::StatefulOpConversionPattern; \ \ LogicalResult \ @@ -824,7 +914,7 @@ DEFINE_TWO_TARGET_TWO_PARAMETER(XXMinusYYOp, XXMINUSYY, xx_minus_yy, /** * @brief Erases qc.barrier operation, as it is a no-op in QIR */ -struct ConvertQCBarrierQIR final : StatefulOpConversionPattern { +struct ConvertQCBarrierOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -838,7 +928,7 @@ struct ConvertQCBarrierQIR final : StatefulOpConversionPattern { /** * @brief Inlines qc.ctrl region removes the operation */ -struct ConvertQCCtrlQIR final : StatefulOpConversionPattern { +struct ConvertQCCtrlOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -862,7 +952,7 @@ struct ConvertQCCtrlQIR final : StatefulOpConversionPattern { /** * @brief Erases qc.yield operation */ -struct ConvertQCYieldQIR final : StatefulOpConversionPattern { +struct ConvertQCYieldOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -883,8 +973,7 @@ struct ConvertQCYieldQIR final : StatefulOpConversionPattern { * * Conversion stages: * 1. Convert func dialect to LLVM - * 2. Ensure proper block structure for QIR base profile and add - * initialization + * 2. Ensure proper block structure for QIR base profile and add initialization * 3. Convert QC operations to QIR calls * 4. Set QIR metadata attributes * 5. Convert arith and cf dialects to LLVM @@ -892,8 +981,8 @@ struct ConvertQCYieldQIR final : StatefulOpConversionPattern { * * @pre * The input entry function must consist of a single block. The pass will - * restructure it into four blocks. Multi-block input functions are currently - * not supported. + * restructure it into four blocks. Multi-block input functions are + * currently not supported. */ struct QCToQIR final : impl::QCToQIRBase { using QCToQIRBase::QCToQIRBase; @@ -906,7 +995,7 @@ struct QCToQIR final : impl::QCToQIRBase { * 1. **Entry block**: Contains constant operations and initialization * 2. **Body block**: Contains reversible quantum operations (gates) * 3. **Measurements block**: Contains irreversible operations (measure, - * reset, dealloc) + * reset, dealloc) * 4. **Output block**: Contains output recording calls * * Blocks are connected with unconditional jumps (entry, body, measurements, @@ -916,7 +1005,7 @@ struct QCToQIR final : impl::QCToQIRBase { * * @param main The main LLVM function to restructure */ - static void ensureBlocks(LLVM::LLVMFuncOp& main) { + static void ensureBlocks(LLVM::LLVMFuncOp& main, LoweringState& state) { // Return if there are already multiple blocks if (main.getBlocks().size() > 1) { return; @@ -934,24 +1023,22 @@ struct QCToQIR final : impl::QCToQIRBase { Block* measurementsBlock = builder.createBlock(&main.getBody()); Block* outputBlock = builder.createBlock(&main.getBody()); + state.entryBlock = entryBlock; + state.measurementsBlock = measurementsBlock; + auto& bodyBlockOps = bodyBlock->getOperations(); auto& outputBlockOps = outputBlock->getOperations(); - auto& measurementsBlockOps = measurementsBlock->getOperations(); // Move operations to appropriate blocks for (auto it = bodyBlock->begin(); it != bodyBlock->end();) { // Ensure iterator remains valid after potential move - if (auto& op = *it++; - isa(op) || isa(op) || isa(op)) { - // Move irreversible quantum operations to measurements block - measurementsBlockOps.splice(measurementsBlock->end(), bodyBlockOps, - Block::iterator(op)); - } else if (isa(op)) { + if (auto& op = *it++; isa(op)) { // Move return to output block outputBlockOps.splice(outputBlock->end(), bodyBlockOps, Block::iterator(op)); - } else if (op.hasTrait()) { - // Move constant like operations to the entry block + } else if (isa(op) || isa(op) || + isa(op) || op.hasTrait()) { + // Move allocations and constant-like operations to entry block entryBlock->getOperations().splice(entryBlock->end(), bodyBlockOps, Block::iterator(op)); } @@ -973,10 +1060,9 @@ struct QCToQIR final : impl::QCToQIRBase { * @brief Adds QIR initialization call to the entry block * * @details - * Inserts a call to `__quantum__rt__initialize` at the end of the entry - * block (before the jump to main block). This QIR runtime function - * initializes the quantum execution environment and takes a null pointer as - * argument. + * Inserts a call to `__quantum__rt__initialize` at the end of the entry block + * (before the jump to main block). This QIR runtime function initializes the + * quantum execution environment and takes a null pointer as argument. * * @param main The main LLVM function * @param ctx The MLIR context @@ -1009,8 +1095,7 @@ struct QCToQIR final : impl::QCToQIRBase { // Create the initialization call LLVM::CallOp::create(builder, main->getLoc(), - cast(fnDecl), - ValueRange{zeroOp->getResult(0)}); + cast(fnDecl), zeroOp->getResult(0)); } /** @@ -1042,8 +1127,8 @@ struct QCToQIR final : impl::QCToQIRBase { * ``` * * Any output recording calls that are not part of registers (i.e., - * measurements without register info) are grouped under a default label - * "c" and recorded similarly. + * measurements without register info) are grouped under a default label "c" + * and recorded similarly. * * @param main The main LLVM function * @param ctx The MLIR context @@ -1051,12 +1136,16 @@ struct QCToQIR final : impl::QCToQIRBase { */ static void addOutputRecording(LLVM::LLVMFuncOp& main, MLIRContext* ctx, LoweringState* state) { - if (state->registerResultMap.empty()) { + auto& resultArrays = state->resultArrays; + auto& resultPtrs = state->resultPtrs; + + if (resultArrays.empty() && resultPtrs.empty()) { return; // No measurements to record } OpBuilder builder(ctx); - const auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto voidType = LLVM::LLVMVoidType::get(ctx); // Find the output block auto& outputBlock = main.getBlocks().back(); @@ -1064,59 +1153,52 @@ struct QCToQIR final : impl::QCToQIRBase { // Insert before the branch to output block builder.setInsertionPoint(&outputBlock.back()); - // Group measurements by register - llvm::StringMap>> registerGroups; - for (const auto& [key, resultPtr] : state->registerResultMap) { - const auto& [registerName, registerIndex] = key; - registerGroups[registerName].emplace_back(registerIndex, resultPtr); - } + if (!resultPtrs.empty()) { + // Sort result pointers for deterministic output + llvm::SmallVector> sortedPtrs; + for (const auto& [index, resultPtr] : resultPtrs) { + sortedPtrs.emplace_back(index, resultPtr); + } + llvm::sort(sortedPtrs, [](const auto& a, const auto& b) { + return a.first < b.first; + }); - // Sort registers by name for deterministic output - SmallVector>>> - sortedRegisters; - for (auto& [name, measurements] : registerGroups) { - sortedRegisters.emplace_back(name, std::move(measurements)); + // Create output recording for each result pointer + auto fnSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration(builder, main, + QIR_RECORD_OUTPUT, fnSig); + + for (const auto& [index, ptr] : sortedPtrs) { + auto label = createResultLabel(builder, main, + "__unnamed__" + std::to_string(index)) + .getResult(); + LLVM::CallOp::create(builder, main->getLoc(), fnDec, + ValueRange{ptr, label}); + } } - llvm::sort(sortedRegisters, - [](const auto& a, const auto& b) { return a.first < b.first; }); - - // create function declarations for output recording - const auto arrayRecordSig = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(ctx), {builder.getI64Type(), ptrType}); - const auto arrayRecordDecl = getOrCreateFunctionDeclaration( - builder, main, QIR_ARRAY_RECORD_OUTPUT, arrayRecordSig); - - const auto resultRecordSig = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(ctx), {ptrType, ptrType}); - const auto resultRecordDecl = getOrCreateFunctionDeclaration( - builder, main, QIR_RECORD_OUTPUT, resultRecordSig); - - // Generate output recording for each register - for (auto& [registerName, measurements] : sortedRegisters) { - // Sort measurements by register index - llvm::sort(measurements, [](const auto& a, const auto& b) { + + if (!resultArrays.empty()) { + // Sort registers by name for deterministic output + SmallVector> sortedRegisters; + for (auto& [name, results] : resultArrays) { + sortedRegisters.emplace_back(name, std::move(results)); + } + llvm::sort(sortedRegisters, [](const auto& a, const auto& b) { return a.first < b.first; }); - const auto arraySize = measurements.size(); - auto arrayLabelOp = createResultLabel(builder, main, registerName); - auto arraySizeConst = LLVM::ConstantOp::create( - builder, main->getLoc(), - builder.getI64IntegerAttr(static_cast(arraySize))); - - LLVM::CallOp::create( - builder, main->getLoc(), arrayRecordDecl, - ValueRange{arraySizeConst.getResult(), arrayLabelOp.getResult()}); - - // Create result_record_output calls for each measurement - for (auto [regIdx, resultPtr] : measurements) { - // Create label for result: "{arrayCounter+1+i}_{registerName}{i}r" - const std::string resultLabel = - registerName.str() + std::to_string(regIdx) + "r"; - auto resultLabelOp = createResultLabel(builder, main, resultLabel); - - LLVM::CallOp::create(builder, main->getLoc(), resultRecordDecl, - ValueRange{resultPtr, resultLabelOp.getResult()}); + auto fnSig = LLVM::LLVMFunctionType::get( + voidType, {builder.getI64Type(), ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration( + builder, main, QIR_ARRAY_RECORD_OUTPUT, fnSig); + + // Generate output recording for each register + for (auto& [name, results] : sortedRegisters) { + auto size = results.getDefiningOp().getArraySize(); + auto label = createResultLabel(builder, main, name).getResult(); + + LLVM::CallOp::create(builder, main->getLoc(), fnDec, + ValueRange{size, results, label}); } } } @@ -1134,24 +1216,23 @@ struct QCToQIR final : impl::QCToQIRBase { * * **Stage 2: Block structure and initialization** * Create proper 4-block structure for QIR base profile (entry, main, - * irreversible, output) and insert the `__quantum__rt__initialize` call - * in the entry block. + * irreversible, output) and insert the `__quantum__rt__initialize` call in + * the entry block. * * **Stage 3: QC to LLVM** - * Convert QC dialect operations to QIR calls and add output recording to - * the output block. + * Convert QC dialect operations to QIR calls and add output recording to the + * output block. * * **Stage 4: QIR attributes** - * Add QIR base profile metadata to the main function, including - * qubit/result counts and version information. + * Add QIR base profile metadata to the main function, including qubit/result + * counts and version information. * * **Stage 5: Standard dialects to LLVM** * Convert arith and control flow dialects to LLVM (for index arithmetic and * function control flow). * * **Stage 6: Reconcile casts** - * Clean up any unrealized cast operations introduced during type - * conversion. + * Clean up any unrealized cast operations introduced during type conversion. */ void runOnOperation() override { MLIRContext* ctx = &getContext(); @@ -1174,7 +1255,6 @@ struct QCToQIR final : impl::QCToQIRBase { } } - // Stage 2: Ensure proper block structure and add initialization auto main = getMainFunction(moduleOp); if (!main) { moduleOp->emitError("No main function with entry_point attribute found"); @@ -1182,56 +1262,31 @@ struct QCToQIR final : impl::QCToQIRBase { return; } - ensureBlocks(main); - addInitialize(main, ctx); - LoweringState state; + // Stage 2: Create block structure + ensureBlocks(main, state); + // Stage 3: Convert QC dialect to LLVM (QIR calls) { - RewritePatternSet qcPatterns(ctx); - target.addIllegalDialect(); - - // Add conversion patterns for QC operations - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - qcPatterns.add(typeConverter, ctx, &state); - - if (applyPartialConversion(moduleOp, target, std::move(qcPatterns)) + RewritePatternSet patterns(ctx); + target.addIllegalDialect(); + + patterns.add( + typeConverter, ctx, &state); + + if (applyPartialConversion(moduleOp, target, std::move(patterns)) .failed()) { signalPassFailure(); return; @@ -1240,6 +1295,9 @@ struct QCToQIR final : impl::QCToQIRBase { addOutputRecording(main, ctx, &state); } + // Stage ?: Insert initialize call + addInitialize(main, ctx); + // Stage 4: Set QIR metadata attributes setQIRAttributes(main, state); diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index e138e0794b..3dfad8e92a 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -503,13 +503,12 @@ OwningOpRef QCProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } - // llvm::SmallVector freeQubits; - // for (auto qubit : allocatedQubits) { - // auto memref = qubit.getDefiningOp(); - // if (!memref) { - // freeQubits.emplace_back(qubit); - // } - // } + llvm::SmallVector freeQubits; + for (auto qubit : allocatedQubits) { + if (!llvm::isa(qubit.getDefiningOp())) { + freeQubits.emplace_back(qubit); + } + } auto blockOrderComparator = [](Value a, Value b) { auto* opA = a.getDefiningOp(); @@ -522,16 +521,13 @@ OwningOpRef QCProgramBuilder::finalize() { // Automatically deallocate all still-allocated qubits // Sort qubits for deterministic output - llvm::SmallVector sortedQubits(allocatedQubits.begin(), - allocatedQubits.end()); + llvm::SmallVector sortedQubits(freeQubits.begin(), freeQubits.end()); llvm::sort(sortedQubits, blockOrderComparator); for (auto qubit : sortedQubits) { DeallocOp::create(*this, qubit); } - allocatedQubits.clear(); - // Automatically deallocate all still-allocated memrefs // Sort memrefs for deterministic output llvm::SmallVector sortedMemrefs(allocatedMemrefs.begin(), @@ -542,6 +538,7 @@ OwningOpRef QCProgramBuilder::finalize() { memref::DeallocOp::create(*this, memref); } + allocatedQubits.clear(); allocatedMemrefs.clear(); // Create constant 0 for successful exit code diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 92a8f36114..729cca295f 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -78,7 +78,7 @@ Value QCOProgramBuilder::allocQubit() { const auto qubit = allocOp.getResult(); // Track the allocated qubit as valid - validQubits.insert(qubit); + validQubits.insert({qubit, {}}); return qubit; } @@ -95,7 +95,7 @@ Value QCOProgramBuilder::staticQubit(const int64_t index) { const auto qubit = staticOp.getQubit(); // Track the static qubit as valid - validQubits.insert(qubit); + validQubits.insert({qubit, {}}); return qubit; } @@ -155,11 +155,14 @@ void QCOProgramBuilder::updateQubitTracking(Value inputQubit, // Validate the input qubit validateQubitValue(inputQubit); + auto it = validQubits.find(inputQubit); + auto info = it->second; + // Remove the input (consumed) value from tracking - validQubits.erase(inputQubit); + validQubits.erase(it); // Add the output (new) value to tracking - validQubits.insert(outputQubit); + validQubits.insert({outputQubit, info}); } void QCOProgramBuilder::validateTensorValue(Value tensor) const { @@ -185,11 +188,14 @@ void QCOProgramBuilder::updateTensorTracking(Value inputTensor, // Validate the input tensor validateTensorValue(inputTensor); + auto it = validTensors.find(inputTensor); + auto info = it->second; + // Remove the input (consumed) value from tracking - validTensors.erase(inputTensor); + validTensors.erase(it); // Add the output (new) value to tracking - validTensors.insert(outputTensor); + validTensors.insert({outputTensor, info}); } //===----------------------------------------------------------------------===// @@ -204,7 +210,7 @@ Value QCOProgramBuilder::qtensorAlloc( auto allocOp = qtensor::AllocOp::create(*this, sizeValue); auto result = allocOp.getResult(); - validTensors.insert(result); + validTensors.insert({result, {tensorCounter++}}); return result; } @@ -226,21 +232,21 @@ Value QCOProgramBuilder::qtensorFromElements(ValueRange elements) { auto fromElementsOp = qtensor::FromElementsOp::create(*this, elements); auto result = fromElementsOp.getResult(); - validTensors.insert(result); + validTensors.insert({result, {tensorCounter++}}); return result; } -std::pair -QCOProgramBuilder::qtensorExtract(Value tensor, - const std::variant& index) { +std::pair QCOProgramBuilder::qtensorExtract(Value tensor, + const int64_t index) { checkFinalized(); - auto indexValue = utils::variantToValue(*this, getLoc(), index); + auto indexValue = + arith::ConstantOp::create(*this, getIndexAttr(index)).getResult(); auto extractOp = qtensor::ExtractOp::create(*this, tensor, indexValue); auto qubit = extractOp.getResult(); auto outTensor = extractOp.getOutTensor(); - validQubits.insert(qubit); + validQubits.insert({qubit, {validTensors[tensor].regId, index}}); updateTensorTracking(tensor, outTensor); return {outTensor, qubit}; @@ -258,7 +264,7 @@ std::pair QCOProgramBuilder::qtensorExtractSlice( auto slicedTensor = extractSliceOp.getResult(); auto outTensor = extractSliceOp.getOutTensor(); - validTensors.insert(slicedTensor); + validTensors.insert({slicedTensor, {tensorCounter++}}); updateTensorTracking(tensor, outTensor); return {outTensor, slicedTensor}; @@ -848,8 +854,8 @@ ValueRange QCOProgramBuilder::qcoIf( for (auto qubitType : qubits.getTypes()) { const auto thenArg = thenBlock.addArgument(qubitType, getLoc()); const auto elseArg = elseBlock.addArgument(qubitType, getLoc()); - validQubits.insert(thenArg); - validQubits.insert(elseArg); + validQubits.insert({thenArg, {}}); + validQubits.insert({elseArg, {}}); } // Construct the bodies of the regions @@ -924,8 +930,6 @@ OwningOpRef QCOProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } - // TODO: Determine "free" qubits? - auto blockOrderComparator = [](Value a, Value b) { auto* opA = a.getDefiningOp(); auto* opB = b.getDefiningOp(); @@ -935,27 +939,76 @@ OwningOpRef QCOProgramBuilder::finalize() { return opA->isBeforeInBlock(opB); }; + auto blockOrderComparator1 = [](const std::pair& a, + const std::pair& b) { + auto* opA = a.first.getDefiningOp(); + auto* opB = b.first.getDefiningOp(); + if (!opA || !opB || opA->getBlock() != opB->getBlock()) { + return a.first.getAsOpaquePointer() < b.first.getAsOpaquePointer(); + } + return opA->isBeforeInBlock(opB); + }; + + auto blockOrderComparator2 = [](const std::pair& a, + const std::pair& b) { + auto* opA = a.first.getDefiningOp(); + auto* opB = b.first.getDefiningOp(); + if (!opA || !opB || opA->getBlock() != opB->getBlock()) { + return a.first.getAsOpaquePointer() < b.first.getAsOpaquePointer(); + } + if (opA != opB) { + return opA->isBeforeInBlock(opB); + } + return llvm::cast(a.first).getResultNumber() < + llvm::cast(b.first).getResultNumber(); + }; + + llvm::SmallVector freeQubits; + llvm::DenseMap registerQubits; + for (auto [qubit, info] : validQubits) { + if (info.regId == -1) { + freeQubits.push_back(qubit); + } else { + registerQubits.insert({qubit, info}); + } + } + // Automatically deallocate all still-allocated qubits // Sort qubits for deterministic output - llvm::SmallVector sortedQubits(validQubits.begin(), validQubits.end()); - llvm::sort(sortedQubits, blockOrderComparator); + llvm::SmallVector sortedFreeQubits(freeQubits.begin(), + freeQubits.end()); + llvm::sort(sortedFreeQubits, blockOrderComparator); - for (auto qubit : sortedQubits) { + for (auto qubit : sortedFreeQubits) { DeallocOp::create(*this, qubit); } - validQubits.clear(); - // Automatically deallocate all still-allocated tensors - // Sort tensors for deterministic output - llvm::SmallVector sortedTensors(validTensors.begin(), - validTensors.end()); - llvm::sort(sortedTensors, blockOrderComparator); - - for (auto tensor : sortedTensors) { - qtensor::DeallocOp::create(*this, tensor); + if (!validTensors.empty()) { + // Sort tensors for deterministic output + llvm::SmallVector> sortedTensors( + validTensors.begin(), validTensors.end()); + llvm::sort(sortedTensors, blockOrderComparator1); + for (auto& [tensor, tensorInfo] : sortedTensors) { + // Filter out qubits belonging to this tensor + SmallVector> toInsert; + for (auto& [qubit, qubitInfo] : registerQubits) { + if (qubitInfo.regId != tensorInfo.regId) { + continue; + } + toInsert.push_back({qubit, qubitInfo.regIndex}); + } + // Sort qubits for deterministic output + llvm::sort(toInsert, blockOrderComparator2); + // Insert qubits + for (auto& [qubit, index] : toInsert) { + tensor = qtensorInsert(qubit, tensor, index); + } + qtensor::DeallocOp::create(*this, tensor); + } } + validQubits.clear(); validTensors.clear(); // Create constant 0 for successful exit code diff --git a/mlir/lib/Dialect/QIR/Builder/CMakeLists.txt b/mlir/lib/Dialect/QIR/Builder/CMakeLists.txt index 3fa768bdef..27adc79530 100644 --- a/mlir/lib/Dialect/QIR/Builder/CMakeLists.txt +++ b/mlir/lib/Dialect/QIR/Builder/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_library( LINK_LIBS PUBLIC MLIRLLVMDialect + MLIRMemRefDialect MLIRQIRUtils MLIRIR MLIRSupport diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index 5ca6cfd050..570be420d6 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -67,15 +68,9 @@ void QIRProgramBuilder::initialize() { measurementsBlock = mainFuncOp.addBlock(); outputBlock = mainFuncOp.addBlock(); - // Create exit code constant in entry block (where constants belong) and add - // QIR initialization call in entry block (after exit code constant) + // Create exit code constant in entry block setInsertionPointToStart(entryBlock); - auto zeroOp = LLVM::ZeroOp::create(*this, ptrType); exitCode = intConstant(0); - const auto initType = LLVM::LLVMFunctionType::get(voidType, ptrType); - auto initFunc = - getOrCreateFunctionDeclaration(*this, module, QIR_INITIALIZE, initType); - LLVM::CallOp::create(*this, initFunc, zeroOp.getResult()); // Add unconditional branches between blocks setInsertionPointToEnd(entryBlock); @@ -106,28 +101,8 @@ Value QIRProgramBuilder::doubleConstant(double value) { } Value QIRProgramBuilder::staticQubit(const int64_t index) { - checkFinalized(); - - if (index < 0) { - llvm::reportFatalUsageError("Index must be non-negative"); - } - - // Check cache - Value val{}; - if (const auto it = ptrCache.find(index); it != ptrCache.end()) { - val = it->second; - } else { - val = createPointerFromIndex(*this, getLoc(), index); - // Cache for reuse - ptrCache[index] = val; - } - - // Update qubit count - if (std::cmp_greater_equal(index, metadata_.numQubits)) { - metadata_.numQubits = static_cast(index) + 1; - } - - return val; + // TODO: Figure this out + llvm::reportFatalInternalError("Currently not implemented"); } SmallVector QIRProgramBuilder::allocQubitRegister(const int64_t size) { @@ -137,11 +112,34 @@ SmallVector QIRProgramBuilder::allocQubitRegister(const int64_t size) { llvm::reportFatalUsageError("Size must be positive"); } + // Save current insertion point + const InsertionGuard guard(*this); + + // Insert allocations and constants in entry block + setInsertionPoint(entryBlock->getTerminator()); + SmallVector qubits; qubits.reserve(size); + auto allocFnSignature = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(getContext()), {getI64Type(), ptrType, ptrType}); + auto allocFnDecl = getOrCreateFunctionDeclaration( + *this, module, QIR_QUBIT_ARRAY_ALLOC, allocFnSignature); + + auto array = + LLVM::AllocaOp::create(*this, ptrType, ptrType, intConstant(size)); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + auto alloc = LLVM::CallOp::create( + *this, allocFnDecl, + ValueRange{intConstant(size), array.getResult(), zero.getResult()}); + + qubitArrays.insert(array.getResult()); + for (int64_t i = 0; i < size; ++i) { - qubits.push_back(staticQubit(static_cast(metadata_.numQubits))); + auto gepOp = LLVM::GEPOp::create(*this, ptrType, ptrType, array.getResult(), + ValueRange{intConstant(i)}); + auto loadOp = LLVM::LoadOp::create(*this, ptrType, gepOp.getResult()); + qubits.push_back(loadOp.getResult()); } return qubits; @@ -159,22 +157,29 @@ QIRProgramBuilder::allocClassicalBitRegister(const int64_t size, // Save current insertion point const InsertionGuard guard(*this); - // Insert in measurements block (before branch) - setInsertionPoint(measurementsBlock->getTerminator()); + // Insert allocations and constants in entry block + setInsertionPoint(entryBlock->getTerminator()); + + auto allocFnSignature = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(getContext()), {getI64Type(), ptrType, ptrType}); + auto allocFnDecl = getOrCreateFunctionDeclaration( + *this, module, QIR_RESULT_ARRAY_ALLOC, allocFnSignature); + + auto array = + LLVM::AllocaOp::create(*this, ptrType, ptrType, intConstant(size)); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + auto alloc = LLVM::CallOp::create( + *this, allocFnDecl, + ValueRange{intConstant(size), array.getResult(), zero.getResult()}); + + resultArrays.try_emplace(name, array.getResult()); - const auto numResults = static_cast(metadata_.numResults); for (int64_t i = 0; i < size; ++i) { - Value val{}; - if (const auto it = ptrCache.find(numResults + i); it != ptrCache.end()) { - val = it->second; - } else { - val = createPointerFromIndex(*this, getLoc(), numResults + i); - // Cache for reuse - ptrCache[numResults + i] = val; - } - registerResultMap.insert({{stringSaver.save(name), i}, val}); + auto gep = LLVM::GEPOp::create(*this, ptrType, ptrType, array.getResult(), + ValueRange{intConstant(i)}); + auto load = LLVM::LoadOp::create(*this, ptrType, gep.getResult()); } - metadata_.numResults += size; + return {.name = name, .size = size}; } @@ -185,50 +190,33 @@ Value QIRProgramBuilder::measure(Value qubit, const int64_t resultIndex) { llvm::reportFatalUsageError("Result index must be non-negative"); } - // Choose a safe default register name - static constexpr llvm::StringLiteral DEFAULT_REG_NAME = "c"; - StringRef regName{DEFAULT_REG_NAME}; - if (llvm::any_of(registerResultMap, [](const auto& entry) { - return entry.first.first == DEFAULT_REG_NAME; - })) { - static constexpr llvm::StringLiteral FALLBACK_REG_NAME = "__unnamed__"; - regName = FALLBACK_REG_NAME; - } - // Save current insertion point const InsertionGuard guard(*this); - // Insert in measurements block (before branch) - setInsertionPoint(measurementsBlock->getTerminator()); + // Insert allocations and constants in entry block + setInsertionPoint(entryBlock->getTerminator()); - const auto key = std::make_pair(regName, resultIndex); - if (const auto it = registerResultMap.find(key); - it != registerResultMap.end()) { - return it->second; - } + // Create result pointer + auto fnSig = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(getContext()), {ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(*this, module, QIR_RESULT_ALLOC, fnSig); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + auto result = + LLVM::CallOp::create(*this, fnDec, zero.getResult()).getResult(); - Value resultValue{}; - if (const auto it = ptrCache.find(resultIndex); it != ptrCache.end()) { - resultValue = it->second; - } else { - resultValue = createPointerFromIndex(*this, getLoc(), resultIndex); - ptrCache[resultIndex] = resultValue; - registerResultMap.try_emplace(key, resultValue); - } + resultPtrs.try_emplace(resultIndex, result); - // Update result count - if (std::cmp_greater_equal(resultIndex, metadata_.numResults)) { - metadata_.numResults = static_cast(resultIndex) + 1; - } + // Switch to measurements block + setInsertionPoint(measurementsBlock->getTerminator()); - // Create mz call - const auto mzSignature = - LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); - auto mzDecl = - getOrCreateFunctionDeclaration(*this, module, QIR_MEASURE, mzSignature); - LLVM::CallOp::create(*this, mzDecl, ValueRange{qubit, resultValue}); + // Create measure call + const auto mzSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); + auto mzDec = + getOrCreateFunctionDeclaration(*this, module, QIR_MEASURE, mzSig); + LLVM::CallOp::create(*this, mzDec, ValueRange{qubit, result}); - return resultValue; + return result; } QIRProgramBuilder& QIRProgramBuilder::measure(Value qubit, const Bit& bit) { @@ -237,24 +225,32 @@ QIRProgramBuilder& QIRProgramBuilder::measure(Value qubit, const Bit& bit) { // Save current insertion point const InsertionGuard guard(*this); - // Insert in measurements block (before branch) - setInsertionPoint(measurementsBlock->getTerminator()); + // Insert allocations and constants in entry block + setInsertionPoint(entryBlock->getTerminator()); + + auto index = intConstant(bit.registerIndex); - // Check if we already have a result pointer for this register slot + // Get array pointer const auto& registerName = bit.registerName; - const auto registerIndex = bit.registerIndex; - const auto key = std::make_pair(registerName, registerIndex); - if (!registerResultMap.contains(key)) { + if (!resultArrays.contains(registerName)) { llvm::reportFatalInternalError("Result pointer not found"); } - const auto resultValue = registerResultMap.at(key); + auto array = resultArrays.at(registerName); - // Create mz call - const auto mzSignature = - LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); - auto mzDecl = - getOrCreateFunctionDeclaration(*this, module, QIR_MEASURE, mzSignature); - LLVM::CallOp::create(*this, mzDecl, ValueRange{qubit, resultValue}); + // Get result pointer + auto gep = + LLVM::GEPOp::create(*this, ptrType, ptrType, array, ValueRange{index}); + auto load = LLVM::LoadOp::create(*this, ptrType, gep.getResult()); + auto result = load.getResult(); + + // Switch to measurements block + setInsertionPoint(measurementsBlock->getTerminator()); + + // Create measure call + const auto fnSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(*this, module, QIR_MEASURE, fnSig); + LLVM::CallOp::create(*this, fnDec, ValueRange{qubit, result}); return *this; } @@ -265,7 +261,7 @@ QIRProgramBuilder& QIRProgramBuilder::reset(Value qubit) { // Save current insertion point const InsertionGuard guard(*this); - // Insert in measurements block (before branch) + // Switch to measurements block setInsertionPoint(measurementsBlock->getTerminator()); // Create reset call @@ -578,7 +574,7 @@ void QIRProgramBuilder::checkFinalized() const { } void QIRProgramBuilder::generateOutputRecording() { - if (registerResultMap.empty()) { + if (resultArrays.empty() && resultPtrs.empty()) { return; // No measurements to record } @@ -588,55 +584,48 @@ void QIRProgramBuilder::generateOutputRecording() { // Insert in output block (before return) setInsertionPoint(outputBlock->getTerminator()); - // Group measurements by register - llvm::StringMap>> registerGroups; - for (const auto& [key, resultPtr] : registerResultMap) { - const auto& [regName, regIdx] = key; - registerGroups[regName].emplace_back(regIdx, resultPtr); + if (!resultPtrs.empty()) { + // Sort result pointers for deterministic output + llvm::SmallVector> sortedPtrs; + for (const auto& [index, resultPtr] : resultPtrs) { + sortedPtrs.emplace_back(index, resultPtr); + } + llvm::sort(sortedPtrs, + [](const auto& a, const auto& b) { return a.first < b.first; }); + + // Create output recording for each result pointer + auto fnSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(*this, module, QIR_RECORD_OUTPUT, fnSig); + + for (const auto& [index, ptr] : sortedPtrs) { + auto label = createResultLabel(*this, module, + "__unnamed__" + std::to_string(index)) + .getResult(); + LLVM::CallOp::create(*this, fnDec, ValueRange{ptr, label}); + } } - // Sort registers by name for deterministic output - SmallVector>>> - sortedRegisters; - for (auto& [name, measurements] : registerGroups) { - sortedRegisters.emplace_back(name, std::move(measurements)); - } - sort(sortedRegisters, - [](const auto& a, const auto& b) { return a.first < b.first; }); - - // Create array_record_output call - const auto arrayRecordSig = - LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType}); - const auto arrayRecordDecl = getOrCreateFunctionDeclaration( - *this, module, QIR_ARRAY_RECORD_OUTPUT, arrayRecordSig); - - // Create result_record_output calls for each measurement - const auto resultRecordSig = - LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); - const auto resultRecordDecl = getOrCreateFunctionDeclaration( - *this, module, QIR_RECORD_OUTPUT, resultRecordSig); - - // Generate output recording for each register - for (auto& [registerName, measurements] : sortedRegisters) { - // Sort measurements by register index - sort(measurements, - [](const auto& a, const auto& b) { return a.first < b.first; }); - - const auto arraySize = measurements.size(); - auto arrayLabelOp = createResultLabel(*this, module, registerName); - auto arraySizeConst = intConstant(static_cast(arraySize)); - - LLVM::CallOp::create(*this, arrayRecordDecl, - ValueRange{arraySizeConst, arrayLabelOp.getResult()}); - - for (const auto& [regIdx, resultPtr] : measurements) { - // Create label for result: "{registerName}{regIdx}r" - const std::string resultLabel = - registerName + std::to_string(regIdx) + "r"; - auto resultLabelOp = createResultLabel(*this, module, resultLabel); - - LLVM::CallOp::create(*this, resultRecordDecl, - ValueRange{resultPtr, resultLabelOp.getResult()}); + if (!resultArrays.empty()) { + // Sort registers by name for deterministic output + SmallVector> sortedArrays; + for (auto& [name, results] : resultArrays) { + sortedArrays.emplace_back(name, std::move(results)); + } + llvm::sort(sortedArrays, + [](const auto& a, const auto& b) { return a.first < b.first; }); + + auto fnSig = + LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType, ptrType}); + auto fnDec = getOrCreateFunctionDeclaration(*this, module, + QIR_ARRAY_RECORD_OUTPUT, fnSig); + + // Create output recording for each register + for (auto& [name, results] : sortedArrays) { + auto size = results.getDefiningOp().getArraySize(); + auto label = createResultLabel(*this, module, name).getResult(); + + LLVM::CallOp::create(*this, fnDec, ValueRange{size, results, label}); } } } @@ -644,6 +633,28 @@ void QIRProgramBuilder::generateOutputRecording() { OwningOpRef QIRProgramBuilder::finalize() { checkFinalized(); + const InsertionGuard guard(*this); + + // Insert initialization at end of entry block + setInsertionPoint(entryBlock->getTerminator()); + + auto initSig = LLVM::LLVMFunctionType::get(voidType, ptrType); + auto initDec = + getOrCreateFunctionDeclaration(*this, module, QIR_INITIALIZE, initSig); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + LLVM::CallOp::create(*this, initDec, zero.getResult()); + + // Insert in output block (before return) + setInsertionPoint(measurementsBlock->getTerminator()); + + for (auto array : qubitArrays) { + auto sig = LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType}); + auto decl = getOrCreateFunctionDeclaration(*this, module, + QIR_QUBIT_ARRAY_RELEASE, sig); + auto size = array.getDefiningOp().getArraySize(); + LLVM::CallOp::create(*this, decl, ValueRange{size, array}); + } + // Generate output recording in the output block generateOutputRecording(); diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index a81f6a7c01..d6b40060dd 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/QC/Translation/TranslateQuantumComputationToQC.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QIR/Builder/QIRProgramBuilder.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" #include "mlir/Support/IRVerification.h" #include "mlir/Support/Passes.h" #include "qc_programs.h" @@ -27,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -85,8 +87,9 @@ class CompilerPipelineTest void SetUp() override { mlir::DialectRegistry registry; registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index 82ab251e41..696aadf565 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -61,7 +62,7 @@ class QCToQIRTest : public testing::TestWithParam { void SetUp() override { DialectRegistry registry; registry.insert(); + func::FuncDialect, memref::MemRefDialect>(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); From 98f4055ba5a12c103e5f18c48867bfd4de1b9c0e Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:45:30 +0100 Subject: [PATCH 03/71] Update conversion between QCO and Jeff --- mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp | 70 ++++++++++++- mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp | 97 ++++++++++++++++++- .../Dialect/QTensor/IR/Operations/AllocOp.cpp | 23 ++--- 3 files changed, 173 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp b/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp index 5599753f96..ec1d7541a2 100644 --- a/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp +++ b/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include #include @@ -381,6 +383,64 @@ static LogicalResult cleanUp(Operation* op) { namespace { +struct ConvertJeffQuregAllocOpToQCO final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(jeff::QuregAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto size = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getIndexType(), adaptor.getNumQubits()); + rewriter.replaceOpWithNewOp(op, size.getResult()); + return success(); + } +}; + +struct ConvertJeffQuregExtractIndexOpToQCO final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(jeff::QuregExtractIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto index = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getIndexType(), adaptor.getIndex()); + rewriter.replaceOpWithNewOp(op, adaptor.getInQreg(), + index.getResult()); + return success(); + } +}; + +struct ConvertJeffQuregInsertIndexOpToQCO final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(jeff::QuregInsertIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto index = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getIndexType(), adaptor.getIndex()); + rewriter.replaceOpWithNewOp(op, adaptor.getInQubit(), + adaptor.getInQreg(), + + index.getResult()); + return success(); + } +}; + +struct ConvertJeffQuregFreeZeroOpToQCO final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(jeff::QuregFreeZeroOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getQreg()); + return success(); + } +}; + /** * @brief Converts jeff.qubit_alloc to qco.alloc * @@ -904,6 +964,11 @@ class JeffToQCOTypeConverter final : public TypeConverter { addConversion([ctx](jeff::QubitType /*type*/) -> Type { return qco::QubitType::get(ctx); }); + + addConversion([ctx](jeff::QuregType /*type*/) -> Type { + return RankedTensorType::get({ShapedType::kDynamic}, + qco::QubitType::get(ctx)); + }); } }; @@ -924,7 +989,8 @@ struct JeffToQCO final : impl::JeffToQCOBase { // Configure conversion target target.addIllegalDialect(); - target.addLegalDialect(); target.addDynamicallyLegalOp([&](func::FuncOp op) { @@ -936,6 +1002,8 @@ struct JeffToQCO final : impl::JeffToQCOBase { // Register operation conversion patterns jeff::populateJeffToNativeConversionPatterns(patterns); patterns.add< + ConvertJeffQuregAllocOpToQCO, ConvertJeffQuregExtractIndexOpToQCO, + ConvertJeffQuregInsertIndexOpToQCO, ConvertJeffQuregFreeZeroOpToQCO, ConvertJeffQubitAllocOpToQCO, ConvertJeffQubitFreeOpToQCO, ConvertJeffQubitFreeZeroOpToQCO, ConvertJeffQubitMeasureOpToQCO, ConvertJeffQubitMeasureNDOpToQCO, ConvertJeffQubitResetOpToQCO, diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index 124c095060..566b201e44 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include #include @@ -23,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -244,6 +247,80 @@ static LogicalResult cleanUp(Operation* op, LoweringState& state) { namespace { +struct ConvertQTensorAllocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // TODO: Why is this not happening in native conversion? + auto sizeValue = getConstantIntValue(adaptor.getSize()); + Value size; + if (sizeValue.has_value()) { + size = jeff::IntConst32Op::create(rewriter, op.getLoc(), *sizeValue); + } else { + size = adaptor.getSize(); + } + rewriter.replaceOpWithNewOp(op, size); + return success(); + } +}; + +struct ConvertQTensorExtractOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // TODO: Why is this not happening in native conversion? + auto indexValue = getConstantIntValue(adaptor.getIndex()); + Value index; + if (indexValue.has_value()) { + index = jeff::IntConst32Op::create(rewriter, op.getLoc(), *indexValue); + } else { + index = adaptor.getIndex(); + } + rewriter.replaceOpWithNewOp( + op, adaptor.getTensor(), index); + return success(); + } +}; + +struct ConvertQTensorInsertOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // TODO: Why is this not happening in native conversion? + auto indexValue = getConstantIntValue(adaptor.getIndex()); + Value index; + if (indexValue.has_value()) { + index = jeff::IntConst32Op::create(rewriter, op.getLoc(), *indexValue); + } else { + index = adaptor.getIndex(); + } + rewriter.replaceOpWithNewOp( + op, adaptor.getDest(), index, adaptor.getScalar()); + return success(); + } +}; + +struct ConvertQTensorDeallocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(qtensor::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getTensor()); + return success(); + } +}; + /** * @brief Converts qco.alloc to jeff.qubit_alloc * @@ -1321,7 +1398,8 @@ struct ConvertQCOMainToJeff final : StatefulOpConversionPattern { * @brief Type converter for QCO-to-Jeff conversion * * @details - * Converts `!qco.qubit` to `!jeff.qubit`. + * Converts `!qco.qubit` to `!jeff.qubit` and tensor to + * tensor. */ class QCOToJeffTypeConverter final : public TypeConverter { public: @@ -1332,6 +1410,13 @@ class QCOToJeffTypeConverter final : public TypeConverter { addConversion([ctx](qco::QubitType /*type*/) -> Type { return jeff::QubitType::get(ctx); }); + + addConversion([ctx](RankedTensorType type) -> Type { + if (llvm::isa(type.getElementType())) { + return jeff::QuregType::get(ctx); + } + return type; + }); } }; @@ -1353,7 +1438,8 @@ struct QCOToJeff final : impl::QCOToJeffBase { LoweringState state; // Configure conversion target - target.addIllegalDialect(); target.addLegalDialect(); @@ -1364,9 +1450,10 @@ struct QCOToJeff final : impl::QCOToJeffBase { // Register operation conversion patterns jeff::populateNativeToJeffConversionPatterns(patterns); patterns.add< - ConvertQCOAllocOpToJeff, ConvertQCODeallocOpToJeff, - ConvertQCOMeasureOpToJeff, ConvertQCOResetOpToJeff, - ConvertQCOGPhaseOpToJeff, + ConvertQTensorAllocOp, ConvertQTensorExtractOp, ConvertQTensorInsertOp, + ConvertQTensorDeallocOp, ConvertQCOAllocOpToJeff, + ConvertQCODeallocOpToJeff, ConvertQCOMeasureOpToJeff, + ConvertQCOResetOpToJeff, ConvertQCOGPhaseOpToJeff, ConvertQCOOneTargetZeroParameterToJeff, ConvertQCOOneTargetZeroParameterToJeff, ConvertQCOOneTargetZeroParameterToJeff, diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp index 898b8b6412..d6eb6fc693 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp @@ -45,17 +45,18 @@ LogicalResult AllocOp::verify() { if (sizeValue && *sizeValue <= 0) { return emitOpError("Constant size operand must be positive"); } - if (sizeValue.has_value() == resultType.isDynamicDim(0)) { - return emitOpError("Size operand and result type must both be static or " - "both be dynamic, but got ") - << (sizeValue ? "static size with dynamic result" - : "dynamic size with static result"); - } - if (sizeValue && resultSize != *sizeValue) { - return emitOpError("Constant size operand (") - << *sizeValue << ") does not match static result size (" - << resultSize << ")"; - } + // TODO: Deal with this + // if (sizeValue.has_value() == resultType.isDynamicDim(0)) { + // return emitOpError("Size operand and result type must both be static or " + // "both be dynamic, but got ") + // << (sizeValue ? "static size with dynamic result" + // : "dynamic size with static result"); + // } + // if (sizeValue && resultSize != *sizeValue) { + // return emitOpError("Constant size operand (") + // << *sizeValue << ") does not match static result size (" + // << resultSize << ")"; + // } return success(); } From 6458e73921387e48b988e9f13ec4145418ba32e8 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 23 Mar 2026 15:16:39 +0100 Subject: [PATCH 04/71] Fix QC IR tests --- mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp | 38 ++++++++++----------- mlir/unittests/programs/qc_programs.cpp | 38 ++++++++++++++++++--- mlir/unittests/programs/qc_programs.h | 18 ++++++++++ 3 files changed, 71 insertions(+), 23 deletions(-) diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index 3153062411..f3ff5738de 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -177,20 +177,19 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCBarrierOpTest, QCTest, - testing::Values(QCTestCase{"Barrier", MQT_NAMED_BUILDER(barrier), - MQT_NAMED_BUILDER(barrier)}, - QCTestCase{"BarrierTwoQubits", - MQT_NAMED_BUILDER(barrierTwoQubits), - MQT_NAMED_BUILDER(barrierTwoQubits)}, - QCTestCase{"BarrierMultipleQubits", - MQT_NAMED_BUILDER(barrierMultipleQubits), - MQT_NAMED_BUILDER(barrierMultipleQubits)}, - QCTestCase{"SingleControlledBarrier", - MQT_NAMED_BUILDER(singleControlledBarrier), - MQT_NAMED_BUILDER(barrier)}, - QCTestCase{"InverseBarrier", - MQT_NAMED_BUILDER(inverseBarrier), - MQT_NAMED_BUILDER(barrier)})); + testing::Values( + QCTestCase{"Barrier", MQT_NAMED_BUILDER(barrier), + MQT_NAMED_BUILDER(barrier)}, + QCTestCase{"BarrierTwoQubits", MQT_NAMED_BUILDER(barrierTwoQubits), + MQT_NAMED_BUILDER(barrierTwoQubits)}, + QCTestCase{"BarrierMultipleQubits", + MQT_NAMED_BUILDER(barrierMultipleQubits), + MQT_NAMED_BUILDER(barrierMultipleQubits)}, + QCTestCase{"SingleControlledBarrier", + MQT_NAMED_BUILDER(singleControlledBarrier), + MQT_NAMED_BUILDER(singleControlledBarrierCanonicalized)}, + QCTestCase{"InverseBarrier", MQT_NAMED_BUILDER(inverseBarrier), + MQT_NAMED_BUILDER(barrier)})); /// @} /// \name QC/Operations/StandardGates/DcxOp.cpp @@ -258,7 +257,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(multipleControlledP)}, QCTestCase{"NestedControlledGlobalPhase", MQT_NAMED_BUILDER(nestedControlledGlobalPhase), - MQT_NAMED_BUILDER(singleControlledP)}, + MQT_NAMED_BUILDER(nestedControlledGlobalPhaseCanonicalized)}, QCTestCase{"TrivialControlledGlobalPhase", MQT_NAMED_BUILDER(trivialControlledGlobalPhase), MQT_NAMED_BUILDER(globalPhase)}, @@ -300,13 +299,13 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(identity)}, QCTestCase{"SingleControlledIdentity", MQT_NAMED_BUILDER(singleControlledIdentity), - MQT_NAMED_BUILDER(identity)}, + MQT_NAMED_BUILDER(singleControlledIdentityCanonicalized)}, QCTestCase{"MultipleControlledIdentity", MQT_NAMED_BUILDER(multipleControlledIdentity), - MQT_NAMED_BUILDER(identity)}, + MQT_NAMED_BUILDER(multipleControlledIdentityCanonicalized)}, QCTestCase{"NestedControlledIdentity", MQT_NAMED_BUILDER(nestedControlledIdentity), - MQT_NAMED_BUILDER(identity)}, + MQT_NAMED_BUILDER(nestedControlledIdentityCanonicalized)}, QCTestCase{"TrivialControlledIdentity", MQT_NAMED_BUILDER(trivialControlledIdentity), MQT_NAMED_BUILDER(identity)}, @@ -314,7 +313,8 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(identity)}, QCTestCase{"InverseMultipleControlledIdentity", MQT_NAMED_BUILDER(inverseMultipleControlledIdentity), - MQT_NAMED_BUILDER(identity)})); + MQT_NAMED_BUILDER( + inverseMultipleControlledIdentityCanonicalized)})); /// @} /// \name QC/Operations/StandardGates/IswapOp.cpp diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 2134b1f368..a2ddd3d704 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -123,8 +123,13 @@ void multipleControlledGlobalPhase(QCProgramBuilder& b) { } void nestedControlledGlobalPhase(QCProgramBuilder& b) { - auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cgphase(0.123, reg[1]); }); + auto q = b.allocQubitRegister(3); + b.ctrl(q[0], [&] { b.cgphase(0.123, q[1]); }); +} + +void nestedControlledGlobalPhaseCanonicalized(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.cp(0.123, q[0], q[1]); } void trivialControlledGlobalPhase(QCProgramBuilder& b) { @@ -151,14 +156,29 @@ void singleControlledIdentity(QCProgramBuilder& b) { b.cid(q[1], q[0]); } +void singleControlledIdentityCanonicalized(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.id(q[0]); +} + void multipleControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); b.mcid({q[2], q[1]}, q[0]); } +void multipleControlledIdentityCanonicalized(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.id(q[0]); +} + void nestedControlledIdentity(QCProgramBuilder& b) { - auto reg = b.allocQubitRegister(3); - b.ctrl(reg[2], [&] { b.cid(reg[1], reg[0]); }); + auto q = b.allocQubitRegister(3); + b.ctrl(q[2], [&] { b.cid(q[1], q[0]); }); +} + +void nestedControlledIdentityCanonicalized(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.id(q[0]); } void trivialControlledIdentity(QCProgramBuilder& b) { @@ -176,6 +196,11 @@ void inverseMultipleControlledIdentity(QCProgramBuilder& b) { b.inv([&]() { b.mcid({q[2], q[1]}, q[0]); }); } +void inverseMultipleControlledIdentityCanonicalized(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.id(q[0]); +} + void x(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.x(q[0]); @@ -1152,6 +1177,11 @@ void singleControlledBarrier(QCProgramBuilder& b) { b.ctrl(q[1], [&] { b.barrier(q[0]); }); } +void singleControlledBarrierCanonicalized(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.barrier(q[0]); +} + void inverseBarrier(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.inv([&]() { b.barrier(q[0]); }); diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index 21225c5b44..95e253c777 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -84,6 +84,9 @@ void multipleControlledGlobalPhase(QCProgramBuilder& b); /// Creates a circuit with a nested controlled global phase gate. void nestedControlledGlobalPhase(QCProgramBuilder& b); +/// Canonicalized version of `nestedControlledGlobalPhase`. +void nestedControlledGlobalPhaseCanonicalized(QCProgramBuilder& b); + /// Creates a circuit with a trivial controlled global phase gate. void trivialControlledGlobalPhase(QCProgramBuilder& b); @@ -102,12 +105,21 @@ void identity(QCProgramBuilder& b); /// Creates a controlled identity gate with a single control qubit. void singleControlledIdentity(QCProgramBuilder& b); +/// Canonicalized version of `singleControlledIdentity`. +void singleControlledIdentityCanonicalized(QCProgramBuilder& b); + /// Creates a multi-controlled identity gate with multiple control qubits. void multipleControlledIdentity(QCProgramBuilder& b); +/// Canonicalized version of `multipleControlledIdentity`. +void multipleControlledIdentityCanonicalized(QCProgramBuilder& b); + /// Creates a circuit with a nested controlled identity gate. void nestedControlledIdentity(QCProgramBuilder& b); +/// Canonicalized version of `nestedControlledIdentity`. +void nestedControlledIdentityCanonicalized(QCProgramBuilder& b); + /// Creates a circuit with a trivial controlled identity gate. void trivialControlledIdentity(QCProgramBuilder& b); @@ -118,6 +130,9 @@ void inverseIdentity(QCProgramBuilder& b); /// gate. void inverseMultipleControlledIdentity(QCProgramBuilder& b); +/// Canonicalized version of `inverseMultipleControlledIdentity`. +void inverseMultipleControlledIdentityCanonicalized(QCProgramBuilder& b); + // --- XOp ------------------------------------------------------------------ // /// Creates a circuit with just an X gate. @@ -764,6 +779,9 @@ void barrierMultipleQubits(QCProgramBuilder& b); /// Creates a circuit with a single controlled barrier. void singleControlledBarrier(QCProgramBuilder& b); +/// Canonicalized version of `singleControlledBarrier`. +void singleControlledBarrierCanonicalized(QCProgramBuilder& b); + /// Creates a circuit with an inverse modifier applied to a barrier. void inverseBarrier(QCProgramBuilder& b); From 685551bdd8036a6de4b9528077037c17091d2cb5 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:15:55 +0100 Subject: [PATCH 05/71] Fix QCO IR tests --- .../mlir/Dialect/QTensor/IR/QTensorOps.td | 2 +- .../lib/Dialect/QCO/IR/Operations/ResetOp.cpp | 27 +++++- .../QTensor/IR/Operations/InsertOp.cpp | 82 ++++++++++++------- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 31 ++++--- mlir/unittests/programs/qco_programs.cpp | 5 ++ mlir/unittests/programs/qco_programs.h | 3 + 6 files changed, 101 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index f421ee2dc6..41ddedcfcf 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -207,7 +207,7 @@ def InsertOp $scalar `into` $dest `[` $index `]` attr-dict `:` type($dest) }]; - let hasFolder = 1; + let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index 7ae1a733da..219206cdcd 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include #include @@ -21,7 +22,8 @@ using namespace mlir::qco; namespace { /** - * @brief Remove reset operations that immediately follow an allocation. + * @brief Remove reset operations that immediately follow a `qco.alloc` + * operation. */ struct RemoveResetAfterAlloc final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -39,9 +41,30 @@ struct RemoveResetAfterAlloc final : OpRewritePattern { } }; +/** + * @brief Remove reset operations that immediately follow a `qtensor.extract` + * operation. + */ +struct RemoveResetAfterExtract final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ResetOp op, + PatternRewriter& rewriter) const override { + // Check if the predecessor is an ExtractOp + if (auto extractOp = op.getQubitIn().getDefiningOp(); + !extractOp) { + return failure(); + } + + // Remove the ResetOp + rewriter.replaceOp(op, op.getQubitIn()); + return success(); + } +}; + } // namespace void ResetOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 982d4a6335..5d7588f1fc 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -8,17 +8,66 @@ * Licensed under the MIT License */ +#include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include #include +#include #include +#include #include #include using namespace mlir; using namespace mlir::qtensor; +static ExtractOp findExtractOp(InsertOp op) { + + auto definingOp = op.getDest().getDefiningOp(); + if (llvm::isa(definingOp)) { + return llvm::cast(definingOp); + } else if (llvm::isa(definingOp)) { + auto nestedInsertOp = llvm::cast(definingOp); + return findExtractOp(nestedInsertOp); + } else { + return nullptr; + } +} + +namespace { + +struct RemoveExtractInsertPair final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertOp op, + PatternRewriter& rewriter) const override { + auto extractOp = findExtractOp(op); + if (!extractOp) { + return failure(); + } + + if (op.getScalar() != extractOp.getResult()) { + return failure(); + } + + if (op.getIndex() != extractOp.getIndex()) { + return failure(); + } + + // TODO: Improve this + auto qubit = qco::AllocOp::create(rewriter, op.getLoc()); + rewriter.replaceOp(extractOp, {extractOp.getTensor(), qubit.getResult()}); + qco::DeallocOp::create(rewriter, op.getLoc(), qubit.getResult()); + + rewriter.replaceOp(op, op.getDest()); + + return success(); + } +}; + +} // namespace + LogicalResult InsertOp::verify() { auto dstDim = getDest().getType().getDimSize(0); auto index = getConstantIntValue(getIndex()); @@ -35,34 +84,7 @@ LogicalResult InsertOp::verify() { return success(); } -/** - * @brief If an InsertOp consumes an ExtractOp with the same index, - * return the tensor from the extractOp directly. - */ -static Value foldInsertAfterExtract(InsertOp insertOp) { - auto extractOp = insertOp.getScalar().getDefiningOp(); - - if (!extractOp) { - return nullptr; - } - if (insertOp.getDest() != extractOp.getOutTensor()) { - return nullptr; - } - - auto insertIndex = insertOp.getIndex(); - auto extractIndex = extractOp.getIndex(); - - if (getAsOpFoldResult(insertIndex) != getAsOpFoldResult(extractIndex)) { - return nullptr; - } - - return extractOp.getTensor(); -} - -OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { - if (auto result = foldInsertAfterExtract(*this)) { - return result; - } - - return {}; +void InsertOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); } diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index ad4981d1b7..c6f81b4dda 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -185,22 +185,21 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCOBarrierOpTest, QCOTest, - testing::Values(QCOTestCase{"Barrier", MQT_NAMED_BUILDER(barrier), - MQT_NAMED_BUILDER(barrier)}, - QCOTestCase{"BarrierTwoQubits", - MQT_NAMED_BUILDER(barrierTwoQubits), - MQT_NAMED_BUILDER(barrierTwoQubits)}, - QCOTestCase{"BarrierMultipleQubits", - MQT_NAMED_BUILDER(barrierMultipleQubits), - MQT_NAMED_BUILDER(barrierMultipleQubits)}, - QCOTestCase{"SingleControlledBarrier", - MQT_NAMED_BUILDER(singleControlledBarrier), - MQT_NAMED_BUILDER(barrier)}, - QCOTestCase{"InverseBarrier", - MQT_NAMED_BUILDER(inverseBarrier), - MQT_NAMED_BUILDER(barrier)}, - QCOTestCase{"TwoBarrier", MQT_NAMED_BUILDER(twoBarrier), - MQT_NAMED_BUILDER(barrierTwoQubits)})); + testing::Values( + QCOTestCase{"Barrier", MQT_NAMED_BUILDER(barrier), + MQT_NAMED_BUILDER(barrier)}, + QCOTestCase{"BarrierTwoQubits", MQT_NAMED_BUILDER(barrierTwoQubits), + MQT_NAMED_BUILDER(barrierTwoQubits)}, + QCOTestCase{"BarrierMultipleQubits", + MQT_NAMED_BUILDER(barrierMultipleQubits), + MQT_NAMED_BUILDER(barrierMultipleQubits)}, + QCOTestCase{"SingleControlledBarrier", + MQT_NAMED_BUILDER(singleControlledBarrier), + MQT_NAMED_BUILDER(singleControlledBarrierCanonicalized)}, + QCOTestCase{"InverseBarrier", MQT_NAMED_BUILDER(inverseBarrier), + MQT_NAMED_BUILDER(barrier)}, + QCOTestCase{"TwoBarrier", MQT_NAMED_BUILDER(twoBarrier), + MQT_NAMED_BUILDER(barrierTwoQubits)})); /// @} /// \name QCO/Operations/StandardGates/DcxOp.cpp diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 1ef16df6d0..657d94450b 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -1893,6 +1893,11 @@ void singleControlledBarrier(QCOProgramBuilder& b) { }); } +void singleControlledBarrierCanonicalized(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.barrier(q[0]); +} + void inverseBarrier(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.inv({q[0]}, [&](mlir::ValueRange qubits) { diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index ba26e42969..4647d0d2d2 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -914,6 +914,9 @@ void barrierMultipleQubits(QCOProgramBuilder& b); /// Creates a circuit with a single controlled barrier. void singleControlledBarrier(QCOProgramBuilder& b); +/// Canonicalized version of `singleControlledBarrier`. +void singleControlledBarrierCanonicalized(QCOProgramBuilder& b); + /// Creates a circuit with an inverse modifier applied to a barrier. void inverseBarrier(QCOProgramBuilder& b); From 3e8d254a5b76f43527a77db6be943f3ace770e11 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 23 Mar 2026 20:53:47 +0100 Subject: [PATCH 06/71] Fix QC-to-QIR tests --- .../Conversion/QCToQIR/test_qc_to_qir.cpp | 27 ++++++++++-------- mlir/unittests/programs/qc_programs.cpp | 28 +++++++++---------- mlir/unittests/programs/qco_programs.cpp | 16 +++++------ mlir/unittests/programs/qir_programs.cpp | 24 ++++++++++++++++ mlir/unittests/programs/qir_programs.h | 20 +++++++++++++ 5 files changed, 81 insertions(+), 34 deletions(-) diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index 696aadf565..e962e44a52 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -119,16 +119,17 @@ INSTANTIATE_TEST_SUITE_P( QCToQIRBarrierOpTest, QCToQIRTest, testing::Values( QCToQIRTestCase{"Barrier", MQT_NAMED_BUILDER(qc::barrier), - MQT_NAMED_BUILDER(qir::emptyQIR)}, + MQT_NAMED_BUILDER(qir::barrierConverted)}, QCToQIRTestCase{"BarrierTwoQubits", MQT_NAMED_BUILDER(qc::barrierTwoQubits), - MQT_NAMED_BUILDER(qir::emptyQIR)}, + MQT_NAMED_BUILDER(qir::barrierTwoQubitsConverted)}, QCToQIRTestCase{"BarrierMultipleQubits", MQT_NAMED_BUILDER(qc::barrierMultipleQubits), - MQT_NAMED_BUILDER(qir::emptyQIR)}, - QCToQIRTestCase{"SingleControlledBarrier", - MQT_NAMED_BUILDER(qc::singleControlledBarrier), - MQT_NAMED_BUILDER(qir::emptyQIR)})); + MQT_NAMED_BUILDER(qir::barrierMultipleQubitsConverted)}, + QCToQIRTestCase{ + "SingleControlledBarrier", + MQT_NAMED_BUILDER(qc::singleControlledBarrier), + MQT_NAMED_BUILDER(qir::singleControlledBarrierConverted)})); /// @} /// \name QCToQIR/Operations/StandardGates/DcxOp.cpp @@ -191,12 +192,14 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( QCToQIRTestCase{"Identity", MQT_NAMED_BUILDER(qc::identity), MQT_NAMED_BUILDER(qir::identity)}, - QCToQIRTestCase{"SingleControlledIdentity", - MQT_NAMED_BUILDER(qc::singleControlledIdentity), - MQT_NAMED_BUILDER(qir::identity)}, - QCToQIRTestCase{"MultipleControlledIdentity", - MQT_NAMED_BUILDER(qc::multipleControlledIdentity), - MQT_NAMED_BUILDER(qir::identity)})); + QCToQIRTestCase{ + "SingleControlledIdentity", + MQT_NAMED_BUILDER(qc::singleControlledIdentity), + MQT_NAMED_BUILDER(qir::singleControlledIdentityConverted)}, + QCToQIRTestCase{ + "MultipleControlledIdentity", + MQT_NAMED_BUILDER(qc::multipleControlledIdentity), + MQT_NAMED_BUILDER(qir::multipleControlledIdentityConverted)})); /// @} /// \name QCToQIR/Operations/StandardGates/IswapOp.cpp diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index a2ddd3d704..0baec89da8 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -71,8 +71,8 @@ void multipleClassicalRegistersAndMeasurements(QCProgramBuilder& b) { } void resetQubitWithoutOp(QCProgramBuilder& b) { - auto q = b.allocQubit(); - b.reset(q); + auto q = b.allocQubitRegister(1); + b.reset(q[0]); } void resetMultipleQubitsWithoutOp(QCProgramBuilder& b) { @@ -82,16 +82,16 @@ void resetMultipleQubitsWithoutOp(QCProgramBuilder& b) { } void repeatedResetWithoutOp(QCProgramBuilder& b) { - auto q = b.allocQubit(); - b.reset(q); - b.reset(q); - b.reset(q); + auto q = b.allocQubitRegister(1); + b.reset(q[0]); + b.reset(q[0]); + b.reset(q[0]); } void resetQubitAfterSingleOp(QCProgramBuilder& b) { - auto q = b.allocQubit(); - b.h(q); - b.reset(q); + auto q = b.allocQubitRegister(1); + b.h(q[0]); + b.reset(q[0]); } void resetMultipleQubitsAfterSingleOp(QCProgramBuilder& b) { @@ -103,11 +103,11 @@ void resetMultipleQubitsAfterSingleOp(QCProgramBuilder& b) { } void repeatedResetAfterSingleOp(QCProgramBuilder& b) { - auto q = b.allocQubit(); - b.h(q); - b.reset(q); - b.reset(q); - b.reset(q); + auto q = b.allocQubitRegister(1); + b.h(q[0]); + b.reset(q[0]); + b.reset(q[0]); + b.reset(q[0]); } void globalPhase(QCProgramBuilder& b) { b.gphase(0.123); } diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 657d94450b..fd681c8d11 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -93,9 +93,9 @@ void repeatedResetWithoutOp(QCOProgramBuilder& b) { } void resetQubitAfterSingleOp(QCOProgramBuilder& b) { - auto q = b.allocQubit(); - q = b.h(q); - q = b.reset(q); + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); + q[0] = b.reset(q[0]); } void resetMultipleQubitsAfterSingleOp(QCOProgramBuilder& b) { @@ -107,11 +107,11 @@ void resetMultipleQubitsAfterSingleOp(QCOProgramBuilder& b) { } void repeatedResetAfterSingleOp(QCOProgramBuilder& b) { - auto q = b.allocQubit(); - q = b.h(q); - q = b.reset(q); - q = b.reset(q); - q = b.reset(q); + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); + q[0] = b.reset(q[0]); + q[0] = b.reset(q[0]); + q[0] = b.reset(q[0]); } void globalPhase(QCOProgramBuilder& b) { b.gphase(0.123); } diff --git a/mlir/unittests/programs/qir_programs.cpp b/mlir/unittests/programs/qir_programs.cpp index 883cf59af8..11059a98d0 100644 --- a/mlir/unittests/programs/qir_programs.cpp +++ b/mlir/unittests/programs/qir_programs.cpp @@ -115,11 +115,21 @@ void singleControlledIdentity(QIRProgramBuilder& b) { b.cid(q[0], q[1]); } +void singleControlledIdentityConverted(QIRProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.id(q[0]); +} + void multipleControlledIdentity(QIRProgramBuilder& b) { auto q = b.allocQubitRegister(3); b.mcid({q[0], q[1]}, q[2]); } +void multipleControlledIdentityConverted(QIRProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.id(q[0]); +} + void x(QIRProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.x(q[0]); @@ -530,4 +540,18 @@ void multipleControlledXxMinusYY(QIRProgramBuilder& b) { b.mcxx_minus_yy(0.123, 0.456, {q[0], q[1]}, q[2], q[3]); } +void barrierConverted(QIRProgramBuilder& b) { b.allocQubitRegister(1); } + +void barrierTwoQubitsConverted(QIRProgramBuilder& b) { + b.allocQubitRegister(2); +} + +void barrierMultipleQubitsConverted(QIRProgramBuilder& b) { + b.allocQubitRegister(3); +} + +void singleControlledBarrierConverted(QIRProgramBuilder& b) { + b.allocQubitRegister(2); +} + } // namespace mlir::qir diff --git a/mlir/unittests/programs/qir_programs.h b/mlir/unittests/programs/qir_programs.h index f379f19785..47aa385b99 100644 --- a/mlir/unittests/programs/qir_programs.h +++ b/mlir/unittests/programs/qir_programs.h @@ -80,9 +80,15 @@ void identity(QIRProgramBuilder& b); /// Creates a controlled identity gate with a single control qubit. void singleControlledIdentity(QIRProgramBuilder& b); +/// Converted version of `qc::singleControlledIdentity`. +void singleControlledIdentityConverted(QIRProgramBuilder& b); + /// Creates a multi-controlled identity gate with multiple control qubits. void multipleControlledIdentity(QIRProgramBuilder& b); +/// Converted version of `qc::multipleControlledIdentity`. +void multipleControlledIdentityConverted(QIRProgramBuilder& b); + // --- XOp ------------------------------------------------------------------ // /// Creates a circuit with just an X gate. @@ -383,4 +389,18 @@ void singleControlledXxMinusYY(QIRProgramBuilder& b); /// Creates a circuit with a multi-controlled XXMinusYY gate. void multipleControlledXxMinusYY(QIRProgramBuilder& b); +// --- BarrierOp ------------------------------------------------------------ // + +/// Converted version of `qc::barrier`. +void barrierConverted(QIRProgramBuilder& b); + +/// Converted version of `qc::barrierTwoQubits`. +void barrierTwoQubitsConverted(QIRProgramBuilder& b); + +/// Converted version of `qc::barrierMultipleQubits`. +void barrierMultipleQubitsConverted(QIRProgramBuilder& b); + +/// Converted version of `qc::singleControlledBarrier`. +void singleControlledBarrierConverted(QIRProgramBuilder& b); + } // namespace mlir::qir From 5253e141323d324827e9b26883747785be176da2 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 27 Mar 2026 08:35:50 +0530 Subject: [PATCH 07/71] Fix linter errors --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 4 ++++ mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 10 ++++++++-- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 15 ++++++++------- mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp | 1 + .../lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp | 7 ++++--- .../lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp | 12 +++--------- .../Dialect/QTensor/IR/Operations/InsertOp.cpp | 9 +++++---- 7 files changed, 33 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 0bfeeaea8e..6514b2255f 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -17,13 +17,17 @@ #include "mlir/Dialect/QTensor/IR/QTensorDialect.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include #include #include #include #include +#include #include +#include #include #include +#include #include #include #include diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 0f04d621d4..bedddc8648 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -19,12 +19,17 @@ #include #include +#include #include #include #include +#include +#include #include +#include #include #include +#include #include #include #include @@ -199,7 +204,8 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { qtensor::ExtractOp::create(rewriter, op.getLoc(), qtensor, index); qubitMap.try_emplace(op.getResult(), extract.getResult()); - qubitInfos.try_emplace(op.getResult(), QubitInfo{memref, index}); + qubitInfos.try_emplace(op.getResult(), + QubitInfo{.reg = memref, .index = index}); qtensorMap[memref] = extract.getOutTensor(); rewriter.eraseOp(op); @@ -245,7 +251,7 @@ struct ConvertMemRefDeallocOp final using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult - matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, + matchAndRewrite(memref::DeallocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap; auto& qubitInfos = getState().qubitInfos; diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index a9b3b12cbc..b73bec3485 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -88,8 +89,8 @@ struct LoweringState : QIRMetadata { DenseMap> controls; // Block information - Block* entryBlock; - Block* measurementsBlock; + Block* entryBlock{}; + Block* measurementsBlock{}; }; /** @@ -300,8 +301,8 @@ struct ConvertMemRefDeallocOp final Value size; if (shape[0] == ShapedType::kDynamic) { - llvm::errs() << "I do not know yet\n"; - return failure(); + size = + op.getMemref().getDefiningOp().getDynamicSizes()[0]; } else { size = LLVM::ConstantOp::create( rewriter, op.getLoc(), @@ -430,7 +431,7 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult - matchAndRewrite(StaticOp op, OpAdaptor /*adaptor*/, + matchAndRewrite(StaticOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { // TODO: Figure this out return failure(); @@ -493,7 +494,7 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { static_cast(op.getRegisterIndex().value()); // Create result register if it does not exist yet - if (resultArrays.find(registerName) == resultArrays.end()) { + if (!resultArrays.contains(registerName)) { auto fnSig = LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(ctx), {rewriter.getI64Type(), ptrType, ptrType}); @@ -1181,7 +1182,7 @@ struct QCToQIR final : impl::QCToQIRBase { // Sort registers by name for deterministic output SmallVector> sortedRegisters; for (auto& [name, results] : resultArrays) { - sortedRegisters.emplace_back(name, std::move(results)); + sortedRegisters.emplace_back(name, results); } llvm::sort(sortedRegisters, [](const auto& a, const auto& b) { return a.first < b.first; diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 3dfad8e92a..ccf83b8a1e 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 729cca295f..5979ed16cb 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include "mlir/Dialect/Utils/Utils.h" +#include #include #include #include @@ -32,7 +33,6 @@ #include #include -#include #include #include #include @@ -246,7 +246,8 @@ std::pair QCOProgramBuilder::qtensorExtract(Value tensor, auto qubit = extractOp.getResult(); auto outTensor = extractOp.getOutTensor(); - validQubits.insert({qubit, {validTensors[tensor].regId, index}}); + validQubits.insert( + {qubit, {.regId = validTensors[tensor].regId, .regIndex = index}}); updateTensorTracking(tensor, outTensor); return {outTensor, qubit}; @@ -991,7 +992,7 @@ OwningOpRef QCOProgramBuilder::finalize() { llvm::sort(sortedTensors, blockOrderComparator1); for (auto& [tensor, tensorInfo] : sortedTensors) { // Filter out qubits belonging to this tensor - SmallVector> toInsert; + llvm::SmallVector> toInsert; for (auto& [qubit, qubitInfo] : registerQubits) { if (qubitInfo.regId != tensorInfo.regId) { continue; diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index 570be420d6..e819ccc29b 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -129,7 +129,7 @@ SmallVector QIRProgramBuilder::allocQubitRegister(const int64_t size) { auto array = LLVM::AllocaOp::create(*this, ptrType, ptrType, intConstant(size)); auto zero = LLVM::ZeroOp::create(*this, ptrType); - auto alloc = LLVM::CallOp::create( + LLVM::CallOp::create( *this, allocFnDecl, ValueRange{intConstant(size), array.getResult(), zero.getResult()}); @@ -168,18 +168,12 @@ QIRProgramBuilder::allocClassicalBitRegister(const int64_t size, auto array = LLVM::AllocaOp::create(*this, ptrType, ptrType, intConstant(size)); auto zero = LLVM::ZeroOp::create(*this, ptrType); - auto alloc = LLVM::CallOp::create( + LLVM::CallOp::create( *this, allocFnDecl, ValueRange{intConstant(size), array.getResult(), zero.getResult()}); resultArrays.try_emplace(name, array.getResult()); - for (int64_t i = 0; i < size; ++i) { - auto gep = LLVM::GEPOp::create(*this, ptrType, ptrType, array.getResult(), - ValueRange{intConstant(i)}); - auto load = LLVM::LoadOp::create(*this, ptrType, gep.getResult()); - } - return {.name = name, .size = size}; } @@ -610,7 +604,7 @@ void QIRProgramBuilder::generateOutputRecording() { // Sort registers by name for deterministic output SmallVector> sortedArrays; for (auto& [name, results] : resultArrays) { - sortedArrays.emplace_back(name, std::move(results)); + sortedArrays.emplace_back(name, results); } llvm::sort(sortedArrays, [](const auto& a, const auto& b) { return a.first < b.first; }); diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 5d7588f1fc..e312d73244 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include #include #include #include @@ -24,15 +25,15 @@ using namespace mlir::qtensor; static ExtractOp findExtractOp(InsertOp op) { - auto definingOp = op.getDest().getDefiningOp(); + auto* definingOp = op.getDest().getDefiningOp(); if (llvm::isa(definingOp)) { return llvm::cast(definingOp); - } else if (llvm::isa(definingOp)) { + } + if (llvm::isa(definingOp)) { auto nestedInsertOp = llvm::cast(definingOp); return findExtractOp(nestedInsertOp); - } else { - return nullptr; } + return nullptr; } namespace { From 4b02d3f76d59d1687eb47bb2dd4a69550ace9f34 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 27 Mar 2026 22:18:13 +0530 Subject: [PATCH 08/71] Fix some more tests --- .../Compiler/test_compiler_pipeline.cpp | 36 +++++++++---------- .../Conversion/QCToQIR/test_qc_to_qir.cpp | 1 + mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 15 ++++++-- mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp | 26 +++++++------- 4 files changed, 44 insertions(+), 34 deletions(-) diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index d6b40060dd..7cfdc681a4 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -192,27 +192,28 @@ TEST_P(CompilerPipelineTest, EndToEndPipeline) { INSTANTIATE_TEST_SUITE_P( QuantumComputationPipelineProgramsTest, CompilerPipelineTest, testing::Values( + // FIXME: Test fails because static qubits are currently not supported CompilerPipelineTestCase{ "StaticQubits", nullptr, MQT_NAMED_BUILDER(mlir::qc::staticQubits), MQT_NAMED_BUILDER(mlir::qc::staticQubits), MQT_NAMED_BUILDER(mlir::qir::staticQubits), false}, CompilerPipelineTestCase{"AllocQubit", MQT_NAMED_BUILDER(qc::allocQubit), nullptr, - MQT_NAMED_BUILDER(mlir::qc::allocQubit), - MQT_NAMED_BUILDER(mlir::qir::allocQubit)}, - CompilerPipelineTestCase{ - "AllocQubitRegister", MQT_NAMED_BUILDER(qc::allocQubitRegister), - nullptr, MQT_NAMED_BUILDER(mlir::qc::allocQubitRegister), - MQT_NAMED_BUILDER(mlir::qir::allocQubitRegister)}, + MQT_NAMED_BUILDER(mlir::qc::emptyQC), + MQT_NAMED_BUILDER(mlir::qir::emptyQIR)}, + CompilerPipelineTestCase{"AllocQubitRegister", + MQT_NAMED_BUILDER(qc::allocQubitRegister), + nullptr, MQT_NAMED_BUILDER(mlir::qc::emptyQC), + MQT_NAMED_BUILDER(mlir::qir::emptyQIR)}, CompilerPipelineTestCase{ "AllocMultipleQubitRegisters", MQT_NAMED_BUILDER(qc::allocMultipleQubitRegisters), nullptr, - MQT_NAMED_BUILDER(mlir::qc::allocMultipleQubitRegisters), - MQT_NAMED_BUILDER(mlir::qir::allocMultipleQubitRegisters)}, - CompilerPipelineTestCase{ - "AllocLargeRegister", MQT_NAMED_BUILDER(qc::allocLargeRegister), - nullptr, MQT_NAMED_BUILDER(mlir::qc::allocLargeRegister), - MQT_NAMED_BUILDER(mlir::qir::allocLargeRegister)}, + MQT_NAMED_BUILDER(mlir::qc::emptyQC), + MQT_NAMED_BUILDER(mlir::qir::emptyQIR)}, + CompilerPipelineTestCase{"AllocLargeRegister", + MQT_NAMED_BUILDER(qc::allocLargeRegister), + nullptr, MQT_NAMED_BUILDER(mlir::qc::emptyQC), + MQT_NAMED_BUILDER(mlir::qir::emptyQIR)}, CompilerPipelineTestCase{ "SingleMeasurementToSingleBit", MQT_NAMED_BUILDER(qc::singleMeasurementToSingleBit), nullptr, @@ -228,6 +229,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::repeatedMeasurementToDifferentBits), nullptr, MQT_NAMED_BUILDER(mlir::qc::repeatedMeasurementToDifferentBits), MQT_NAMED_BUILDER(mlir::qir::repeatedMeasurementToDifferentBits)}, + // FIXME: Test fails because of location of llvm.load CompilerPipelineTestCase{ "MultipleClassicalRegistersAndMeasurements", MQT_NAMED_BUILDER(qc::multipleClassicalRegistersAndMeasurements), @@ -429,9 +431,8 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(mlir::qc::r), MQT_NAMED_BUILDER(mlir::qir::r)}, CompilerPipelineTestCase{ - "SingleControlledR", - MQT_NAMED_BUILDER(qc::singleControlledR), nullptr, - MQT_NAMED_BUILDER(mlir::qc::singleControlledR), + "SingleControlledR", MQT_NAMED_BUILDER(qc::singleControlledR), + nullptr, MQT_NAMED_BUILDER(mlir::qc::singleControlledR), MQT_NAMED_BUILDER(mlir::qir::singleControlledR)}, CompilerPipelineTestCase{ "MultipleControlledR", MQT_NAMED_BUILDER(qc::multipleControlledR), @@ -452,9 +453,8 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(mlir::qc::u), MQT_NAMED_BUILDER(mlir::qir::u)}, CompilerPipelineTestCase{ - "SingleControlledU", - MQT_NAMED_BUILDER(qc::singleControlledU), nullptr, - MQT_NAMED_BUILDER(mlir::qc::singleControlledU), + "SingleControlledU", MQT_NAMED_BUILDER(qc::singleControlledU), + nullptr, MQT_NAMED_BUILDER(mlir::qc::singleControlledU), MQT_NAMED_BUILDER(mlir::qir::singleControlledU)}, CompilerPipelineTestCase{ "MultipleControlledU", MQT_NAMED_BUILDER(qc::multipleControlledU), diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index e962e44a52..1f34c81594 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -577,6 +577,7 @@ INSTANTIATE_TEST_SUITE_P( "RepeatedMeasurementToDifferentBits", MQT_NAMED_BUILDER(qc::repeatedMeasurementToDifferentBits), MQT_NAMED_BUILDER(qir::repeatedMeasurementToDifferentBits)}, + // FIXME: Test fails because of location of llvm.load QCToQIRTestCase{ "MultipleClassicalRegistersAndMeasurements", MQT_NAMED_BUILDER(qc::multipleClassicalRegistersAndMeasurements), diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index c6f81b4dda..7381022a68 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include "mlir/Support/IRVerification.h" #include "mlir/Support/Passes.h" #include "qco_programs.h" @@ -99,8 +100,11 @@ TEST_F(QCOTest, DirectIfBuilder) { // Test If construction directly qco::QCOProgramBuilder builder(context.get()); builder.initialize(); - auto q0 = AllocOp::create(builder); - auto q1 = HOp::create(builder, q0); + auto c0 = arith::ConstantOp::create(builder, builder.getIndexAttr(0)); + auto c1 = arith::ConstantOp::create(builder, builder.getIndexAttr(1)); + auto r0 = qtensor::AllocOp::create(builder, c1); + auto extractOp = qtensor::ExtractOp::create(builder, r0, c0); + auto q1 = HOp::create(builder, extractOp.getResult()); auto measureOp = MeasureOp::create(builder, q1); auto ifOp = IfOp::create(builder, measureOp.getResult(), measureOp.getQubitOut(), @@ -108,7 +112,9 @@ TEST_F(QCOTest, DirectIfBuilder) { auto innerQubit = XOp::create(builder, qubits[0]); return llvm::SmallVector{innerQubit}; }); - DeallocOp::create(builder, ifOp.getResult(0)); + auto r2 = qtensor::InsertOp::create(builder, ifOp.getResult(0), + extractOp.getOutTensor(), c0); + qtensor::DeallocOp::create(builder, r2); auto directBuilder = builder.finalize(); ASSERT_TRUE(directBuilder); @@ -455,6 +461,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(multipleControlledRxx)}, QCOTestCase{"TwoRXX", MQT_NAMED_BUILDER(twoRxx), MQT_NAMED_BUILDER(rxx)}, + // FIXME: Test fails because of qtensor.insert location QCOTestCase{"TwoRXXSwappedTargets", MQT_NAMED_BUILDER(twoRxxSwappedTargets), MQT_NAMED_BUILDER(rxx)}, @@ -516,6 +523,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(multipleControlledRyy)}, QCOTestCase{"TwoRYY", MQT_NAMED_BUILDER(twoRyy), MQT_NAMED_BUILDER(ryy)}, + // FIXME: Test fails because of qtensor.insert location QCOTestCase{"TwoRYYSwappedTargets", MQT_NAMED_BUILDER(twoRyySwappedTargets), MQT_NAMED_BUILDER(ryy)}, @@ -605,6 +613,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(multipleControlledRzz)}, QCOTestCase{"TwoRZZ", MQT_NAMED_BUILDER(twoRzz), MQT_NAMED_BUILDER(rzz)}, + // FIXME: Test fails because of qtensor.insert location QCOTestCase{"TwoRZZSwappedTargets", MQT_NAMED_BUILDER(twoRzzSwappedTargets), MQT_NAMED_BUILDER(rzz)}, diff --git a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp index e177c90329..b765529b69 100644 --- a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp +++ b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp @@ -522,17 +522,17 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QIRQubitManagementTest, QIRTest, - testing::Values(QIRTestCase{"AllocQubit", MQT_NAMED_BUILDER(allocQubit), - MQT_NAMED_BUILDER(allocQubit)}, - QIRTestCase{"AllocQubitRegister", - MQT_NAMED_BUILDER(allocQubitRegister), - MQT_NAMED_BUILDER(allocQubitRegister)}, - QIRTestCase{"AllocMultipleQubitRegisters", - MQT_NAMED_BUILDER(allocMultipleQubitRegisters), - MQT_NAMED_BUILDER(allocMultipleQubitRegisters)}, - QIRTestCase{"AllocLargeRegister", - MQT_NAMED_BUILDER(allocLargeRegister), - MQT_NAMED_BUILDER(allocLargeRegister)}, - QIRTestCase{"StaticQubits", MQT_NAMED_BUILDER(staticQubits), - MQT_NAMED_BUILDER(staticQubits)})); + testing::Values( + QIRTestCase{"AllocQubit", MQT_NAMED_BUILDER(allocQubit), + MQT_NAMED_BUILDER(allocQubit)}, + QIRTestCase{"AllocQubitRegister", MQT_NAMED_BUILDER(allocQubitRegister), + MQT_NAMED_BUILDER(allocQubitRegister)}, + QIRTestCase{"AllocMultipleQubitRegisters", + MQT_NAMED_BUILDER(allocMultipleQubitRegisters), + MQT_NAMED_BUILDER(allocMultipleQubitRegisters)}, + QIRTestCase{"AllocLargeRegister", MQT_NAMED_BUILDER(allocLargeRegister), + MQT_NAMED_BUILDER(allocLargeRegister)}, + // FIXME: Test fails because static qubits are currently not supported + QIRTestCase{"StaticQubits", MQT_NAMED_BUILDER(staticQubits), + MQT_NAMED_BUILDER(staticQubits)})); /// @} From 47f84a816e1decc712b3b6c8292f0a29f1489a79 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:57:57 +0200 Subject: [PATCH 09/71] Exclude Jeff tests from discovery for now --- mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt b/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt index 868d4e5656..b60b8f5c91 100644 --- a/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt +++ b/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt @@ -24,4 +24,4 @@ target_link_libraries( mqt_mlir_configure_unittest_target(${target_name}) -gtest_discover_tests(${target_name} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) +# gtest_discover_tests(${target_name} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) From 5eb6778fec5c3d6b0166533f6c6651e581326e10 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 11:29:14 +0200 Subject: [PATCH 10/71] Support static qubits in QIR again --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 31 +++++++++++++++-- .../Dialect/QIR/Builder/QIRProgramBuilder.cpp | 27 +++++++++++++-- mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp | 34 +++++++++++++------ .../Compiler/test_compiler_pipeline.cpp | 11 +++--- mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp | 26 +++++++------- 5 files changed, 97 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index b73bec3485..4285159f08 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -219,6 +219,9 @@ struct ConvertMemRefAllocOp final LogicalResult matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); + state.useDynamicQubit = true; + auto* ctx = getContext(); auto ptrType = LLVM::LLVMPointerType::get(ctx); @@ -351,6 +354,9 @@ struct ConvertQCAllocOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(AllocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); + state.useDynamicQubit = true; + auto* ctx = getContext(); auto ptrType = LLVM::LLVMPointerType::get(ctx); @@ -433,8 +439,27 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(StaticOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // TODO: Figure this out - return failure(); + const auto index = static_cast(op.getIndex()); + auto& state = getState(); + + // Get or create a pointer to the qubit + Value qubit; + if (const auto it = state.ptrMap.find(index); it != state.ptrMap.end()) { + // Reuse existing pointer + qubit = it->second; + } else { + // Create and cache for reuse + qubit = createPointerFromIndex(rewriter, op.getLoc(), index); + state.ptrMap.try_emplace(index, qubit); + } + rewriter.replaceOp(op, qubit); + + // Track maximum qubit index + if (std::cmp_greater_equal(index, state.numQubits)) { + state.numQubits = index + 1; + } + + return success(); } }; @@ -472,6 +497,8 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { matchAndRewrite(MeasureOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); + state.useDynamicResult = true; + auto& resultArrays = state.resultArrays; auto& resultPtrs = state.resultPtrs; diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index e819ccc29b..e7e4b0df56 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -101,8 +101,27 @@ Value QIRProgramBuilder::doubleConstant(double value) { } Value QIRProgramBuilder::staticQubit(const int64_t index) { - // TODO: Figure this out - llvm::reportFatalInternalError("Currently not implemented"); + checkFinalized(); + + if (index < 0) { + llvm::reportFatalUsageError("Index must be non-negative"); + } + + Value qubit; + if (const auto it = ptrCache.find(index); it != ptrCache.end()) { + qubit = it->second; + } else { + qubit = createPointerFromIndex(*this, getLoc(), index); + // Cache for reuse + ptrCache[index] = qubit; + } + + // Update qubit count + if (std::cmp_greater_equal(index, metadata_.numQubits)) { + metadata_.numQubits = static_cast(index) + 1; + } + + return qubit; } SmallVector QIRProgramBuilder::allocQubitRegister(const int64_t size) { @@ -112,6 +131,8 @@ SmallVector QIRProgramBuilder::allocQubitRegister(const int64_t size) { llvm::reportFatalUsageError("Size must be positive"); } + metadata_.useDynamicQubit = true; + // Save current insertion point const InsertionGuard guard(*this); @@ -154,6 +175,8 @@ QIRProgramBuilder::allocClassicalBitRegister(const int64_t size, llvm::reportFatalUsageError("Size must be positive"); } + metadata_.useDynamicResult = true; + // Save current insertion point const InsertionGuard guard(*this); diff --git a/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp b/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp index 1093921606..b698c06dc7 100644 --- a/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp +++ b/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp @@ -57,6 +57,11 @@ LLVM::LLVMFuncOp getMainFunction(Operation* op) { } void setQIRAttributes(LLVM::LLVMFuncOp& main, const QIRMetadata& metadata) { + if (metadata.useDynamicQubit && metadata.numQubits != 0) { + llvm::reportFatalUsageError( + "Cannot use dynamic qubit allocation if static qubits are allocated"); + } + OpBuilder builder(main.getBody()); llvm::SmallVector attributes; @@ -73,17 +78,26 @@ void setQIRAttributes(LLVM::LLVMFuncOp& main, const QIRMetadata& metadata) { attributes.emplace_back(builder.getStrArrayAttr( {"required_num_results", std::to_string(metadata.numResults)})); - // QIR version (Base Profile spec requires version 2.0) - attributes.emplace_back(builder.getStrArrayAttr({"qir_major_version", "2"})); - attributes.emplace_back(builder.getStrArrayAttr({"qir_minor_version", "0"})); + // Management model or resource requirements + if (metadata.useDynamicQubit) { + attributes.emplace_back( + builder.getStrArrayAttr({"dynamic_qubit_management", "true"})); + } else { + attributes.emplace_back(builder.getStrArrayAttr( + {"required_num_qubits", std::to_string(metadata.numQubits)})); + } - // Management model - attributes.emplace_back( - builder.getStrArrayAttr({"dynamic_qubit_management", - metadata.useDynamicQubit ? "true" : "false"})); - attributes.emplace_back( - builder.getStrArrayAttr({"dynamic_result_management", - metadata.useDynamicResult ? "true" : "false"})); + if (metadata.useDynamicResult) { + attributes.emplace_back( + builder.getStrArrayAttr({"dynamic_result_management", "true"})); + } else { + attributes.emplace_back(builder.getStrArrayAttr( + {"required_num_results", std::to_string(metadata.numResults)})); + } + + // QIR version (Base Profile spec requires version 2.1) + attributes.emplace_back(builder.getStrArrayAttr({"qir_major_version", "2"})); + attributes.emplace_back(builder.getStrArrayAttr({"qir_minor_version", "1"})); main->setAttr("passthrough", builder.getArrayAttr(attributes)); } diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index 7cfdc681a4..ca203826e3 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -192,7 +192,6 @@ TEST_P(CompilerPipelineTest, EndToEndPipeline) { INSTANTIATE_TEST_SUITE_P( QuantumComputationPipelineProgramsTest, CompilerPipelineTest, testing::Values( - // FIXME: Test fails because static qubits are currently not supported CompilerPipelineTestCase{ "StaticQubits", nullptr, MQT_NAMED_BUILDER(mlir::qc::staticQubits), MQT_NAMED_BUILDER(mlir::qc::staticQubits), @@ -431,8 +430,9 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(mlir::qc::r), MQT_NAMED_BUILDER(mlir::qir::r)}, CompilerPipelineTestCase{ - "SingleControlledR", MQT_NAMED_BUILDER(qc::singleControlledR), - nullptr, MQT_NAMED_BUILDER(mlir::qc::singleControlledR), + "SingleControlledR", + MQT_NAMED_BUILDER(qc::singleControlledR), nullptr, + MQT_NAMED_BUILDER(mlir::qc::singleControlledR), MQT_NAMED_BUILDER(mlir::qir::singleControlledR)}, CompilerPipelineTestCase{ "MultipleControlledR", MQT_NAMED_BUILDER(qc::multipleControlledR), @@ -453,8 +453,9 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(mlir::qc::u), MQT_NAMED_BUILDER(mlir::qir::u)}, CompilerPipelineTestCase{ - "SingleControlledU", MQT_NAMED_BUILDER(qc::singleControlledU), - nullptr, MQT_NAMED_BUILDER(mlir::qc::singleControlledU), + "SingleControlledU", + MQT_NAMED_BUILDER(qc::singleControlledU), nullptr, + MQT_NAMED_BUILDER(mlir::qc::singleControlledU), MQT_NAMED_BUILDER(mlir::qir::singleControlledU)}, CompilerPipelineTestCase{ "MultipleControlledU", MQT_NAMED_BUILDER(qc::multipleControlledU), diff --git a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp index b765529b69..e177c90329 100644 --- a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp +++ b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp @@ -522,17 +522,17 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QIRQubitManagementTest, QIRTest, - testing::Values( - QIRTestCase{"AllocQubit", MQT_NAMED_BUILDER(allocQubit), - MQT_NAMED_BUILDER(allocQubit)}, - QIRTestCase{"AllocQubitRegister", MQT_NAMED_BUILDER(allocQubitRegister), - MQT_NAMED_BUILDER(allocQubitRegister)}, - QIRTestCase{"AllocMultipleQubitRegisters", - MQT_NAMED_BUILDER(allocMultipleQubitRegisters), - MQT_NAMED_BUILDER(allocMultipleQubitRegisters)}, - QIRTestCase{"AllocLargeRegister", MQT_NAMED_BUILDER(allocLargeRegister), - MQT_NAMED_BUILDER(allocLargeRegister)}, - // FIXME: Test fails because static qubits are currently not supported - QIRTestCase{"StaticQubits", MQT_NAMED_BUILDER(staticQubits), - MQT_NAMED_BUILDER(staticQubits)})); + testing::Values(QIRTestCase{"AllocQubit", MQT_NAMED_BUILDER(allocQubit), + MQT_NAMED_BUILDER(allocQubit)}, + QIRTestCase{"AllocQubitRegister", + MQT_NAMED_BUILDER(allocQubitRegister), + MQT_NAMED_BUILDER(allocQubitRegister)}, + QIRTestCase{"AllocMultipleQubitRegisters", + MQT_NAMED_BUILDER(allocMultipleQubitRegisters), + MQT_NAMED_BUILDER(allocMultipleQubitRegisters)}, + QIRTestCase{"AllocLargeRegister", + MQT_NAMED_BUILDER(allocLargeRegister), + MQT_NAMED_BUILDER(allocLargeRegister)}, + QIRTestCase{"StaticQubits", MQT_NAMED_BUILDER(staticQubits), + MQT_NAMED_BUILDER(staticQubits)})); /// @} From 8f80b9875a61c871591907080e946f41e175ee10 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:20:05 +0200 Subject: [PATCH 11/71] Load results once --- .../Dialect/QIR/Builder/QIRProgramBuilder.h | 7 +++- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 27 +++++++------ .../Dialect/QIR/Builder/QIRProgramBuilder.cpp | 39 +++++++------------ .../Compiler/test_compiler_pipeline.cpp | 1 - .../Conversion/QCToQIR/test_qc_to_qir.cpp | 1 - 5 files changed, 36 insertions(+), 39 deletions(-) diff --git a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h index b7809d872e..12ef6f64bd 100644 --- a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h @@ -887,7 +887,7 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { Value exitCode; /// Cache static pointers for reuse - llvm::DenseMap ptrCache; + llvm::DenseMap staticQubits; /// Set of qubit-array pointers llvm::DenseSet qubitArrays; @@ -895,7 +895,10 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { /// Map from register name to result-array pointer llvm::StringMap resultArrays; - /// Map from result index to result pointer + /// Map from (register name, index) to loaded result + llvm::DenseMap, Value> loadedResults; + + /// Map from result index to result pointer for non-register results llvm::DenseMap resultPtrs; /// Track qubit and result counts for QIR metadata diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 4285159f08..2665f2fae9 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -81,7 +81,10 @@ struct LoweringState : QIRMetadata { /// Map from register name to result-array pointer llvm::StringMap resultArrays; - /// Map from index to result pointer + /// Map from (register name, index) to loaded result + llvm::DenseMap, Value> loadedResults; + + /// Map from index to result pointer for non-register results DenseMap resultPtrs; /// Modifier information @@ -500,6 +503,7 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { state.useDynamicResult = true; auto& resultArrays = state.resultArrays; + auto& loadedResults = state.loadedResults; auto& resultPtrs = state.resultPtrs; auto* ctx = getContext(); @@ -539,18 +543,19 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { rewriter, op.getLoc(), fnDec, ValueRange{size, array.getResult(), zero.getResult()}); resultArrays.try_emplace(registerName, array.getResult()); + + for (int64_t i = 0; i < registerSize; ++i) { + auto gep = LLVM::GEPOp::create( + rewriter, op.getLoc(), ptrType, ptrType, array.getResult(), + ValueRange{LLVM::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getI64IntegerAttr(i))}); + auto load = LLVM::LoadOp::create(rewriter, op.getLoc(), ptrType, + gep.getResult()); + loadedResults.try_emplace({registerName, i}, load.getResult()); + } } - auto array = resultArrays[registerName]; - auto index = - LLVM::ConstantOp::create(rewriter, op.getLoc(), - rewriter.getI64IntegerAttr(registerIndex)) - .getResult(); - auto gep = LLVM::GEPOp::create(rewriter, op.getLoc(), ptrType, ptrType, - array, index); - auto load = - LLVM::LoadOp::create(rewriter, op.getLoc(), ptrType, gep.getResult()); - result = load.getResult(); + result = loadedResults.at({registerName, registerIndex}); } else { auto fnSig = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {ptrType}); diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index e7e4b0df56..0e0ee74358 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -108,12 +108,12 @@ Value QIRProgramBuilder::staticQubit(const int64_t index) { } Value qubit; - if (const auto it = ptrCache.find(index); it != ptrCache.end()) { + if (const auto it = staticQubits.find(index); it != staticQubits.end()) { qubit = it->second; } else { qubit = createPointerFromIndex(*this, getLoc(), index); // Cache for reuse - ptrCache[index] = qubit; + staticQubits[index] = qubit; } // Update qubit count @@ -157,10 +157,10 @@ SmallVector QIRProgramBuilder::allocQubitRegister(const int64_t size) { qubitArrays.insert(array.getResult()); for (int64_t i = 0; i < size; ++i) { - auto gepOp = LLVM::GEPOp::create(*this, ptrType, ptrType, array.getResult(), - ValueRange{intConstant(i)}); - auto loadOp = LLVM::LoadOp::create(*this, ptrType, gepOp.getResult()); - qubits.push_back(loadOp.getResult()); + auto gep = LLVM::GEPOp::create(*this, ptrType, ptrType, array.getResult(), + ValueRange{intConstant(i)}); + auto load = LLVM::LoadOp::create(*this, ptrType, gep.getResult()); + qubits.push_back(load.getResult()); } return qubits; @@ -197,6 +197,13 @@ QIRProgramBuilder::allocClassicalBitRegister(const int64_t size, resultArrays.try_emplace(name, array.getResult()); + for (int64_t i = 0; i < size; ++i) { + auto gep = LLVM::GEPOp::create(*this, ptrType, ptrType, array.getResult(), + ValueRange{intConstant(i)}); + auto load = LLVM::LoadOp::create(*this, ptrType, gep.getResult()); + loadedResults.try_emplace({stringSaver.save(name), i}, load.getResult()); + } + return {.name = name, .size = size}; } @@ -242,27 +249,11 @@ QIRProgramBuilder& QIRProgramBuilder::measure(Value qubit, const Bit& bit) { // Save current insertion point const InsertionGuard guard(*this); - // Insert allocations and constants in entry block - setInsertionPoint(entryBlock->getTerminator()); - - auto index = intConstant(bit.registerIndex); - - // Get array pointer - const auto& registerName = bit.registerName; - if (!resultArrays.contains(registerName)) { - llvm::reportFatalInternalError("Result pointer not found"); - } - auto array = resultArrays.at(registerName); - - // Get result pointer - auto gep = - LLVM::GEPOp::create(*this, ptrType, ptrType, array, ValueRange{index}); - auto load = LLVM::LoadOp::create(*this, ptrType, gep.getResult()); - auto result = load.getResult(); - // Switch to measurements block setInsertionPoint(measurementsBlock->getTerminator()); + auto result = loadedResults.at({bit.registerName, bit.registerIndex}); + // Create measure call const auto fnSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); auto fnDec = diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index ca203826e3..36f1c73f18 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -228,7 +228,6 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::repeatedMeasurementToDifferentBits), nullptr, MQT_NAMED_BUILDER(mlir::qc::repeatedMeasurementToDifferentBits), MQT_NAMED_BUILDER(mlir::qir::repeatedMeasurementToDifferentBits)}, - // FIXME: Test fails because of location of llvm.load CompilerPipelineTestCase{ "MultipleClassicalRegistersAndMeasurements", MQT_NAMED_BUILDER(qc::multipleClassicalRegistersAndMeasurements), diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index 1f34c81594..e962e44a52 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -577,7 +577,6 @@ INSTANTIATE_TEST_SUITE_P( "RepeatedMeasurementToDifferentBits", MQT_NAMED_BUILDER(qc::repeatedMeasurementToDifferentBits), MQT_NAMED_BUILDER(qir::repeatedMeasurementToDifferentBits)}, - // FIXME: Test fails because of location of llvm.load QCToQIRTestCase{ "MultipleClassicalRegistersAndMeasurements", MQT_NAMED_BUILDER(qc::multipleClassicalRegistersAndMeasurements), From 9887b3f4ee31205d8817c7d6617b884f03e6cfe8 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:44:22 +0200 Subject: [PATCH 12/71] Fix sorting of qtensor.insert statements --- .../lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp | 15 +++++++-------- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 3 --- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 5979ed16cb..4998d3804a 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -940,8 +940,8 @@ OwningOpRef QCOProgramBuilder::finalize() { return opA->isBeforeInBlock(opB); }; - auto blockOrderComparator1 = [](const std::pair& a, - const std::pair& b) { + auto blockOrderComparatorTensors = [](const std::pair& a, + const std::pair& b) { auto* opA = a.first.getDefiningOp(); auto* opB = b.first.getDefiningOp(); if (!opA || !opB || opA->getBlock() != opB->getBlock()) { @@ -950,8 +950,8 @@ OwningOpRef QCOProgramBuilder::finalize() { return opA->isBeforeInBlock(opB); }; - auto blockOrderComparator2 = [](const std::pair& a, - const std::pair& b) { + auto blockOrderComparatorToInsert = [](const std::pair& a, + const std::pair& b) { auto* opA = a.first.getDefiningOp(); auto* opB = b.first.getDefiningOp(); if (!opA || !opB || opA->getBlock() != opB->getBlock()) { @@ -960,8 +960,7 @@ OwningOpRef QCOProgramBuilder::finalize() { if (opA != opB) { return opA->isBeforeInBlock(opB); } - return llvm::cast(a.first).getResultNumber() < - llvm::cast(b.first).getResultNumber(); + return a.second < b.second; }; llvm::SmallVector freeQubits; @@ -989,7 +988,7 @@ OwningOpRef QCOProgramBuilder::finalize() { // Sort tensors for deterministic output llvm::SmallVector> sortedTensors( validTensors.begin(), validTensors.end()); - llvm::sort(sortedTensors, blockOrderComparator1); + llvm::sort(sortedTensors, blockOrderComparatorTensors); for (auto& [tensor, tensorInfo] : sortedTensors) { // Filter out qubits belonging to this tensor llvm::SmallVector> toInsert; @@ -1000,7 +999,7 @@ OwningOpRef QCOProgramBuilder::finalize() { toInsert.push_back({qubit, qubitInfo.regIndex}); } // Sort qubits for deterministic output - llvm::sort(toInsert, blockOrderComparator2); + llvm::sort(toInsert, blockOrderComparatorToInsert); // Insert qubits for (auto& [qubit, index] : toInsert) { tensor = qtensorInsert(qubit, tensor, index); diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 7381022a68..0dcfbc79e8 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -461,7 +461,6 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(multipleControlledRxx)}, QCOTestCase{"TwoRXX", MQT_NAMED_BUILDER(twoRxx), MQT_NAMED_BUILDER(rxx)}, - // FIXME: Test fails because of qtensor.insert location QCOTestCase{"TwoRXXSwappedTargets", MQT_NAMED_BUILDER(twoRxxSwappedTargets), MQT_NAMED_BUILDER(rxx)}, @@ -523,7 +522,6 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(multipleControlledRyy)}, QCOTestCase{"TwoRYY", MQT_NAMED_BUILDER(twoRyy), MQT_NAMED_BUILDER(ryy)}, - // FIXME: Test fails because of qtensor.insert location QCOTestCase{"TwoRYYSwappedTargets", MQT_NAMED_BUILDER(twoRyySwappedTargets), MQT_NAMED_BUILDER(ryy)}, @@ -613,7 +611,6 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(multipleControlledRzz)}, QCOTestCase{"TwoRZZ", MQT_NAMED_BUILDER(twoRzz), MQT_NAMED_BUILDER(rzz)}, - // FIXME: Test fails because of qtensor.insert location QCOTestCase{"TwoRZZSwappedTargets", MQT_NAMED_BUILDER(twoRzzSwappedTargets), MQT_NAMED_BUILDER(rzz)}, From 12d944314391b5c2cb3d3ac3c49248abadd52926 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:14:35 +0200 Subject: [PATCH 13/71] Fix linter errors --- mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h | 6 ++---- mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h | 6 ++---- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 1 - mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 1 - mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 2 +- mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp | 3 +-- .../QC/Translation/TranslateQuantumComputationToQC.cpp | 4 ++-- mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp | 3 +-- mlir/unittests/programs/qc_programs.cpp | 4 ++-- mlir/unittests/programs/qco_programs.cpp | 4 ++-- 10 files changed, 13 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 867ef62515..29c0a7ae6a 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -126,12 +126,11 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { /** * @brief Allocate a qubit register * @param size Number of qubits (must be positive) - * @param name Register name (default: "q") * @return Vector of qubit references * * @par Example: * ```c++ - * auto q = builder.allocQubitRegister(3, "q"); + * auto q = builder.allocQubitRegister(3); * ``` * ```mlir * %q0 = qc.alloc("q", 3, 0) : !qc.qubit @@ -139,8 +138,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * %q2 = qc.alloc("q", 3, 2) : !qc.qubit * ``` */ - llvm::SmallVector allocQubitRegister(int64_t size, - const std::string& name = "q"); + llvm::SmallVector allocQubitRegister(int64_t size); /** * @brief A small structure representing a single classical bit within a diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 553b07a584..7ecae44229 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -135,12 +135,11 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { /** * @brief Allocate a qubit register * @param size Number of qubits (must be positive) - * @param name Register name (default: "q") * @return Vector of tracked, valid qubit SSA values * * @par Example: * ```c++ - * auto q = builder.allocQubitRegister(3, "q"); + * auto q = builder.allocQubitRegister(3); * ``` * ```mlir * %q0 = qco.alloc("q", 3, 0) : !qco.qubit @@ -148,8 +147,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { * %q2 = qco.alloc("q", 3, 2) : !qco.qubit * ``` */ - llvm::SmallVector allocQubitRegister(int64_t size, - const std::string& name = "q"); + llvm::SmallVector allocQubitRegister(int64_t size); /** * @brief A small structure representing a single classical bit within a diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 6514b2255f..c3f9ffbd35 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index bedddc8648..dbee173ac8 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 2665f2fae9..4c878f598e 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -440,7 +440,7 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult - matchAndRewrite(StaticOp op, OpAdaptor adaptor, + matchAndRewrite(StaticOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { const auto index = static_cast(op.getIndex()); auto& state = getState(); diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index ccf83b8a1e..7f90ed392b 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -95,8 +95,7 @@ Value QCProgramBuilder::staticQubit(const int64_t index) { } llvm::SmallVector -QCProgramBuilder::allocQubitRegister(const int64_t size, - const std::string& name) { +QCProgramBuilder::allocQubitRegister(const int64_t size) { checkFinalized(); if (size <= 0) { diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index ee94512d5b..f250201cc4 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -97,8 +97,8 @@ allocateQregs(QCProgramBuilder& builder, // Allocate quantum registers using the builder SmallVector qregs; for (const auto* qregPtr : qregPtrs) { - auto qubits = builder.allocQubitRegister( - static_cast(qregPtr->getSize()), qregPtr->getName()); + auto qubits = + builder.allocQubitRegister(static_cast(qregPtr->getSize())); qregs.emplace_back(qregPtr, std::move(qubits)); } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 4998d3804a..bcfef08784 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -101,8 +101,7 @@ Value QCOProgramBuilder::staticQubit(const int64_t index) { } llvm::SmallVector -QCOProgramBuilder::allocQubitRegister(const int64_t size, - const std::string& name) { +QCOProgramBuilder::allocQubitRegister(const int64_t size) { checkFinalized(); if (size <= 0) { diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 0baec89da8..517c15ed58 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -23,8 +23,8 @@ void allocQubit(QCProgramBuilder& b) { b.allocQubit(); } void allocQubitRegister(QCProgramBuilder& b) { b.allocQubitRegister(2); } void allocMultipleQubitRegisters(QCProgramBuilder& b) { - b.allocQubitRegister(2, "reg0"); - b.allocQubitRegister(3, "reg1"); + b.allocQubitRegister(2); + b.allocQubitRegister(3); } void allocLargeRegister(QCProgramBuilder& b) { b.allocQubitRegister(100); } diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index fd681c8d11..cf119f62c3 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -27,8 +27,8 @@ void allocQubit(QCOProgramBuilder& b) { b.allocQubit(); } void allocQubitRegister(QCOProgramBuilder& b) { b.allocQubitRegister(2); } void allocMultipleQubitRegisters(QCOProgramBuilder& b) { - b.allocQubitRegister(2, "reg0"); - b.allocQubitRegister(3, "reg1"); + b.allocQubitRegister(2); + b.allocQubitRegister(3); } void allocLargeRegister(QCOProgramBuilder& b) { b.allocQubitRegister(100); } From 4622c9564fe6f98ed9af8dc42fcb15ed2ea7fa3f Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:06:39 +0200 Subject: [PATCH 14/71] Resolve some TODOs --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 62 +++++++++---------- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 1 + .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 2 +- .../Dialect/QTensor/IR/Operations/AllocOp.cpp | 23 ++++--- .../QTensor/IR/Operations/InsertOp.cpp | 7 +-- 5 files changed, 42 insertions(+), 53 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index dbee173ac8..34f44ea05f 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -158,6 +159,10 @@ struct ConvertMemRefAllocOp final LogicalResult matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (!llvm::isa(op.getType().getElementType())) { + return success(); + } + auto& qtensorMap = getState().qtensorMap; auto shape = op.getType().getShape(); @@ -188,6 +193,10 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (!llvm::isa(op.getMemref().getType().getElementType())) { + return success(); + } + auto& qubitMap = getState().qubitMap; auto& qubitInfos = getState().qubitInfos; auto& qtensorMap = getState().qtensorMap; @@ -213,38 +222,6 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { } }; -// struct ConvertMemRefStoreOp final -// : StatefulOpConversionPattern { -// using StatefulOpConversionPattern::StatefulOpConversionPattern; - -// LogicalResult -// matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, -// ConversionPatternRewriter& rewriter) const override { -// auto& qubitMap = getState().qubitMap; -// auto& qtensorMap = getState().qtensorMap; - -// // Look up latest QCO value for this QC qubit -// auto qcQubit = op.getValue(); -// assert(qubitMap.contains(qcQubit) && "QC qubit not found"); -// auto qcoQubit = qubitMap[qcQubit]; - -// // Look up latest QTensor value for this QC register -// auto memref = op.getMemref(); -// assert(qtensorMap.contains(memref) && "QC register not found"); -// auto qtensor = qtensorMap[memref]; - -// auto store = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, -// qtensor, adaptor.getIndices()[0]); - -// qubitMap.erase(qcQubit); -// qtensorMap[memref] = store.getResult(); - -// rewriter.eraseOp(op); - -// return success(); -// } -// }; - struct ConvertMemRefDeallocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -252,6 +229,10 @@ struct ConvertMemRefDeallocOp final LogicalResult matchAndRewrite(memref::DeallocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { + if (!llvm::isa(op.getMemref().getType().getElementType())) { + return success(); + } + auto& qubitMap = getState().qubitMap; auto& qubitInfos = getState().qubitInfos; auto& qtensorMap = getState().qtensorMap; @@ -1280,11 +1261,24 @@ struct QCToQCO final : impl::QCToQCOBase { QCToQCOTypeConverter typeConverter(context); // Configure conversion target - // TODO: Do not blanket-illegalize memref - target.addIllegalDialect(); + target.addIllegalDialect(); target.addLegalDialect(); + target.addDynamicallyLegalOp([&](memref::AllocOp op) { + return !llvm::isa(op.getType().getElementType()); + }); + + target.addDynamicallyLegalOp([&](memref::LoadOp op) { + return !llvm::isa( + op.getMemref().getType().getElementType()); + }); + + target.addDynamicallyLegalOp([&](memref::DeallocOp op) { + return !llvm::isa( + op.getMemref().getType().getElementType()); + }); + // Register operation conversion patterns with state tracking patterns.add< ConvertMemRefAllocOp, ConvertMemRefLoadOp, ConvertMemRefDeallocOp, diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 7f90ed392b..dfa0ea3ed6 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -119,6 +119,7 @@ QCProgramBuilder::allocQubitRegister(const int64_t size) { allocatedMemrefs.insert(memref); + // TODO: Return register return qubits; } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index bcfef08784..2316fc23d4 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -119,7 +119,7 @@ QCOProgramBuilder::allocQubitRegister(const int64_t size) { qubits.emplace_back(qubit); } - // TODO: Return qtensor + // TODO: Return register return qubits; } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp index d6eb6fc693..898b8b6412 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp @@ -45,18 +45,17 @@ LogicalResult AllocOp::verify() { if (sizeValue && *sizeValue <= 0) { return emitOpError("Constant size operand must be positive"); } - // TODO: Deal with this - // if (sizeValue.has_value() == resultType.isDynamicDim(0)) { - // return emitOpError("Size operand and result type must both be static or " - // "both be dynamic, but got ") - // << (sizeValue ? "static size with dynamic result" - // : "dynamic size with static result"); - // } - // if (sizeValue && resultSize != *sizeValue) { - // return emitOpError("Constant size operand (") - // << *sizeValue << ") does not match static result size (" - // << resultSize << ")"; - // } + if (sizeValue.has_value() == resultType.isDynamicDim(0)) { + return emitOpError("Size operand and result type must both be static or " + "both be dynamic, but got ") + << (sizeValue ? "static size with dynamic result" + : "dynamic size with static result"); + } + if (sizeValue && resultSize != *sizeValue) { + return emitOpError("Constant size operand (") + << *sizeValue << ") does not match static result size (" + << resultSize << ")"; + } return success(); } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index e312d73244..a123f467a8 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -8,7 +8,6 @@ * Licensed under the MIT License */ -#include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include @@ -56,12 +55,8 @@ struct RemoveExtractInsertPair final : OpRewritePattern { return failure(); } - // TODO: Improve this - auto qubit = qco::AllocOp::create(rewriter, op.getLoc()); - rewriter.replaceOp(extractOp, {extractOp.getTensor(), qubit.getResult()}); - qco::DeallocOp::create(rewriter, op.getLoc(), qubit.getResult()); - rewriter.replaceOp(op, op.getDest()); + rewriter.replaceOp(extractOp, {extractOp.getTensor(), nullptr}); return success(); } From 56f57717b173835aefe5f3f4506d5da989f713ae Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 16:08:31 +0200 Subject: [PATCH 15/71] Improve documentation --- .../Dialect/QC/Builder/QCProgramBuilder.h | 8 +- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 17 +- .../Dialect/QIR/Builder/QIRProgramBuilder.h | 70 ++++---- .../include/mlir/Dialect/QIR/Utils/QIRUtils.h | 2 +- mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp | 52 +++++- mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp | 53 +++++- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 54 ++++-- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 52 +++++- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 170 +++++++++--------- 9 files changed, 336 insertions(+), 142 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 29c0a7ae6a..dd346f47be 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -133,9 +133,10 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * auto q = builder.allocQubitRegister(3); * ``` * ```mlir - * %q0 = qc.alloc("q", 3, 0) : !qc.qubit - * %q1 = qc.alloc("q", 3, 1) : !qc.qubit - * %q2 = qc.alloc("q", 3, 2) : !qc.qubit + * %memref = memref.alloc() : memref<3x!qc.qubit> + * %q0 = memref.load %alloc[%c0] : memref<3x!qc.qubit> + * %q1 = memref.load %alloc[%c1] : memref<3x!qc.qubit> + * %q2 = memref.load %alloc[%c2] : memref<3x!qc.qubit> * ``` */ llvm::SmallVector allocQubitRegister(int64_t size); @@ -940,6 +941,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { /// Track allocated qubits for automatic deallocation llvm::DenseSet allocatedQubits; + /// Track allocated MemRefs for automatic deallocation llvm::DenseSet allocatedMemrefs; /// Check if the builder has been finalized diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 7ecae44229..0211f914f7 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -142,9 +142,10 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { * auto q = builder.allocQubitRegister(3); * ``` * ```mlir - * %q0 = qco.alloc("q", 3, 0) : !qco.qubit - * %q1 = qco.alloc("q", 3, 1) : !qco.qubit - * %q2 = qco.alloc("q", 3, 2) : !qco.qubit + * %t0 = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * %t1, %q0 = qtensor.extract %t0[%c0]: tensor<3x!qco.qubit> + * %t2, %q1 = qtensor.extract %t1[%c1]: tensor<3x!qco.qubit> + * %t3, %q2 = qtensor.extract %t2[%c2]: tensor<3x!qco.qubit> * ``` */ llvm::SmallVector allocQubitRegister(int64_t size); @@ -1345,10 +1346,16 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ void updateQubitTracking(Value inputQubit, Value outputQubit); + /// Count unique tensors int64_t tensorCounter = 0; + /** + * @brief Information about a qubit + */ struct QubitInfo { + /// ID of the register the qubit belongs to int64_t regId = -1; + /// Index of the qubit within its register int64_t regIndex = -1; }; @@ -1374,7 +1381,11 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ void updateTensorTracking(Value inputTensor, Value outputTensor); + /** + * @brief Information about a tensor + */ struct TensorInfo { + /// ID of the register the tensor corresponds to int64_t regId = -1; }; diff --git a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h index 12ef6f64bd..634ab64133 100644 --- a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h @@ -233,10 +233,10 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * @brief Measure a qubit and record the result (simple version) * * @details - * Performs a Z-basis measurement using __quantum__qis__mz__body. The - * result is tracked for deferred output recording in the output block. - * This version does NOT include register information, so output will - * not be grouped by register. + * Performs a Z-basis measurement using `__quantum__qis__mz__body`. + * + * The output is recorded via `__quantum__rt__result_record_output` during + * `finalize()`. * * @param qubit The qubit to measure * @param resultIndex The classical bit index for result pointer @@ -247,12 +247,17 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * auto result = builder.measure(q0, 0); * ``` * ```mlir + * // In entry block: + * %zero = llvm.mlir.zero : !llvm.ptr + * %r = llvm.call @"@__quantum__rt__result_allocate"(%zero) : !llvm.ptr -> + * !llvm.ptr + * * // In measurements block: - * %c0 = llvm.mlir.constant(0 : i64) : i64 - * %r = llvm.inttoptr %c0 : i64 to !llvm.ptr * llvm.call @__quantum__qis__mz__body(%q0, %r) : (!llvm.ptr, !llvm.ptr) -> () * - * // Output recording deferred to output block + * // In output block: + * llvm.call @__quantum__rt__result_record_output(%r, %label) : (!llvm.ptr, + * !llvm.ptr) -> () * ``` */ Value measure(Value qubit, int64_t resultIndex); @@ -261,12 +266,10 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * @brief Measure a qubit into a classical register * * @details - * Performs a Z-basis measurement using __quantum__qis__mz__body and tracks - * the measurement with register information for array-based output recording. - * Output recording is deferred to the output block during finalize(), where - * measurements are grouped by register and recorded using: - * 1. __quantum__rt__array_record_output for each register - * 2. __quantum__rt__result_record_output for each measurement in the register + * Performs a Z-basis measurement using `__quantum__qis__mz__body`. + * + * The output is recorded via `__quantum__rt__result_array_record_output` + * during `finalize()`. * * @param qubit The qubit to measure * @param bit The classical bit to store the result @@ -276,21 +279,21 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * ```c++ * auto c = builder.allocClassicalBitRegister(2, "c"); * builder.measure(q0, c[0]); - * builder.measure(q1, c[1]); * ``` * ```mlir + * // In entry block: + * %zero = llvm.mlir.zero : !llvm.ptr + * %alloca = llvm.alloca %c2 x !llvm.ptr : (i64) -> !llvm.ptr + * llvm.call @"@__quantum__rt__result_array_allocate"(%c2, %alloca, %zero) : + * (i64, !llvm.ptr, !llvm.ptr) -> () + * %r = llvm.load %alloca : !llvm.ptr -> !llvm.ptr + * * // In measurements block: - * llvm.call @__quantum__qis__mz__body(%q0, %r0) : (!llvm.ptr, !llvm.ptr) -> - * () llvm.call @__quantum__qis__mz__body(%q1, %r1) : (!llvm.ptr, !llvm.ptr) - * -> () + * llvm.call @__quantum__qis__mz__body(%q, %r) : (!llvm.ptr, !llvm.ptr) -> () * - * // In output block (generated during finalize): - * @0 = internal constant [3 x i8] c"c\00" - * @1 = internal constant [5 x i8] c"c0r\00" - * @2 = internal constant [5 x i8] c"c1r\00" - * llvm.call @__quantum__rt__array_record_output(i64 2, ptr @0) - * llvm.call @__quantum__rt__result_record_output(ptr %r0, ptr @1) - * llvm.call @__quantum__rt__result_record_output(ptr %r1, ptr @2) + * // In output block: + * llvm.call @__quantum__rt__result_array_record_output(%c2, %alloca, %label) + * : (i64, !llvm.ptr, !llvm.ptr) -> () * ``` */ QIRProgramBuilder& measure(Value qubit, const Bit& bit); @@ -309,7 +312,7 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * builder.reset(q); * ``` * ```mlir - * llvm.call @__quantum__qis__reset__body(%q) : (!llvm.ptr) -> () + * llvm.call @__quantum__qis__reset__body(%q) : !llvm.ptr -> () * ``` */ QIRProgramBuilder& reset(Value qubit); @@ -350,7 +353,7 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * builder.OP_NAME(q); \ * ``` \ * ```mlir \ - * llvm.call @__quantum__qis__##QIR_NAME##__body(%q) : (!llvm.ptr) -> () \ + * llvm.call @__quantum__qis__##QIR_NAME##__body(%q) : !llvm.ptr -> () \ * ``` \ */ \ QIRProgramBuilder& OP_NAME(Value qubit); \ @@ -840,11 +843,10 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * @brief Finalize the program and return the constructed module * * @details - * Automatically deallocates all remaining allocated qubits, generates - * array-based output recording in the output block (grouped by register), - * ensures proper QIR metadata attributes are set, and transfers ownership - * of the module to the caller. The builder should not be used after calling - * this method. + * Automatically deallocates all remaining allocated qubits, generates output + * recording in the output block, ensures proper QIR metadata attributes are + * set, and transfers ownership of the module to the caller. The builder + * should not be used after calling this method. * * @return OwningOpRef containing the constructed QIR program module */ @@ -927,10 +929,8 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * @brief Generate array-based output recording in the output block * * @details - * Called by finalize() to generate output recording calls for all tracked - * measurements. Groups measurements by register and generates: - * 1. array_record_output for each register - * 2. result_record_output for each measurement in the register + * Called by `finalize()` to generate output recording calls for all tracked + * measurements. */ void generateOutputRecording(); diff --git a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h index 0bc326fcbb..a3f0fbb0c8 100644 --- a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h +++ b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h @@ -180,7 +180,7 @@ LLVM::LLVMFuncOp getMainFunction(Operation* op); * - `required_num_qubits`: Number of qubits used * - `required_num_results`: Number of measurement results * - `qir_major_version`: 2 - * - `qir_minor_version`: 0 + * - `qir_minor_version`: 1 * - `dynamic_qubit_management`: true/false * - `dynamic_result_management`: true/false * diff --git a/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp b/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp index ec1d7541a2..6e2476b244 100644 --- a/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp +++ b/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp @@ -383,6 +383,18 @@ static LogicalResult cleanUp(Operation* op) { namespace { +/** + * @brief Converts jeff.qureg_alloc to qtensor.alloc + * + * @par Example: + * ```mlir + * %qureg = jeff.qureg_alloc(%c3) : !jeff.qureg + * ``` + * is converted to + * ```mlir + * %tensor = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * ``` + */ struct ConvertJeffQuregAllocOpToQCO final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -397,6 +409,19 @@ struct ConvertJeffQuregAllocOpToQCO final } }; +/** + * @brief Converts jeff.qureg_extract_index to qtensor.extract + * + * @par Example: + * ```mlir + * %qureg_out, %q = jeff.qureg_extract_index(%c0) %qureg_in : !jeff.qureg, + * !jeff.qubit + * ``` + * is converted to + * ```mlir + * %tensor_out, %q = qtensor.extract %tensor_in[%c0]: tensor<3x!qco.qubit> + * ``` + */ struct ConvertJeffQuregExtractIndexOpToQCO final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -412,6 +437,18 @@ struct ConvertJeffQuregExtractIndexOpToQCO final } }; +/** + * @brief Converts jeff.qureg_insert_index to qtensor.insert + * + * @par Example: + * ```mlir + * %qureg_out = jeff.qureg_insert_index(%c0) %qureg_in %q : !jeff.qureg + * ``` + * is converted to + * ```mlir + * %tensor_out = qtensor.insert %q into %tensor_in[%c0] : tensor<3x!qco.qubit> + * ``` + */ struct ConvertJeffQuregInsertIndexOpToQCO final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -429,6 +466,18 @@ struct ConvertJeffQuregInsertIndexOpToQCO final } }; +/** + * @brief Converts jeff.qureg_free_zero to qtensor.dealloc + * + * @par Example: + * ```mlir + * jeff.qureg_free_zero %qureg : !jeff.qureg + * ``` + * is converted to + * ```mlir + * qtensor.dealloc %tensor : tensor<3x!qco.qubit> + * ``` + */ struct ConvertJeffQuregFreeZeroOpToQCO final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -953,7 +1002,8 @@ struct ConvertJeffMainToQCO final : OpConversionPattern { * @brief Type converter for Jeff-to-QCO conversion * * @details - * Converts `!jeff.qubit` to `!qco.qubit`. + * Converts `!jeff.qubit` to `!qco.qubit` and `!jeff.qureg` to + * `!tensor`. */ class JeffToQCOTypeConverter final : public TypeConverter { public: diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index 566b201e44..802a713e18 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -247,6 +247,18 @@ static LogicalResult cleanUp(Operation* op, LoweringState& state) { namespace { +/** + * @brief Converts qtensor.alloc to jeff.qureg_alloc + * + * @par Example: + * ```mlir + * %tensor = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %qureg = jeff.qureg_alloc(%c3) : !jeff.qureg + * ``` + */ struct ConvertQTensorAllocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -267,6 +279,19 @@ struct ConvertQTensorAllocOp final } }; +/** + * @brief Converts qtensor.extract to jeff.qureg_extract_index + * + * @par Example: + * ```mlir + * %tensor_out, %q = qtensor.extract %tensor_in[%c0]: tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %qureg_out, %q = jeff.qureg_extract_index(%c0) %qureg_in : !jeff.qureg, + * !jeff.qubit + * ``` + */ struct ConvertQTensorExtractOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -288,6 +313,18 @@ struct ConvertQTensorExtractOp final } }; +/** + * @brief Converts qtensor.insert to jeff.qureg_insert_index + * + * @par Example: + * ```mlir + * %tensor_out = qtensor.insert %q into %tensor_in[%c0] : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %qureg_out = jeff.qureg_insert_index(%c0) %qureg_in %q : !jeff.qureg + * ``` + */ struct ConvertQTensorInsertOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -309,6 +346,18 @@ struct ConvertQTensorInsertOp final } }; +/** + * @brief Converts qtensor.dealloc to jeff.qureg_free_zero + * + * @par Example: + * ```mlir + * qtensor.dealloc %tensor : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * jeff.qureg_free_zero %qureg : !jeff.qureg + * ``` + */ struct ConvertQTensorDeallocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -1398,8 +1447,8 @@ struct ConvertQCOMainToJeff final : StatefulOpConversionPattern { * @brief Type converter for QCO-to-Jeff conversion * * @details - * Converts `!qco.qubit` to `!jeff.qubit` and tensor to - * tensor. + * Converts `!qco.qubit` to `!jeff.qubit` and `tensor` to + * `!jeff.qureg`. */ class QCOToJeffTypeConverter final : public TypeConverter { public: diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index c3f9ffbd35..ac57742692 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -76,6 +76,18 @@ class QCOToQCTypeConverter final : public TypeConverter { } }; +/** + * @brief Converts qtensor.alloc to memref.alloc + * + * @par Example: + * ```mlir + * %tensor = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %memref = memref.alloc(%c3) : memref<3x!qc.qubit> + * ``` + */ struct ConvertQTensorAllocOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -89,6 +101,18 @@ struct ConvertQTensorAllocOp final : OpConversionPattern { } }; +/** + * @brief Converts qtensor.extract to memref.load + * + * @par Example: + * ```mlir + * %tensor_out, %q = qtensor.extract %tensor_in[%c0]: tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %q = memref.load %memref[%c0] : memref<3x!qc.qubit> + * ``` + */ struct ConvertQTensorExtractOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -102,21 +126,9 @@ struct ConvertQTensorExtractOp final : OpConversionPattern { } }; -// struct ConvertQTensorInsertOp final : OpConversionPattern -// { -// using OpConversionPattern::OpConversionPattern; - -// LogicalResult -// matchAndRewrite(qtensor::InsertOp op, OpAdaptor adaptor, -// ConversionPatternRewriter& rewriter) const override { -// auto store = -// memref::StoreOp::create(rewriter, op.getLoc(), adaptor.getScalar(), -// adaptor.getDest(), adaptor.getIndex()); -// rewriter.replaceOp(op, adaptor.getDest()); -// return success(); -// } -// }; - +/** + * @brief Removes qtensor.insert operations + */ struct ConvertQTensorInsertOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -128,6 +140,18 @@ struct ConvertQTensorInsertOp final : OpConversionPattern { } }; +/** + * @brief Converts qtensor.dealloc to memref.dealloc + * + * @par Example: + * ```mlir + * qtensor.dealloc %tensor : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * memref.dealloc %memref : memref<3x!qc.qubit> + * ``` + */ struct ConvertQTensorDeallocOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 34f44ea05f..0ff3ad4e6c 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -51,8 +51,13 @@ using namespace qc; namespace { +/** + * @brief Information about a qubit + */ struct QubitInfo { + /// Register the qubit belongs to Value reg; + /// Index of the qubit within its register Value index; }; @@ -85,11 +90,13 @@ struct QubitInfo { * - %q2 after the X gate */ struct LoweringState { - /// Map from original QC qubit references to their latest QCO SSA values + /// Map from original QC qubit reference to its latest QCO SSA value llvm::DenseMap qubitMap; + /// Map from original MemRef to its latest QTensor SSA value llvm::DenseMap qtensorMap; + /// Map from original QC qubit reference to its register information llvm::DenseMap qubitInfos; /// Modifier information @@ -152,6 +159,18 @@ class QCToQCOTypeConverter final : public TypeConverter { } }; +/** + * @brief Converts memref.alloc to qtensor.alloc + * + * @par Example: + * ```mlir + * %memref = memref.alloc(%c3) : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * %tensor = qtensor.alloc(%c3) : tensor<3x!qco.qubit> + * ``` + */ struct ConvertMemRefAllocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -187,6 +206,18 @@ struct ConvertMemRefAllocOp final } }; +/** + * @brief Converts memref.load to qtensor.extract + * + * @par Example: + * ```mlir + * %q = memref.load %memref[%c0] : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * %tensor_out, %q = qtensor.extract %tensor_in[%c0]: tensor<3x!qco.qubit> + * ``` + */ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -222,6 +253,25 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { } }; +/** + * @brief Converts memref.dealloc to qtensor.dealloc + * + * @details + * Before deallocating the tensor, all qubits are inserted back into it at their + * original location. + * + * @par Example: + * ```mlir + * memref.dealloc %memref : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * %t1 = qtensor.insert %q0 into %t0[%c0] : tensor<3x!qco.qubit> + * %t2 = qtensor.insert %q1 into %t1[%c1] : tensor<3x!qco.qubit> + * %t3 = qtensor.insert %q2 into %t2[%c2] : tensor<3x!qco.qubit> + * qtensor.dealloc %t3 : tensor<3x!qco.qubit> + * ``` + */ struct ConvertMemRefDeallocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 4c878f598e..2f490c20cb 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -65,14 +65,6 @@ namespace { /** * @brief State object for tracking lowering information during QIR conversion - * - * @details - * This struct maintains state during the conversion of QC dialect - * operations to QIR (Quantum Intermediate Representation). It tracks: - * - Qubit and result counts for QIR metadata - * - Pointer value caching for reuse - * - Whether dynamic memory management is needed - * - Sequence of measurements for output recording */ struct LoweringState : QIRMetadata { /// Map from index to qubit pointer @@ -91,7 +83,7 @@ struct LoweringState : QIRMetadata { int64_t inCtrlOp = 0; DenseMap> controls; - // Block information + /// Block information Block* entryBlock{}; Block* measurementsBlock{}; }; @@ -203,6 +195,7 @@ namespace { * * Type conversions: * - `!qc.qubit` -> `!llvm.ptr` (opaque pointer to qubit in QIR) + * - `memref` -> `!llvm.ptr` (opaque pointer to array in QIR) */ struct QCToQIRTypeConverter final : LLVMTypeConverter { explicit QCToQIRTypeConverter(MLIRContext* ctx) : LLVMTypeConverter(ctx) { @@ -215,6 +208,21 @@ struct QCToQIRTypeConverter final : LLVMTypeConverter { } }; +/** + * @brief Converts memref.alloc to QIR qubit-array allocation + * + * @par Example: + * ```mlir + * %memref = memref.alloc() : memref<3x!qc.qubit> + * ``` + * becomes: + * ```mlir + * %zero = llvm.mlir.zero : !llvm.ptr + * %alloca = llvm.alloca %c3 x !llvm.ptr : (i64) -> !llvm.ptr + * llvm.call @"@__quantum__rt__qubit_array_allocate"(%c3, %alloca, %zero) : + * (i64, !llvm.ptr, !llvm.ptr) -> () + * ``` + */ struct ConvertMemRefAllocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -261,6 +269,19 @@ struct ConvertMemRefAllocOp final } }; +/** + * @brief Converts memref.load to llvm.load + * + * @par Example: + * ```mlir + * %q = memref.load %memref[%c0] : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * %ptr = llvm.getelementptr %alloca[1] : !llvm.ptr -> !llvm.ptr, !llvm.ptr + * %q = llvm.load %ptr : !llvm.ptr -> !llvm.ptr + * ``` + */ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -283,6 +304,19 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { } }; +/** + * @brief Converts memref.dealloc to QIR qubit-array release + * + * @par Example: + * ```mlir + * memref.dealloc %memref : memref<3x!qc.qubit> + * ``` + * becomes: + * ```mlir + * llvm.call @"@__quantum__rt__qubit_array_release"(%c3, %alloca) : (i64, + * !llvm.ptr) -> () + * ``` + */ struct ConvertMemRefDeallocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -331,15 +365,7 @@ struct ConvertMemRefDeallocOp final }; /** - * @brief Converts qc.alloc operation to static QIR qubit allocations - * - * @details - * QIR 2.0 does not support dynamic qubit allocation. Therefore, qc.alloc - * operations are converted to static qubit references using inttoptr with a - * constant index. - * - * Register metadata (register_name, register_size, register_index) is used to - * provide a reasonable guess for a static qubit index that is still free. + * @brief Converts qc.alloc to QIR qubit allocation * * @par Example: * ```mlir @@ -347,8 +373,9 @@ struct ConvertMemRefDeallocOp final * ``` * becomes: * ```mlir - * %c0 = llvm.mlir.constant(0 : i64) : i64 - * %q0 = llvm.inttoptr %c0 : i64 to !llvm.ptr + * %zero = llvm.mlir.zero : !llvm.ptr + * %q = llvm.call @"@__quantum__rt__qubit_allocate"(%zero) : !llvm.ptr -> + * !llvm.ptr * ``` */ struct ConvertQCAllocOp final : StatefulOpConversionPattern { @@ -375,12 +402,7 @@ struct ConvertQCAllocOp final : StatefulOpConversionPattern { }; /** - * @brief Erases qc.dealloc operations - * - * @details - * Since QIR 2.0 does not support dynamic qubit allocation, dynamic - * allocations are converted to static allocations. Therefore, deallocation - * operations become no-ops and are simply removed from the IR. + * @brief Converts qc.dealloc to QIR qubit release * * @par Example: * ```mlir @@ -388,7 +410,7 @@ struct ConvertQCAllocOp final : StatefulOpConversionPattern { * ``` * becomes: * ```mlir - * // (removed) + * llvm.call @"@__quantum__rt__qubit_release"(%q) : !llvm.ptr -> () * ``` */ struct ConvertQCDeallocOp final : StatefulOpConversionPattern { @@ -419,7 +441,7 @@ struct ConvertQCDeallocOp final : StatefulOpConversionPattern { }; /** - * @brief Converts qc.static operation to QIR inttoptr + * @brief Converts qc.static to llvm.inttoptr * * @details * Converts a static qubit reference to an LLVM pointer by creating a constant @@ -467,18 +489,14 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { }; /** - * @brief Converts qc.measure operation to QIR measurement + * @brief Converts qc.measure to QIR measurement * * @details - * Converts qubit measurement to a QIR call to `__quantum__qis__mz__body`. - * Unlike the previous implementation, this does NOT immediately record - * output. Instead, it tracks measurements in the lowering state for deferred - * output recording in a separate output block, as required by the QIR Base - * Profile. + * For measurements with register information, a result array is allocated and + * all result pointers are loaded. * - * For measurements with register information, the result pointer is mapped - * to (register_name, register_index) for later retrieval. For measurements - * without register information, a sequential result pointer is assigned. + * For measurements without register information, an individual result pointer + * is allocated. * * @par Example (with register): * ```mlir @@ -486,11 +504,15 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { * ``` * becomes: * ```mlir - * %c0_i64 = llvm.mlir.constant(0 : i64) : i64 - * %result_ptr = llvm.inttoptr %c0_i64 : i64 to !llvm.ptr - * llvm.call @__quantum__qis__mz__body(%q, %result_ptr) : (!llvm.ptr, - * !llvm.ptr) - * -> () + * // In entry block: + * %zero = llvm.mlir.zero : !llvm.ptr + * %alloca = llvm.alloca %c2 x !llvm.ptr : (i64) -> !llvm.ptr + * llvm.call @"@__quantum__rt__result_array_allocate"(%c2, %alloca, %zero) : + * (i64, !llvm.ptr, !llvm.ptr) -> () + * %r = llvm.load %alloca : !llvm.ptr -> !llvm.ptr + * + * // In measurements block: + * llvm.call @__quantum__qis__mz__body(%q, %r) : (!llvm.ptr, !llvm.ptr) -> () * ``` */ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { @@ -601,7 +623,7 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { * ``` * becomes: * ```mlir - * llvm.call @__quantum__qis__reset__body(%q) : (!llvm.ptr) -> () + * llvm.call @__quantum__qis__reset__body(%q) : !llvm.ptr -> () * ``` */ struct ConvertQCResetOp final : StatefulOpConversionPattern { @@ -673,7 +695,7 @@ struct ConvertQCGPhaseOp final : StatefulOpConversionPattern { * ``` \ * is converted to \ * ```mlir \ - * llvm.call @__quantum__qis__QIR_NAME__body(%q) : (!llvm.ptr) -> () \ + * llvm.call @__quantum__qis__QIR_NAME__body(%q) : !llvm.ptr -> () \ * ``` \ */ \ struct ConvertQC##OP_CLASS final : StatefulOpConversionPattern { \ @@ -1006,11 +1028,12 @@ struct ConvertQCYieldOp final : StatefulOpConversionPattern { * * Conversion stages: * 1. Convert func dialect to LLVM - * 2. Ensure proper block structure for QIR base profile and add initialization + * 2. Ensure proper block structure for QIR base profile * 3. Convert QC operations to QIR calls - * 4. Set QIR metadata attributes - * 5. Convert arith and cf dialects to LLVM - * 6. Reconcile unrealized casts + * 4. Add QIR initialization call + * 5. Set QIR metadata attributes + * 6. Convert arith and cf dialects to LLVM + * 7. Reconcile unrealized casts * * @pre * The input entry function must consist of a single block. The pass will @@ -1139,29 +1162,12 @@ struct QCToQIR final : impl::QCToQIRBase { * measurements tracked during conversion. Follows the QIR Base Profile * specification for labeled output schema. * - * For each classical register, creates: - * 1. An array_record_output call with the register size and label - * 2. Individual result_record_output calls for each measurement in the - * register - * - * Labels follow the format: "{registerName}{resultIndex}r" - * - registerName: Name of the classical register (e.g., "c") - * - resultIndex: Index within the array - * - 'r' suffix: Indicates this is a result record + * Results that are part of registers are recorded via + * `__quantum__rt__result_array_record_output`. * - * Example output: - * ``` - * @0 = internal constant [3 x i8] c"c\00" - * @1 = internal constant [5 x i8] c"c0r\00" - * @2 = internal constant [5 x i8] c"c1r\00" - * call void @__quantum__rt__array_record_output(i64 2, ptr @0) - * call void @__quantum__rt__result_record_output(ptr %result0, ptr @1) - * call void @__quantum__rt__result_record_output(ptr %result1, ptr @2) - * ``` - * - * Any output recording calls that are not part of registers (i.e., - * measurements without register info) are grouped under a default label "c" - * and recorded similarly. + * Results that are not part of registers (i.e., measurements without register + * info) are grouped under a default `__unnamed__` label recorded via + * `__quantum__rt__result_record_output`. * * @param main The main LLVM function * @param ctx The MLIR context @@ -1241,30 +1247,32 @@ struct QCToQIR final : impl::QCToQIRBase { * @brief Executes the QC to QIR conversion pass * * @details - * Performs the conversion in six stages: + * Performs the conversion in seven stages: * * **Stage 1: Func to LLVM** * Convert func dialect operations (main function) to LLVM dialect * equivalents. * - * **Stage 2: Block structure and initialization** + * **Stage 2: Block structure** * Create proper 4-block structure for QIR base profile (entry, main, - * irreversible, output) and insert the `__quantum__rt__initialize` call in - * the entry block. + * irreversible, output). * * **Stage 3: QC to LLVM** * Convert QC dialect operations to QIR calls and add output recording to the * output block. * - * **Stage 4: QIR attributes** + * **Stage 4: Initialization** + * Insert the `__quantum__rt__initialize` call. + * + * **Stage 5: QIR attributes** * Add QIR base profile metadata to the main function, including qubit/result * counts and version information. * - * **Stage 5: Standard dialects to LLVM** + * **Stage 6: Standard dialects to LLVM** * Convert arith and control flow dialects to LLVM (for index arithmetic and * function control flow). * - * **Stage 6: Reconcile casts** + * **Stage 7: Reconcile casts** * Clean up any unrealized cast operations introduced during type conversion. */ void runOnOperation() override { @@ -1328,13 +1336,13 @@ struct QCToQIR final : impl::QCToQIRBase { addOutputRecording(main, ctx, &state); } - // Stage ?: Insert initialize call + // Stage 4: Insert initialize call addInitialize(main, ctx); - // Stage 4: Set QIR metadata attributes + // Stage 5: Set QIR metadata attributes setQIRAttributes(main, state); - // Stage 5: Convert standard dialects to LLVM + // Stage 6: Convert standard dialects to LLVM { RewritePatternSet stdPatterns(ctx); target.addIllegalDialect(); @@ -1351,7 +1359,7 @@ struct QCToQIR final : impl::QCToQIRBase { } } - // Stage 6: Reconcile unrealized casts + // Stage 7: Reconcile unrealized casts PassManager passManager(ctx); passManager.addPass(createReconcileUnrealizedCastsPass()); if (passManager.run(moduleOp).failed()) { From f8619ed9b5eda1f38ea1b52c5e7383d32f695135 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 18:56:31 +0200 Subject: [PATCH 16/71] Address the Rabbit's comments --- .../Dialect/QC/Builder/QCProgramBuilder.h | 6 +-- .../Dialect/QIR/Builder/QIRProgramBuilder.h | 2 + .../include/mlir/Dialect/QIR/Utils/QIRUtils.h | 16 +++--- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 1 + mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 3 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 50 +++++++++++-------- .../lib/Dialect/QCO/IR/Operations/ResetOp.cpp | 25 +++++++++- .../Dialect/QIR/Builder/QIRProgramBuilder.cpp | 3 +- .../QTensor/IR/Operations/InsertOp.cpp | 21 ++++++-- .../Conversion/JeffRoundTrip/CMakeLists.txt | 3 ++ 10 files changed, 87 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index dd346f47be..92329f1674 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -134,9 +134,9 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * ``` * ```mlir * %memref = memref.alloc() : memref<3x!qc.qubit> - * %q0 = memref.load %alloc[%c0] : memref<3x!qc.qubit> - * %q1 = memref.load %alloc[%c1] : memref<3x!qc.qubit> - * %q2 = memref.load %alloc[%c2] : memref<3x!qc.qubit> + * %q0 = memref.load %memref[%c0] : memref<3x!qc.qubit> + * %q1 = memref.load %memref[%c1] : memref<3x!qc.qubit> + * %q2 = memref.load %memref[%c2] : memref<3x!qc.qubit> * ``` */ llvm::SmallVector allocQubitRegister(int64_t size); diff --git a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h index 634ab64133..9da23be8d7 100644 --- a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h @@ -13,8 +13,10 @@ #include "mlir/Dialect/QIR/Utils/QIRMetadata.h" #include +#include #include #include +#include #include #include #include diff --git a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h index a3f0fbb0c8..6a28995c83 100644 --- a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h +++ b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h @@ -35,20 +35,20 @@ namespace mlir::qir { // QIR function names inline constexpr auto QIR_QUBIT_ARRAY_ALLOC = - "@__quantum__rt__qubit_array_allocate"; + "__quantum__rt__qubit_array_allocate"; inline constexpr auto QIR_QUBIT_ARRAY_RELEASE = - "@__quantum__rt__qubit_array_release"; + "__quantum__rt__qubit_array_release"; -inline constexpr auto QIR_QUBIT_ALLOC = "@__quantum__rt__qubit_allocate"; -inline constexpr auto QIR_QUBIT_RELEASE = "@__quantum__rt__qubit_release"; +inline constexpr auto QIR_QUBIT_ALLOC = "__quantum__rt__qubit_allocate"; +inline constexpr auto QIR_QUBIT_RELEASE = "__quantum__rt__qubit_release"; inline constexpr auto QIR_RESULT_ARRAY_ALLOC = - "@__quantum__rt__result_array_allocate"; + "__quantum__rt__result_array_allocate"; inline constexpr auto QIR_RESULT_ARRAY_RELEASE = - "@__quantum__rt__result_array_release"; + "__quantum__rt__result_array_release"; -inline constexpr auto QIR_RESULT_ALLOC = "@__quantum__rt__result_allocate"; -inline constexpr auto QIR_RESULT_RELEASE = "@__quantum__rt__result_release"; +inline constexpr auto QIR_RESULT_ALLOC = "__quantum__rt__result_allocate"; +inline constexpr auto QIR_RESULT_RELEASE = "__quantum__rt__result_release"; inline constexpr auto QIR_INITIALIZE = "__quantum__rt__initialize"; inline constexpr auto QIR_MEASURE = "__quantum__qis__mz__body"; diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 0ff3ad4e6c..bb496f4f6a 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -320,6 +320,7 @@ struct ConvertMemRefDeallocOp final auto insert = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, qtensor, index); qtensor = insert.getResult(); + qubitMap.erase(qcQubit); qubitInfos.erase(qcQubit); } diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 2f490c20cb..f71357a4ca 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -579,8 +579,7 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { result = loadedResults.at({registerName, registerIndex}); } else { - auto fnSig = - LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {ptrType}); + auto fnSig = LLVM::LLVMFunctionType::get(ptrType, {ptrType}); auto fnDec = getOrCreateFunctionDeclaration(rewriter, op, QIR_RESULT_ALLOC, fnSig); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 2316fc23d4..a101c3c58b 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -75,7 +75,7 @@ Value QCOProgramBuilder::allocQubit() { checkFinalized(); auto allocOp = AllocOp::create(*this); - const auto qubit = allocOp.getResult(); + auto qubit = allocOp.getResult(); // Track the allocated qubit as valid validQubits.insert({qubit, {}}); @@ -92,7 +92,7 @@ Value QCOProgramBuilder::staticQubit(const int64_t index) { auto indexAttr = getI64IntegerAttr(index); auto staticOp = StaticOp::create(*this, indexAttr); - const auto qubit = staticOp.getQubit(); + auto qubit = staticOp.getQubit(); // Track the static qubit as valid validQubits.insert({qubit, {}}); @@ -245,8 +245,10 @@ std::pair QCOProgramBuilder::qtensorExtract(Value tensor, auto qubit = extractOp.getResult(); auto outTensor = extractOp.getOutTensor(); - validQubits.insert( - {qubit, {.regId = validTensors[tensor].regId, .regIndex = index}}); + validateTensorValue(tensor); + const auto regId = validTensors[tensor].regId; + + validQubits.insert({qubit, {.regId = regId, .regIndex = index}}); updateTensorTracking(tensor, outTensor); return {outTensor, qubit}; @@ -282,6 +284,7 @@ Value QCOProgramBuilder::qtensorInsert( validateQubitValue(scalar); validQubits.erase(scalar); updateTensorTracking(tensor, outTensor); + return outTensor; } @@ -340,7 +343,7 @@ Value QCOProgramBuilder::measure(Value qubit, const Bit& bit) { auto indexAttr = getI64IntegerAttr(bit.registerIndex); auto measureOp = MeasureOp::create(*this, qubit, nameAttr, sizeAttr, indexAttr); - const auto qubitOut = measureOp.getQubitOut(); + auto qubitOut = measureOp.getQubitOut(); // Update tracking updateQubitTracking(qubit, qubitOut); @@ -352,7 +355,7 @@ Value QCOProgramBuilder::reset(Value qubit) { checkFinalized(); auto resetOp = ResetOp::create(*this, qubit); - const auto qubitOut = resetOp.getQubitOut(); + auto qubitOut = resetOp.getQubitOut(); // Update tracking updateQubitTracking(qubit, qubitOut); @@ -408,7 +411,7 @@ DEFINE_ZERO_TARGET_ONE_PARAMETER(GPhaseOp, gphase, theta) Value QCOProgramBuilder::OP_NAME(Value qubit) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit); \ - const auto& qubitOut = op.getQubitOut(); \ + auto qubitOut = op.getQubitOut(); \ updateQubitTracking(qubit, qubitOut); \ return qubitOut; \ } \ @@ -453,7 +456,7 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg) Value qubit) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit, PARAM); \ - const auto& qubitOut = op.getQubitOut(); \ + auto qubitOut = op.getQubitOut(); \ updateQubitTracking(qubit, qubitOut); \ return qubitOut; \ } \ @@ -496,7 +499,7 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, phi) Value qubit) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit, PARAM1, PARAM2); \ - const auto& qubitOut = op.getQubitOut(); \ + auto qubitOut = op.getQubitOut(); \ updateQubitTracking(qubit, qubitOut); \ return qubitOut; \ } \ @@ -543,7 +546,7 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) Value qubit) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit, PARAM1, PARAM2, PARAM3); \ - const auto& qubitOut = op.getQubitOut(); \ + auto qubitOut = op.getQubitOut(); \ updateQubitTracking(qubit, qubitOut); \ return qubitOut; \ } \ @@ -590,8 +593,8 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, u, theta, phi, lambda) Value qubit1) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit0, qubit1); \ - const auto& qubit0Out = op.getQubit0Out(); \ - const auto& qubit1Out = op.getQubit1Out(); \ + auto qubit0Out = op.getQubit0Out(); \ + auto qubit1Out = op.getQubit1Out(); \ updateQubitTracking(qubit0, qubit0Out); \ updateQubitTracking(qubit1, qubit1Out); \ return {qubit0Out, qubit1Out}; \ @@ -634,8 +637,8 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr) const std::variant&(PARAM), Value qubit0, Value qubit1) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit0, qubit1, PARAM); \ - const auto& qubit0Out = op.getQubit0Out(); \ - const auto& qubit1Out = op.getQubit1Out(); \ + auto qubit0Out = op.getQubit0Out(); \ + auto qubit1Out = op.getQubit1Out(); \ updateQubitTracking(qubit0, qubit0Out); \ updateQubitTracking(qubit1, qubit1Out); \ return {qubit0Out, qubit1Out}; \ @@ -684,8 +687,8 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta) Value qubit1) { \ checkFinalized(); \ auto op = OP_CLASS::create(*this, qubit0, qubit1, PARAM1, PARAM2); \ - const auto& qubit0Out = op.getQubit0Out(); \ - const auto& qubit1Out = op.getQubit1Out(); \ + auto qubit0Out = op.getQubit0Out(); \ + auto qubit1Out = op.getQubit1Out(); \ updateQubitTracking(qubit0, qubit0Out); \ updateQubitTracking(qubit1, qubit1Out); \ return {qubit0Out, qubit1Out}; \ @@ -735,7 +738,7 @@ ValueRange QCOProgramBuilder::barrier(ValueRange qubits) { checkFinalized(); auto op = BarrierOp::create(*this, qubits); - const auto& qubitsOut = op.getQubitsOut(); + auto qubitsOut = op.getQubitsOut(); for (const auto& [inputQubit, outputQubit] : llvm::zip(qubits, qubitsOut)) { updateQubitTracking(inputQubit, outputQubit); } @@ -753,7 +756,7 @@ std::pair QCOProgramBuilder::ctrl( auto ctrlOp = CtrlOp::create(*this, controls, targets); auto& block = ctrlOp.getBodyRegion().emplaceBlock(); - const auto qubitType = QubitType::get(getContext()); + auto qubitType = QubitType::get(getContext()); for (const auto target : targets) { const auto arg = block.addArgument(qubitType, getLoc()); updateQubitTracking(target, arg); @@ -791,8 +794,8 @@ ValueRange QCOProgramBuilder::inv( // Add block arguments for all qubits auto& block = invOp.getBodyRegion().emplaceBlock(); - const auto qubitType = QubitType::get(getContext()); - for (const auto qubit : qubits) { + auto qubitType = QubitType::get(getContext()); + for (auto qubit : qubits) { const auto arg = block.addArgument(qubitType, getLoc()); updateQubitTracking(qubit, arg); } @@ -962,10 +965,15 @@ OwningOpRef QCOProgramBuilder::finalize() { return a.second < b.second; }; + llvm::DenseSet validTensorIds; + for (const auto& [tensor, info] : validTensors) { + validTensorIds.insert(info.regId); + } + llvm::SmallVector freeQubits; llvm::DenseMap registerQubits; for (auto [qubit, info] : validQubits) { - if (info.regId == -1) { + if (info.regId == -1 || !validTensorIds.contains(info.regId)) { freeQubits.push_back(qubit); } else { registerQubits.insert({qubit, info}); diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index 219206cdcd..60d0c4aecb 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include #include #include #include @@ -41,6 +42,21 @@ struct RemoveResetAfterAlloc final : OpRewritePattern { } }; +/** + * @brief Check if a `qtensor.extract` operation ultimately originates from a + * `qtensor.alloc` operation. + */ +static bool originatesFromAlloc(qtensor::ExtractOp extractOp) { + auto* definingOp = extractOp.getTensor().getDefiningOp(); + if (llvm::isa(definingOp)) { + return true; + } + if (llvm::isa(definingOp)) { + return originatesFromAlloc(llvm::cast(definingOp)); + } + return false; +} + /** * @brief Remove reset operations that immediately follow a `qtensor.extract` * operation. @@ -51,8 +67,13 @@ struct RemoveResetAfterExtract final : OpRewritePattern { LogicalResult matchAndRewrite(ResetOp op, PatternRewriter& rewriter) const override { // Check if the predecessor is an ExtractOp - if (auto extractOp = op.getQubitIn().getDefiningOp(); - !extractOp) { + auto extractOp = op.getQubitIn().getDefiningOp(); + if (!extractOp) { + return failure(); + } + + // Check if the tensor originates from an AllocOp + if (!originatesFromAlloc(extractOp)) { return failure(); } diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index 0e0ee74358..b2640f548f 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -221,8 +221,7 @@ Value QIRProgramBuilder::measure(Value qubit, const int64_t resultIndex) { setInsertionPoint(entryBlock->getTerminator()); // Create result pointer - auto fnSig = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(getContext()), {ptrType}); + auto fnSig = LLVM::LLVMFunctionType::get(ptrType, {ptrType}); auto fnDec = getOrCreateFunctionDeclaration(*this, module, QIR_RESULT_ALLOC, fnSig); auto zero = LLVM::ZeroOp::create(*this, ptrType); diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index a123f467a8..23a9861823 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -22,27 +22,38 @@ using namespace mlir; using namespace mlir::qtensor; -static ExtractOp findExtractOp(InsertOp op) { - +/** + * @brief Find the `qtensor.extract` operation for a given `qtensor.insert` + * operation. + */ +static ExtractOp findExtractOp(InsertOp op, Value index) { auto* definingOp = op.getDest().getDefiningOp(); if (llvm::isa(definingOp)) { return llvm::cast(definingOp); } if (llvm::isa(definingOp)) { auto nestedInsertOp = llvm::cast(definingOp); - return findExtractOp(nestedInsertOp); + if (nestedInsertOp.getIndex() == index) { + return nullptr; + } + return findExtractOp(nestedInsertOp, index); } return nullptr; } namespace { +/** + * @brief Remove matching `qtensor.insert` and `qtensor.extract` pairs. + */ struct RemoveExtractInsertPair final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter& rewriter) const override { - auto extractOp = findExtractOp(op); + auto index = op.getIndex(); + + auto extractOp = findExtractOp(op, index); if (!extractOp) { return failure(); } @@ -51,7 +62,7 @@ struct RemoveExtractInsertPair final : OpRewritePattern { return failure(); } - if (op.getIndex() != extractOp.getIndex()) { + if (index != extractOp.getIndex()) { return failure(); } diff --git a/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt b/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt index b60b8f5c91..43c4143b6d 100644 --- a/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt +++ b/mlir/unittests/Conversion/JeffRoundTrip/CMakeLists.txt @@ -24,4 +24,7 @@ target_link_libraries( mqt_mlir_configure_unittest_target(${target_name}) +# TODO(https://github.com/unitaryfoundation/jeff/issues/46): Enable this again once static +# information is preserved + # gtest_discover_tests(${target_name} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) From f52206b963e1268672bf97e367e7c947b8dc6a10 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 20:17:15 +0200 Subject: [PATCH 17/71] Fix linter errors --- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 1 + .../lib/Dialect/QCO/IR/Operations/ResetOp.cpp | 31 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index a101c3c58b..4e5343940d 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Utils/Utils.h" #include +#include #include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index 60d0c4aecb..5a49cdfe1c 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -15,11 +15,27 @@ #include #include #include +#include #include using namespace mlir; using namespace mlir::qco; +/** + * @brief Check if a `qtensor.extract` operation ultimately originates from a + * `qtensor.alloc` operation. + */ +static bool originatesFromAlloc(qtensor::ExtractOp extractOp) { + auto* definingOp = extractOp.getTensor().getDefiningOp(); + if (llvm::isa(definingOp)) { + return true; + } + if (llvm::isa(definingOp)) { + return originatesFromAlloc(llvm::cast(definingOp)); + } + return false; +} + namespace { /** @@ -42,21 +58,6 @@ struct RemoveResetAfterAlloc final : OpRewritePattern { } }; -/** - * @brief Check if a `qtensor.extract` operation ultimately originates from a - * `qtensor.alloc` operation. - */ -static bool originatesFromAlloc(qtensor::ExtractOp extractOp) { - auto* definingOp = extractOp.getTensor().getDefiningOp(); - if (llvm::isa(definingOp)) { - return true; - } - if (llvm::isa(definingOp)) { - return originatesFromAlloc(llvm::cast(definingOp)); - } - return false; -} - /** * @brief Remove reset operations that immediately follow a `qtensor.extract` * operation. From 068de197c932f4d78413d94be7013dd3d1454c08 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:30:06 +0200 Subject: [PATCH 18/71] Address the Rabbit's comments --- .../Dialect/QIR/Builder/QIRProgramBuilder.h | 10 ++-- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 19 +++--- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 58 ++++++++++++++----- .../Dialect/QIR/Builder/QIRProgramBuilder.cpp | 42 ++++++++++---- 4 files changed, 85 insertions(+), 44 deletions(-) diff --git a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h index 9da23be8d7..d31d4f0bdb 100644 --- a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h @@ -845,10 +845,10 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { * @brief Finalize the program and return the constructed module * * @details - * Automatically deallocates all remaining allocated qubits, generates output - * recording in the output block, ensures proper QIR metadata attributes are - * set, and transfers ownership of the module to the caller. The builder - * should not be used after calling this method. + * Automatically deallocates all remaining allocated qubits and result + * pointers, generates output recording in the output block, ensures proper + * QIR metadata attributes are set, and transfers ownership of the module to + * the caller. The builder should not be used after calling this method. * * @return OwningOpRef containing the constructed QIR program module */ @@ -890,7 +890,7 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { /// Exit code constant (created in entry block, used in output block) Value exitCode; - /// Cache static pointers for reuse + /// Cache static qubit pointers for reuse llvm::DenseMap staticQubits; /// Set of qubit-array pointers diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index bb496f4f6a..ca13514866 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1316,18 +1316,13 @@ struct QCToQCO final : impl::QCToQCOBase { target.addLegalDialect(); - target.addDynamicallyLegalOp([&](memref::AllocOp op) { - return !llvm::isa(op.getType().getElementType()); - }); - - target.addDynamicallyLegalOp([&](memref::LoadOp op) { - return !llvm::isa( - op.getMemref().getType().getElementType()); - }); - - target.addDynamicallyLegalOp([&](memref::DeallocOp op) { - return !llvm::isa( - op.getMemref().getType().getElementType()); + target.addDynamicallyLegalDialect([](Operation* op) { + auto isQubitMemref = [](Type t) { + auto mt = llvm::dyn_cast(t); + return mt && llvm::isa(mt.getElementType()); + }; + return llvm::none_of(op->getOperandTypes(), isQubitMemref) && + llvm::none_of(op->getResultTypes(), isQubitMemref); }); // Register operation conversion patterns with state tracking diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index f71357a4ca..136f73d9dc 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -67,8 +67,11 @@ namespace { * @brief State object for tracking lowering information during QIR conversion */ struct LoweringState : QIRMetadata { - /// Map from index to qubit pointer - DenseMap ptrMap; + /// Cache static qubit pointers for reuse + DenseMap staticQubits; + + /// Cache MemRef sizes for reuse + DenseMap memrefSizes; /// Map from register name to result-array pointer llvm::StringMap resultArrays; @@ -256,6 +259,7 @@ struct ConvertMemRefAllocOp final rewriter.getI64IntegerAttr(static_cast(shape[0]))) .getResult(); } + state.memrefSizes.try_emplace(op.getMemref(), size); auto array = LLVM::AllocaOp::create(rewriter, op.getLoc(), ptrType, ptrType, size); @@ -339,22 +343,13 @@ struct ConvertMemRefDeallocOp final // Switch to measurements block rewriter.setInsertionPoint(getState().measurementsBlock->getTerminator()); - Value size; - if (shape[0] == ShapedType::kDynamic) { - size = - op.getMemref().getDefiningOp().getDynamicSizes()[0]; - } else { - size = LLVM::ConstantOp::create( - rewriter, op.getLoc(), - rewriter.getI64IntegerAttr(static_cast(shape[0]))) - .getResult(); - } - auto fnSig = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {i64Type, ptrType}); auto fnDec = getOrCreateFunctionDeclaration(rewriter, op, QIR_QUBIT_ARRAY_RELEASE, fnSig); + auto size = getState().memrefSizes.lookup(op.getMemref()); + // Create the release call LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, ValueRange{size, adaptor.getMemref()}); @@ -469,13 +464,14 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { // Get or create a pointer to the qubit Value qubit; - if (const auto it = state.ptrMap.find(index); it != state.ptrMap.end()) { + if (const auto it = state.staticQubits.find(index); + it != state.staticQubits.end()) { // Reuse existing pointer qubit = it->second; } else { // Create and cache for reuse qubit = createPointerFromIndex(rewriter, op.getLoc(), index); - state.ptrMap.try_emplace(index, qubit); + state.staticQubits.try_emplace(index, qubit); } rewriter.replaceOp(op, qubit); @@ -1241,6 +1237,36 @@ struct QCToQIR final : impl::QCToQIRBase { } } + /** + * @brief Iterates through all result pointers and releases them + */ + static void releaseResults(LLVM::LLVMFuncOp& main, MLIRContext* ctx, + LoweringState* state) { + OpBuilder builder(ctx); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto voidType = LLVM::LLVMVoidType::get(ctx); + + // Switch to measurements block + builder.setInsertionPoint(state->measurementsBlock->getTerminator()); + + for (auto& [_, ptr] : state->resultPtrs) { + auto sig = LLVM::LLVMFunctionType::get(voidType, {ptrType}); + auto dec = getOrCreateFunctionDeclaration(builder, main, + QIR_RESULT_RELEASE, sig); + LLVM::CallOp::create(builder, main->getLoc(), dec, ptr); + } + + for (auto& [_, array] : state->resultArrays) { + auto sig = LLVM::LLVMFunctionType::get(voidType, + {builder.getI64Type(), ptrType}); + auto dec = getOrCreateFunctionDeclaration(builder, main, + QIR_RESULT_ARRAY_RELEASE, sig); + auto size = array.getDefiningOp().getArraySize(); + LLVM::CallOp::create(builder, main->getLoc(), dec, + ValueRange{size, array}); + } + } + protected: /** * @brief Executes the QC to QIR conversion pass @@ -1333,6 +1359,8 @@ struct QCToQIR final : impl::QCToQIRBase { } addOutputRecording(main, ctx, &state); + + releaseResults(main, ctx, &state); } // Stage 4: Insert initialize call diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index b2640f548f..1582667414 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -220,15 +220,18 @@ Value QIRProgramBuilder::measure(Value qubit, const int64_t resultIndex) { // Insert allocations and constants in entry block setInsertionPoint(entryBlock->getTerminator()); - // Create result pointer - auto fnSig = LLVM::LLVMFunctionType::get(ptrType, {ptrType}); - auto fnDec = - getOrCreateFunctionDeclaration(*this, module, QIR_RESULT_ALLOC, fnSig); - auto zero = LLVM::ZeroOp::create(*this, ptrType); - auto result = - LLVM::CallOp::create(*this, fnDec, zero.getResult()).getResult(); - - resultPtrs.try_emplace(resultIndex, result); + // Get or create result pointer + Value result; + if (const auto it = resultPtrs.find(resultIndex); it != resultPtrs.end()) { + result = it->second; + } else { + auto fnSig = LLVM::LLVMFunctionType::get(ptrType, {ptrType}); + auto fnDec = + getOrCreateFunctionDeclaration(*this, module, QIR_RESULT_ALLOC, fnSig); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + result = LLVM::CallOp::create(*this, fnDec, zero.getResult()).getResult(); + resultPtrs.try_emplace(resultIndex, result); + } // Switch to measurements block setInsertionPoint(measurementsBlock->getTerminator()); @@ -656,10 +659,25 @@ OwningOpRef QIRProgramBuilder::finalize() { for (auto array : qubitArrays) { auto sig = LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType}); - auto decl = getOrCreateFunctionDeclaration(*this, module, - QIR_QUBIT_ARRAY_RELEASE, sig); + auto dec = getOrCreateFunctionDeclaration(*this, module, + QIR_QUBIT_ARRAY_RELEASE, sig); + auto size = array.getDefiningOp().getArraySize(); + LLVM::CallOp::create(*this, dec, ValueRange{size, array}); + } + + for (auto& [_, ptr] : resultPtrs) { + auto sig = LLVM::LLVMFunctionType::get(voidType, {ptrType}); + auto dec = + getOrCreateFunctionDeclaration(*this, module, QIR_RESULT_RELEASE, sig); + LLVM::CallOp::create(*this, dec, ptr); + } + + for (auto& [_, array] : resultArrays) { + auto sig = LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType}); + auto dec = getOrCreateFunctionDeclaration(*this, module, + QIR_RESULT_ARRAY_RELEASE, sig); auto size = array.getDefiningOp().getArraySize(); - LLVM::CallOp::create(*this, decl, ValueRange{size, array}); + LLVM::CallOp::create(*this, dec, ValueRange{size, array}); } // Generate output recording in the output block From d1fd730cd90c1eec4bd1b24bc54397d1ceb4349b Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:02:52 +0200 Subject: [PATCH 19/71] Fix linter error --- mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 23a9861823..cc51d0d20b 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include From a86274e24b623b13a24ff13143f2fc403684c043 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:14:59 +0200 Subject: [PATCH 20/71] Address the Rabbit's comments --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 7 ++++--- mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp | 12 ++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index ca13514866..934676535e 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -182,13 +182,14 @@ struct ConvertMemRefAllocOp final return success(); } - auto& qtensorMap = getState().qtensorMap; - auto shape = op.getType().getShape(); if (shape.size() != 1) { return failure(); } + auto& qtensorMap = getState().qtensorMap; + auto memref = op.getResult(); + Value qtensor; if (shape[0] == ShapedType::kDynamic) { qtensor = rewriter.replaceOpWithNewOp( @@ -200,7 +201,7 @@ struct ConvertMemRefAllocOp final rewriter.replaceOpWithNewOp(op, size.getResult()); } - qtensorMap.try_emplace(op.getResult(), qtensor); + qtensorMap.try_emplace(memref, qtensor); return success(); } diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index 1582667414..05c7db6e11 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -175,6 +175,10 @@ QIRProgramBuilder::allocClassicalBitRegister(const int64_t size, llvm::reportFatalUsageError("Size must be positive"); } + if (resultArrays.contains(name)) { + llvm::reportFatalUsageError("Classical register already exists"); + } + metadata_.useDynamicResult = true; // Save current insertion point @@ -654,7 +658,10 @@ OwningOpRef QIRProgramBuilder::finalize() { auto zero = LLVM::ZeroOp::create(*this, ptrType); LLVM::CallOp::create(*this, initDec, zero.getResult()); - // Insert in output block (before return) + // Generate output recording in output block + generateOutputRecording(); + + // Switch to measurements block setInsertionPoint(measurementsBlock->getTerminator()); for (auto array : qubitArrays) { @@ -680,9 +687,6 @@ OwningOpRef QIRProgramBuilder::finalize() { LLVM::CallOp::create(*this, dec, ValueRange{size, array}); } - // Generate output recording in the output block - generateOutputRecording(); - auto mainFuncOp = llvm::cast(mainFunc); setQIRAttributes(mainFuncOp, metadata_); From 58abb28ca609663c0fe7f67e8c19219d5a5440df Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 31 Mar 2026 15:38:05 +0200 Subject: [PATCH 21/71] Address the Rabbit's comments --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 25 +++++++++++-------- .../Dialect/QIR/Builder/QIRProgramBuilder.cpp | 12 +++++---- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 136f73d9dc..2165471c9f 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -89,6 +89,7 @@ struct LoweringState : QIRMetadata { /// Block information Block* entryBlock{}; Block* measurementsBlock{}; + Block* outputBlock{}; }; /** @@ -328,6 +329,7 @@ struct ConvertMemRefDeallocOp final LogicalResult matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); auto* ctx = getContext(); auto i64Type = rewriter.getI64Type(); auto ptrType = LLVM::LLVMPointerType::get(ctx); @@ -340,15 +342,15 @@ struct ConvertMemRefDeallocOp final // Save current insertion point const OpBuilder::InsertionGuard guard(rewriter); - // Switch to measurements block - rewriter.setInsertionPoint(getState().measurementsBlock->getTerminator()); + // Release resources in output block + rewriter.setInsertionPoint(state.outputBlock->getTerminator()); auto fnSig = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {i64Type, ptrType}); auto fnDec = getOrCreateFunctionDeclaration(rewriter, op, QIR_QUBIT_ARRAY_RELEASE, fnSig); - auto size = getState().memrefSizes.lookup(op.getMemref()); + auto size = state.memrefSizes.lookup(op.getMemref()); // Create the release call LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, @@ -414,14 +416,15 @@ struct ConvertQCDeallocOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); auto* ctx = getContext(); auto ptrType = LLVM::LLVMPointerType::get(ctx); // Save current insertion point const OpBuilder::InsertionGuard guard(rewriter); - // Switch to measurements block - rewriter.setInsertionPoint(getState().measurementsBlock->getTerminator()); + // Release resources in output block + rewriter.setInsertionPoint(state.outputBlock->getTerminator()); auto fnSig = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {ptrType}); @@ -627,13 +630,14 @@ struct ConvertQCResetOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(ResetOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); auto* ctx = getContext(); // Save current insertion point const OpBuilder::InsertionGuard guard(rewriter); // Switch to measurements block - rewriter.setInsertionPoint(getState().measurementsBlock->getTerminator()); + rewriter.setInsertionPoint(state.measurementsBlock->getTerminator()); // Declare QIR function const auto fnSignature = LLVM::LLVMFunctionType::get( @@ -1045,8 +1049,8 @@ struct QCToQIR final : impl::QCToQIRBase { * The QIR base profile requires a specific 4-block structure: * 1. **Entry block**: Contains constant operations and initialization * 2. **Body block**: Contains reversible quantum operations (gates) - * 3. **Measurements block**: Contains irreversible operations (measure, - * reset, dealloc) + * 3. **Measurements block**: Contains irreversible operations (measure and + * reset) * 4. **Output block**: Contains output recording calls * * Blocks are connected with unconditional jumps (entry, body, measurements, @@ -1076,6 +1080,7 @@ struct QCToQIR final : impl::QCToQIRBase { state.entryBlock = entryBlock; state.measurementsBlock = measurementsBlock; + state.outputBlock = outputBlock; auto& bodyBlockOps = bodyBlock->getOperations(); auto& outputBlockOps = outputBlock->getOperations(); @@ -1246,8 +1251,8 @@ struct QCToQIR final : impl::QCToQIRBase { auto ptrType = LLVM::LLVMPointerType::get(ctx); auto voidType = LLVM::LLVMVoidType::get(ctx); - // Switch to measurements block - builder.setInsertionPoint(state->measurementsBlock->getTerminator()); + // Release resources in output block + builder.setInsertionPoint(state->outputBlock->getTerminator()); for (auto& [_, ptr] : state->resultPtrs) { auto sig = LLVM::LLVMFunctionType::get(voidType, {ptrType}); diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index 05c7db6e11..0f8a209f2a 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -218,6 +218,8 @@ Value QIRProgramBuilder::measure(Value qubit, const int64_t resultIndex) { llvm::reportFatalUsageError("Result index must be non-negative"); } + metadata_.useDynamicResult = true; + // Save current insertion point const InsertionGuard guard(*this); @@ -658,11 +660,8 @@ OwningOpRef QIRProgramBuilder::finalize() { auto zero = LLVM::ZeroOp::create(*this, ptrType); LLVM::CallOp::create(*this, initDec, zero.getResult()); - // Generate output recording in output block - generateOutputRecording(); - - // Switch to measurements block - setInsertionPoint(measurementsBlock->getTerminator()); + // Release resources in output block + setInsertionPoint(outputBlock->getTerminator()); for (auto array : qubitArrays) { auto sig = LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType}); @@ -672,6 +671,9 @@ OwningOpRef QIRProgramBuilder::finalize() { LLVM::CallOp::create(*this, dec, ValueRange{size, array}); } + // Generate output recording in output block + generateOutputRecording(); + for (auto& [_, ptr] : resultPtrs) { auto sig = LLVM::LLVMFunctionType::get(voidType, {ptrType}); auto dec = From ad7f26508b8e9159013fd54f8717998f36deb813 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 31 Mar 2026 19:58:31 +0200 Subject: [PATCH 22/71] Initialize at top of entry block --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 63 +++++++------------ .../Dialect/QIR/Builder/QIRProgramBuilder.cpp | 16 +++-- 2 files changed, 30 insertions(+), 49 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 2165471c9f..53a90411ee 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -1028,8 +1028,8 @@ struct ConvertQCYieldOp final : StatefulOpConversionPattern { * Conversion stages: * 1. Convert func dialect to LLVM * 2. Ensure proper block structure for QIR base profile - * 3. Convert QC operations to QIR calls - * 4. Add QIR initialization call + * 3. Add QIR initialization call + * 4. Convert QC operations to QIR calls * 5. Set QIR metadata attributes * 6. Convert arith and cf dialects to LLVM * 7. Reconcile unrealized casts @@ -1116,42 +1116,25 @@ struct QCToQIR final : impl::QCToQIRBase { * @brief Adds QIR initialization call to the entry block * * @details - * Inserts a call to `__quantum__rt__initialize` at the end of the entry block - * (before the jump to main block). This QIR runtime function initializes the - * quantum execution environment and takes a null pointer as argument. + * This QIR runtime function initializes the quantum execution environment. * * @param main The main LLVM function * @param ctx The MLIR context + * @param state The lowering state */ - static void addInitialize(LLVM::LLVMFuncOp& main, MLIRContext* ctx) { - auto moduleOp = main->getParentOfType(); - auto& firstBlock = *(main.getBlocks().begin()); - OpBuilder builder(main.getBody()); + static void addInitialize(LLVM::LLVMFuncOp& main, MLIRContext* ctx, + LoweringState& state) { + OpBuilder builder(ctx); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto voidType = LLVM::LLVMVoidType::get(ctx); - // Create a zero (null) pointer for the initialize call - builder.setInsertionPointToStart(&firstBlock); - auto zeroOp = LLVM::ZeroOp::create(builder, main->getLoc(), - LLVM::LLVMPointerType::get(ctx)); - - // Insert the initialize call before the jump to main block - const auto insertPoint = std::prev(firstBlock.getOperations().end(), 1); - builder.setInsertionPoint(&*insertPoint); - - // Get or create the initialize function declaration - auto* fnDecl = SymbolTable::lookupNearestSymbolFrom( - main, builder.getStringAttr(QIR_INITIALIZE)); - if (fnDecl == nullptr) { - const PatternRewriter::InsertionGuard guard(builder); - builder.setInsertionPointToEnd(moduleOp.getBody()); - auto fnSignature = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(ctx), LLVM::LLVMPointerType::get(ctx)); - fnDecl = LLVM::LLVMFuncOp::create(builder, main->getLoc(), QIR_INITIALIZE, - fnSignature); - } + builder.setInsertionPointToStart(state.entryBlock); - // Create the initialization call - LLVM::CallOp::create(builder, main->getLoc(), - cast(fnDecl), zeroOp->getResult(0)); + auto initSig = LLVM::LLVMFunctionType::get(voidType, ptrType); + auto initDec = + getOrCreateFunctionDeclaration(builder, main, QIR_INITIALIZE, initSig); + auto zero = LLVM::ZeroOp::create(builder, main->getLoc(), ptrType); + LLVM::CallOp::create(builder, main->getLoc(), initDec, zero.getResult()); } /** @@ -1287,13 +1270,13 @@ struct QCToQIR final : impl::QCToQIRBase { * Create proper 4-block structure for QIR base profile (entry, main, * irreversible, output). * - * **Stage 3: QC to LLVM** + * **Stage 3: Initialization** + * Insert the `__quantum__rt__initialize` call. + * + * **Stage 4: QC to LLVM** * Convert QC dialect operations to QIR calls and add output recording to the * output block. * - * **Stage 4: Initialization** - * Insert the `__quantum__rt__initialize` call. - * * **Stage 5: QIR attributes** * Add QIR base profile metadata to the main function, including qubit/result * counts and version information. @@ -1338,7 +1321,10 @@ struct QCToQIR final : impl::QCToQIRBase { // Stage 2: Create block structure ensureBlocks(main, state); - // Stage 3: Convert QC dialect to LLVM (QIR calls) + // Stage 3: Insert initialize call + addInitialize(main, ctx, state); + + // Stage 4: Convert QC dialect to LLVM (QIR calls) { RewritePatternSet patterns(ctx); target.addIllegalDialect(); @@ -1368,9 +1354,6 @@ struct QCToQIR final : impl::QCToQIRBase { releaseResults(main, ctx, &state); } - // Stage 4: Insert initialize call - addInitialize(main, ctx); - // Stage 5: Set QIR metadata attributes setQIRAttributes(main, state); diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index 0f8a209f2a..e59acb4727 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -72,6 +72,12 @@ void QIRProgramBuilder::initialize() { setInsertionPointToStart(entryBlock); exitCode = intConstant(0); + auto initSig = LLVM::LLVMFunctionType::get(voidType, ptrType); + auto initDec = + getOrCreateFunctionDeclaration(*this, module, QIR_INITIALIZE, initSig); + auto zero = LLVM::ZeroOp::create(*this, ptrType); + LLVM::CallOp::create(*this, initDec, zero.getResult()); + // Add unconditional branches between blocks setInsertionPointToEnd(entryBlock); LLVM::BrOp::create(*this, bodyBlock); @@ -649,17 +655,9 @@ void QIRProgramBuilder::generateOutputRecording() { OwningOpRef QIRProgramBuilder::finalize() { checkFinalized(); + // Save current insertion point const InsertionGuard guard(*this); - // Insert initialization at end of entry block - setInsertionPoint(entryBlock->getTerminator()); - - auto initSig = LLVM::LLVMFunctionType::get(voidType, ptrType); - auto initDec = - getOrCreateFunctionDeclaration(*this, module, QIR_INITIALIZE, initSig); - auto zero = LLVM::ZeroOp::create(*this, ptrType); - LLVM::CallOp::create(*this, initDec, zero.getResult()); - // Release resources in output block setInsertionPoint(outputBlock->getTerminator()); From 3e7bbf0ee630bea73bacfae2025b0a6b633b29a2 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 1 Apr 2026 11:46:49 +0200 Subject: [PATCH 23/71] Address the Rabbit's comments --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 9 ++++++++- mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp | 13 +++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 53a90411ee..626e943916 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -18,7 +18,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -86,6 +88,10 @@ struct LoweringState : QIRMetadata { int64_t inCtrlOp = 0; DenseMap> controls; + /// Allocator and StringSaver for stable StringRefs + llvm::BumpPtrAllocator allocator; + llvm::StringSaver stringSaver{allocator}; + /// Block information Block* entryBlock{}; Block* measurementsBlock{}; @@ -572,7 +578,8 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { rewriter, op.getLoc(), rewriter.getI64IntegerAttr(i))}); auto load = LLVM::LoadOp::create(rewriter, op.getLoc(), ptrType, gep.getResult()); - loadedResults.try_emplace({registerName, i}, load.getResult()); + loadedResults.try_emplace({state.stringSaver.save(registerName), i}, + load.getResult()); } } diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index e59acb4727..28eb571b16 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -181,6 +181,10 @@ QIRProgramBuilder::allocClassicalBitRegister(const int64_t size, llvm::reportFatalUsageError("Size must be positive"); } + if (name.starts_with("__unnamed__")) { + llvm::reportFatalUsageError( + "Classical register names starting with '__unnamed__' are reserved"); + } if (resultArrays.contains(name)) { llvm::reportFatalUsageError("Classical register already exists"); } @@ -260,14 +264,19 @@ Value QIRProgramBuilder::measure(Value qubit, const int64_t resultIndex) { QIRProgramBuilder& QIRProgramBuilder::measure(Value qubit, const Bit& bit) { checkFinalized(); + auto it = loadedResults.find({bit.registerName, bit.registerIndex}); + if (it == loadedResults.end()) { + llvm::reportFatalUsageError( + "Bit does not belong to an allocated classical register"); + } + auto result = it->second; + // Save current insertion point const InsertionGuard guard(*this); // Switch to measurements block setInsertionPoint(measurementsBlock->getTerminator()); - auto result = loadedResults.at({bit.registerName, bit.registerIndex}); - // Create measure call const auto fnSig = LLVM::LLVMFunctionType::get(voidType, {ptrType, ptrType}); auto fnDec = From ffd6426872472495e78d7ffbec73989e2ebc3155 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 1 Apr 2026 11:47:48 +0200 Subject: [PATCH 24/71] Fix linter errors --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 626e943916..ad27f76a34 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -34,13 +34,11 @@ #include #include #include -#include #include #include #include #include #include -#include #include #include #include @@ -51,7 +49,6 @@ #include #include -#include #include #include From f0b1b4d50f2c65d0b0b47502cdfc3f1bbf730602 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:38:22 +0200 Subject: [PATCH 25/71] Make everything work again --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 96 ++++++++++--------- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 10 +- 2 files changed, 55 insertions(+), 51 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index d9b59c9121..843259b6c7 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -100,19 +100,19 @@ struct LoweringState { llvm::DenseMap currentQubits; }; + /// Per-region map from QC qubit references to latest QCO SSA values. + /// + /// @details Keys are `Operation::getParentRegion()` for ops being converted + /// (typically a `func.func` body or a modifier region). + llvm::DenseMap> qubitMap; + /// Map from original MemRef to its latest QTensor SSA value llvm::DenseMap qtensorMap; /// Map from original QC qubit reference to its register information llvm::DenseMap qubitInfos; - /// Per-region map from QC qubit references to latest QCO SSA values. - /// - /// @details Keys are `Operation::getParentRegion()` for ops being converted - /// (typically a `func.func` body, or a `qc.ctrl` / `qc.inv` region). - llvm::DenseMap> qubitMap; - - /// Stack of active modifier regions (`qc.ctrl` / `qc.inv`). + /// Stack of active modifier regions SmallVector modifierFrames; }; @@ -317,7 +317,7 @@ struct ConvertFuncReturnOp final : StatefulOpConversionPattern { } // Deallocate dead qubit values - for (Value qcoQubit : llvm::make_second_range(map)) { + for (auto qcoQubit : llvm::make_second_range(map)) { if (!liveQubits.contains(qcoQubit)) { SinkOp::create(rewriter, op.getLoc(), qcoQubit); } @@ -373,7 +373,7 @@ struct ConvertMemRefAllocOp final matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { if (!llvm::isa(op.getType().getElementType())) { - return success(); + return failure(); } auto shape = op.getType().getShape(); @@ -420,12 +420,13 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { if (!llvm::isa(op.getMemref().getType().getElementType())) { - return success(); + return failure(); } - auto& qubitMap = getState().qubitMap; - auto& qubitInfos = getState().qubitInfos; - auto& qtensorMap = getState().qtensorMap; + auto& state = getState(); + auto* operation = op.getOperation(); + auto& qubitInfos = state.qubitInfos; + auto& qtensorMap = state.qtensorMap; // Look up latest QTensor value for this QC register auto memref = op.getMemref(); @@ -437,7 +438,7 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { auto extract = qtensor::ExtractOp::create(rewriter, op.getLoc(), qtensor, index); - qubitMap.try_emplace(op.getResult(), extract.getResult()); + assignMappedQubit(state, operation, op.getResult(), extract.getResult()); qubitInfos.try_emplace(op.getResult(), QubitInfo{.reg = memref, .index = index}); qtensorMap[memref] = extract.getOutTensor(); @@ -475,12 +476,14 @@ struct ConvertMemRefDeallocOp final matchAndRewrite(memref::DeallocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { if (!llvm::isa(op.getMemref().getType().getElementType())) { - return success(); + return failure(); } - auto& qubitMap = getState().qubitMap; - auto& qubitInfos = getState().qubitInfos; - auto& qtensorMap = getState().qtensorMap; + auto& state = getState(); + auto* operation = op.getOperation(); + auto& qubitMap = state.qubitMap[operation->getParentRegion()]; + auto& qubitInfos = state.qubitInfos; + auto& qtensorMap = state.qtensorMap; // Look up latest QTensor value for this QC register auto memref = op.getMemref(); @@ -593,10 +596,9 @@ struct ConvertQCDeallocOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - auto* region = operation->getParentRegion(); - auto& qubitMap = state.qubitMap[region]; - Value qcQubit = op.getQubit(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto& qubitMap = state.qubitMap[operation->getParentRegion()]; + auto qcQubit = op.getQubit(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the sink operation rewriter.replaceOpWithNewOp(op, qcoQubit); @@ -631,7 +633,7 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubit(); + auto qcQubit = op.getQubit(); auto qcoOp = rewriter.replaceOpWithNewOp(op, op.getIndex()); assignMappedQubit(state, operation, qcQubit, qcoOp.getQubit()); @@ -671,8 +673,8 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubit(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubit(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create qco.measure (returns both output qubit and bit result) auto qcoOp = qco::MeasureOp::create( @@ -717,8 +719,8 @@ struct ConvertQCResetOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubit(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubit(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create qco.reset (consumes input, produces output) auto qcoOp = qco::ResetOp::create(rewriter, op.getLoc(), qcoQubit); @@ -789,8 +791,8 @@ struct ConvertQCOneTargetZeroParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubitIn(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubitIn(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit); @@ -828,8 +830,8 @@ struct ConvertQCOneTargetOneParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubitIn(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubitIn(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = @@ -868,8 +870,8 @@ struct ConvertQCOneTargetTwoParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubitIn(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubitIn(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit, @@ -908,8 +910,8 @@ struct ConvertQCOneTargetThreeParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit = op.getQubitIn(); - Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); + auto qcQubit = op.getQubitIn(); + auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = @@ -950,10 +952,10 @@ struct ConvertQCTwoTargetZeroParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit0 = op.getQubit0In(); - Value qcQubit1 = op.getQubit1In(); - Value qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); - Value qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); + auto qcQubit0 = op.getQubit0In(); + auto qcQubit1 = op.getQubit1In(); + auto qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); + auto qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit0, qcoQubit1); @@ -993,10 +995,10 @@ struct ConvertQCTwoTargetOneParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit0 = op.getQubit0In(); - Value qcQubit1 = op.getQubit1In(); - Value qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); - Value qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); + auto qcQubit0 = op.getQubit0In(); + auto qcQubit1 = op.getQubit1In(); + auto qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); + auto qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit0, qcoQubit1, @@ -1037,10 +1039,10 @@ struct ConvertQCTwoTargetTwoParameterToQCO final ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); auto* operation = op.getOperation(); - Value qcQubit0 = op.getQubit0In(); - Value qcQubit1 = op.getQubit1In(); - Value qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); - Value qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); + auto qcQubit0 = op.getQubit0In(); + auto qcQubit1 = op.getQubit1In(); + auto qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); + auto qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit0, qcoQubit1, diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 2105ba9f9b..10c54b62c8 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -930,8 +930,8 @@ OwningOpRef QCOProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } - auto blockOrderComparatorToInsert = [](const std::pair& a, - const std::pair& b) { + auto blockOrderComparator = [](const std::pair& a, + const std::pair& b) { auto* opA = a.first.getDefiningOp(); auto* opB = b.first.getDefiningOp(); if (!opA || !opB || opA->getBlock() != opB->getBlock()) { @@ -965,7 +965,8 @@ OwningOpRef QCOProgramBuilder::finalize() { // Automatically deallocate all still-allocated tensors if (!validTensors.empty()) { - for (auto& [tensor, tensorInfo] : validTensors) { + for (auto& [tensor, tensorInfo] : llvm::to_vector(validTensors)) { + llvm::errs() << "Deallocating tensor\n"; // Filter out qubits belonging to this tensor llvm::SmallVector> toInsert; for (auto& [qubit, qubitInfo] : registerQubits) { @@ -975,11 +976,12 @@ OwningOpRef QCOProgramBuilder::finalize() { toInsert.push_back({qubit, qubitInfo.regIndex}); } // Sort qubits for deterministic output - llvm::sort(toInsert, blockOrderComparatorToInsert); + llvm::sort(toInsert, blockOrderComparator); // Insert qubits for (auto& [qubit, index] : toInsert) { tensor = qtensorInsert(qubit, tensor, index); } + // Deallocate tensor qtensor::DeallocOp::create(*this, tensor); } } From 9f2dc2a60bf6e14786c7131cdd825582ad6a54f6 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:39:58 +0200 Subject: [PATCH 26/71] Save one materialization --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 843259b6c7..108d7d21d1 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1078,7 +1078,7 @@ struct ConvertQCBarrierOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - const auto qcQubits = llvm::to_vector(op.getQubits()); + auto qcQubits = op.getQubits(); auto qcoQubits = resolveMappedQubits(state, operation, qcQubits); // Create qco.barrier From 6e047dd7406522875bf53d8ea4f320bd6b012450 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 1 Apr 2026 17:45:18 +0200 Subject: [PATCH 27/71] Make tensor map region-based as well --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 124 +++++++++++++----------- 1 file changed, 70 insertions(+), 54 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 108d7d21d1..4cd7fc0ced 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -100,17 +100,19 @@ struct LoweringState { llvm::DenseMap currentQubits; }; - /// Per-region map from QC qubit references to latest QCO SSA values. + /// Per-region map from original QC qubit reference to its latest QCO SSA + /// value. /// /// @details Keys are `Operation::getParentRegion()` for ops being converted /// (typically a `func.func` body or a modifier region). llvm::DenseMap> qubitMap; - /// Map from original MemRef to its latest QTensor SSA value - llvm::DenseMap qtensorMap; + /// Per-region map from original QC register to its latest QTensor SSA value + llvm::DenseMap> tensorMap; - /// Map from original QC qubit reference to its register information - llvm::DenseMap qubitInfos; + /// Per-region map from original QC qubit reference to its register + /// information + llvm::DenseMap> qubitInfoMap; /// Stack of active modifier regions SmallVector modifierFrames; @@ -147,23 +149,6 @@ class StatefulOpConversionPattern : public OpConversionPattern { }; } // namespace -/** - * @brief Helper function to look up the latest QCO qubit value for a given QC - * qubit reference - * - * @param qubitMap The mapping from QC qubits to QCO qubits for the current - * region - * @param qcQubit The QC qubit reference to look up - * @return The latest QCO qubit value corresponding to the given QC qubit - * reference - */ -[[nodiscard]] static Value -lookupMappedQubit(llvm::DenseMap& qubitMap, Value qcQubit) { - auto it = qubitMap.find(qcQubit); - assert(it != qubitMap.end() && "QC qubit not found"); - return it->second; -} - /** @brief Returns whether lowering currently processes a modifier body. */ [[nodiscard]] static bool isInsideModifier(const LoweringState& state) { return !state.modifierFrames.empty(); @@ -176,14 +161,15 @@ currentModifierFrame(LoweringState& state) { return state.modifierFrames.back(); } -/** @brief Finds the nearest region-local qubit map containing @p qcQubit. */ +/** @brief Finds the nearest region-local map containing @p reference. */ [[nodiscard]] static llvm::DenseMap* -findMappedQubitMap(LoweringState& state, Operation* anchor, Value qcQubit) { - for (Region* current = anchor->getParentRegion(); current != nullptr; +findRegionLocalMap(llvm::DenseMap>& map, + Operation* anchor, Value reference) { + for (auto* current = anchor->getParentRegion(); current != nullptr; current = current->getParentRegion()) { - auto mapIt = state.qubitMap.find(current); - if (mapIt != state.qubitMap.end() && mapIt->second.contains(qcQubit)) { - return &mapIt->second; + auto it = map.find(current); + if (it != map.end() && it->second.contains(reference)) { + return &it->second; } } return nullptr; @@ -200,9 +186,21 @@ findMappedQubitMap(LoweringState& state, Operation* anchor, Value qcQubit) { } } - auto* qubitMap = findMappedQubitMap(state, anchor, qcQubit); + auto* qubitMap = findRegionLocalMap(state.qubitMap, anchor, qcQubit); assert(qubitMap != nullptr && "QC qubit not found"); - return lookupMappedQubit(*qubitMap, qcQubit); + auto it = qubitMap->find(qcQubit); + assert(it != qubitMap->end() && "QC qubit not found"); + return it->second; +} + +/** @brief Resolves the latest QTensor SSA value for a QC register. */ +[[nodiscard]] static Value lookupMappedTensor(LoweringState& state, + Operation* anchor, Value memref) { + auto* tensorMap = findRegionLocalMap(state.tensorMap, anchor, memref); + assert(tensorMap != nullptr && "QC register not found"); + auto it = tensorMap->find(memref); + assert(it != tensorMap->end() && "QC register not found"); + return it->second; } /** @brief Updates the latest QCO SSA value for a QC qubit reference. */ @@ -217,7 +215,7 @@ static void assignMappedQubit(LoweringState& state, Operation* anchor, } } - if (auto* qubitMap = findMappedQubitMap(state, anchor, qcQubit)) { + if (auto* qubitMap = findRegionLocalMap(state.qubitMap, anchor, qcQubit)) { (*qubitMap)[qcQubit] = qcoQubit; return; } @@ -225,6 +223,17 @@ static void assignMappedQubit(LoweringState& state, Operation* anchor, state.qubitMap[anchor->getParentRegion()][qcQubit] = qcoQubit; } +/** @brief Updates the latest QTensor SSA value for a QC register. */ +static void assignMappedTensor(LoweringState& state, Operation* anchor, + Value memref, Value tensor) { + if (auto* tensorMap = findRegionLocalMap(state.tensorMap, anchor, memref)) { + (*tensorMap)[memref] = tensor; + return; + } + + state.tensorMap[anchor->getParentRegion()][memref] = tensor; +} + /** @brief Resolves a range of QC qubits to their latest QCO values. */ template [[nodiscard]] static SmallVector @@ -296,7 +305,7 @@ struct ConvertFuncReturnOp final : StatefulOpConversionPattern { matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - Region* funcRegion = op->getParentRegion(); + auto* funcRegion = op->getParentRegion(); auto& map = state.qubitMap[funcRegion]; // Build return values from qubitMap and collect live qubit information. @@ -381,7 +390,9 @@ struct ConvertMemRefAllocOp final return failure(); } - auto& qtensorMap = getState().qtensorMap; + auto& state = getState(); + auto* operation = op.getOperation(); + auto memref = op.getResult(); Value qtensor; @@ -395,7 +406,7 @@ struct ConvertMemRefAllocOp final rewriter.replaceOpWithNewOp(op, size.getResult()); } - qtensorMap.try_emplace(memref, qtensor); + assignMappedTensor(state, op.getOperation(), memref, qtensor); return success(); } @@ -424,24 +435,30 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { } auto& state = getState(); + auto& qubitInfoMap = state.qubitInfoMap; auto* operation = op.getOperation(); - auto& qubitInfos = state.qubitInfos; - auto& qtensorMap = state.qtensorMap; // Look up latest QTensor value for this QC register auto memref = op.getMemref(); - assert(qtensorMap.contains(memref) && "QC register not found"); - auto qtensor = qtensorMap[memref]; + auto qtensor = lookupMappedTensor(state, operation, memref); auto index = adaptor.getIndices()[0]; - auto extract = qtensor::ExtractOp::create(rewriter, op.getLoc(), qtensor, index); - assignMappedQubit(state, operation, op.getResult(), extract.getResult()); - qubitInfos.try_emplace(op.getResult(), - QubitInfo{.reg = memref, .index = index}); - qtensorMap[memref] = extract.getOutTensor(); + auto qcQubit = op.getResult(); + auto qcoQubit = extract.getResult(); + + assignMappedQubit(state, operation, qcQubit, qcoQubit); + assignMappedTensor(state, operation, memref, extract.getOutTensor()); + + QubitInfo info{.reg = memref, .index = index}; + if (auto it = qubitInfoMap.find(operation->getParentRegion()); + it != qubitInfoMap.end()) { + it->second[qcQubit] = info; + } else { + qubitInfoMap[operation->getParentRegion()][qcQubit] = info; + } rewriter.eraseOp(op); @@ -480,21 +497,19 @@ struct ConvertMemRefDeallocOp final } auto& state = getState(); - auto* operation = op.getOperation(); - auto& qubitMap = state.qubitMap[operation->getParentRegion()]; - auto& qubitInfos = state.qubitInfos; - auto& qtensorMap = state.qtensorMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; + auto& tensorMap = state.tensorMap[op->getParentRegion()]; + auto& qubitInfoMap = state.qubitInfoMap[op->getParentRegion()]; // Look up latest QTensor value for this QC register auto memref = op.getMemref(); - assert(qtensorMap.contains(memref) && "QC register not found"); - auto qtensor = qtensorMap[memref]; + auto qtensor = lookupMappedTensor(state, op.getOperation(), memref); // Filter out qubits belonging to this tensor llvm::SmallVector> toInsert; toInsert.reserve(qubitMap.size()); for (auto [qcQubit, qcoQubit] : qubitMap) { - auto& info = qubitInfos[qcQubit]; + auto& info = qubitInfoMap[qcQubit]; if (info.reg != memref) { continue; } @@ -513,18 +528,18 @@ struct ConvertMemRefDeallocOp final // Insert qubits for (auto [qcQubit, qcoQubit] : toInsert) { - auto& info = qubitInfos[qcQubit]; + auto& info = qubitInfoMap[qcQubit]; auto index = info.index; auto insert = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, qtensor, index); qtensor = insert.getResult(); qubitMap.erase(qcQubit); - qubitInfos.erase(qcQubit); + qubitInfoMap.erase(qcQubit); } rewriter.replaceOpWithNewOp(op, qtensor); - qtensorMap.erase(memref); + tensorMap.erase(memref); return success(); } @@ -595,8 +610,9 @@ struct ConvertQCDeallocOp final : StatefulOpConversionPattern { matchAndRewrite(qc::DeallocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); + auto& qubitMap = state.qubitMap[op->getParentRegion()]; auto* operation = op.getOperation(); - auto& qubitMap = state.qubitMap[operation->getParentRegion()]; + auto qcQubit = op.getQubit(); auto qcoQubit = lookupMappedQubit(state, operation, qcQubit); From 70389266e3d438a85bb1b5c2dc504701e6ffca98 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 1 Apr 2026 17:51:48 +0200 Subject: [PATCH 28/71] Address the Rabbit's comments --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index ad27f76a34..316e7061d8 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -47,6 +47,7 @@ #include #include +#include #include #include #include @@ -210,8 +211,12 @@ struct QCToQIRTypeConverter final : LLVMTypeConverter { addConversion( [ctx](QubitType /*type*/) { return LLVM::LLVMPointerType::get(ctx); }); - addConversion( - [ctx](MemRefType /*type*/) { return LLVM::LLVMPointerType::get(ctx); }); + addConversion([ctx](MemRefType type) -> Type { + if (llvm::isa(type.getElementType())) { + return LLVM::LLVMPointerType::get(ctx); + } + return type; + }); } }; @@ -251,7 +256,8 @@ struct ConvertMemRefAllocOp final auto shape = op.getType().getShape(); if (shape.size() != 1) { - return failure(); + return rewriter.notifyMatchFailure( + op, "Only one-dimensional registers are supported"); } Value size; @@ -339,7 +345,8 @@ struct ConvertMemRefDeallocOp final auto shape = op.getMemref().getType().getShape(); if (shape.size() != 1) { - return failure(); + return rewriter.notifyMatchFailure( + op, "Only one-dimensional registers are supported"); } // Save current insertion point @@ -354,6 +361,7 @@ struct ConvertMemRefDeallocOp final QIR_QUBIT_ARRAY_RELEASE, fnSig); auto size = state.memrefSizes.lookup(op.getMemref()); + assert(size != nullptr && "Size not found"); // Create the release call LLVM::CallOp::create(rewriter, op.getLoc(), fnDec, From 5660b64114a00591284b45eda8a9ceb9471737ad Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:02:58 +0200 Subject: [PATCH 29/71] Fix linter errors --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 2 +- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 4cd7fc0ced..0cde5c22be 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -406,7 +406,7 @@ struct ConvertMemRefAllocOp final rewriter.replaceOpWithNewOp(op, size.getResult()); } - assignMappedTensor(state, op.getOperation(), memref, qtensor); + assignMappedTensor(state, operation, memref, qtensor); return success(); } diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 316e7061d8..b9fcf85d0b 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include From f81c35aff16000d10e10d03570d1eb4409641fb5 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:25:59 +0200 Subject: [PATCH 30/71] Address the Rabbit's comments --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 1 - mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 34 +++++++++++-------- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 1 - 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index fcda50e97c..421d7942b2 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -116,7 +116,6 @@ class QCOToQCTypeConverter final : public TypeConverter { addConversion([ctx](RankedTensorType type) -> Type { if (llvm::isa(type.getElementType())) { - // TODO: Can we make it work with type.getShape()? return MemRefType::get({ShapedType::kDynamic}, qc::QubitType::get(ctx)); } return type; diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index b9fcf85d0b..03caf30f02 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -243,6 +243,12 @@ struct ConvertMemRefAllocOp final LogicalResult matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto shape = op.getType().getShape(); + if (shape.size() != 1) { + return rewriter.notifyMatchFailure( + op, "Only one-dimensional registers are supported"); + } + auto& state = getState(); state.useDynamicQubit = true; @@ -255,12 +261,6 @@ struct ConvertMemRefAllocOp final auto fnDec = getOrCreateFunctionDeclaration(rewriter, op, QIR_QUBIT_ARRAY_ALLOC, fnSig); - auto shape = op.getType().getShape(); - if (shape.size() != 1) { - return rewriter.notifyMatchFailure( - op, "Only one-dimensional registers are supported"); - } - Value size; if (shape[0] == ShapedType::kDynamic) { size = adaptor.getDynamicSizes()[0]; @@ -303,6 +303,12 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto shape = op.getMemref().getType().getShape(); + if (shape.size() != 1) { + return rewriter.notifyMatchFailure( + op, "Only one-dimensional registers are supported"); + } + auto* ctx = getContext(); auto ptrType = LLVM::LLVMPointerType::get(ctx); @@ -339,17 +345,17 @@ struct ConvertMemRefDeallocOp final LogicalResult matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto& state = getState(); - auto* ctx = getContext(); - auto i64Type = rewriter.getI64Type(); - auto ptrType = LLVM::LLVMPointerType::get(ctx); - auto shape = op.getMemref().getType().getShape(); if (shape.size() != 1) { return rewriter.notifyMatchFailure( op, "Only one-dimensional registers are supported"); } + auto& state = getState(); + auto* ctx = getContext(); + auto i64Type = rewriter.getI64Type(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + // Save current insertion point const OpBuilder::InsertionGuard guard(rewriter); @@ -1069,14 +1075,12 @@ struct QCToQIR final : impl::QCToQIRBase { * Blocks are connected with unconditional jumps (entry, body, measurements, * output). This structure ensures proper QIR Base Profile semantics. * - * If the function already has multiple blocks, this function does nothing. - * * @param main The main LLVM function to restructure */ static void ensureBlocks(LLVM::LLVMFuncOp& main, LoweringState& state) { - // Return if there are already multiple blocks if (main.getBlocks().size() > 1) { - return; + llvm::reportFatalInternalError( + "Modules with multiple blocks are not supported yet"); } // Get the existing block diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 10c54b62c8..a63ef6a593 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -966,7 +966,6 @@ OwningOpRef QCOProgramBuilder::finalize() { // Automatically deallocate all still-allocated tensors if (!validTensors.empty()) { for (auto& [tensor, tensorInfo] : llvm::to_vector(validTensors)) { - llvm::errs() << "Deallocating tensor\n"; // Filter out qubits belonging to this tensor llvm::SmallVector> toInsert; for (auto& [qubit, qubitInfo] : registerQubits) { From 2623d9b34723039dc6090339c17425f466bcef02 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 1 Apr 2026 21:35:15 +0200 Subject: [PATCH 31/71] =?UTF-8?q?=F0=9F=93=9D=20Add=20changelog=20entry?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0c3ef4d7b..172865af1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel - ✨ Add conversions between Jeff and QCO ([#1479], [#1548], [#1565]) ([**@denialhaag**]) - ✨ Add a `place-and-route` pass for mapping circuits to architectures with restricted topologies ([#1537], [#1547], [#1568], [#1581], [#1583], [#1588]) ([**@MatthiasReumann**]) - ✨ Add initial infrastructure for new QC and QCO MLIR dialects - ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1602]) + ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**]) ### Changed @@ -341,6 +341,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1588]: https://github.com/munich-quantum-toolkit/core/pull/1588 [#1583]: https://github.com/munich-quantum-toolkit/core/pull/1583 [#1581]: https://github.com/munich-quantum-toolkit/core/pull/1581 +[#1580]: https://github.com/munich-quantum-toolkit/core/pull/1580 [#1573]: https://github.com/munich-quantum-toolkit/core/pull/1573 [#1572]: https://github.com/munich-quantum-toolkit/core/pull/1572 [#1571]: https://github.com/munich-quantum-toolkit/core/pull/1571 From 35748192baf5ecf0dc4144d753caebbf7ca1fe5a Mon Sep 17 00:00:00 2001 From: burgholzer Date: Thu, 2 Apr 2026 12:26:20 +0200 Subject: [PATCH 32/71] =?UTF-8?q?=F0=9F=8E=A8=20Small=20tweaks=20to=20the?= =?UTF-8?q?=20QC=20program=20builder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 44 +++---------------- 1 file changed, 6 insertions(+), 38 deletions(-) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 9b8a102955..b6c503f7eb 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -96,24 +96,18 @@ QCProgramBuilder::allocQubitRegister(const int64_t size) { llvm::reportFatalUsageError("Size must be positive"); } - auto qubitType = QubitType::get(ctx); - auto memrefType = mlir::MemRefType::get({size}, qubitType); + auto memrefType = MemRefType::get({size}, QubitType::get(ctx)); auto memref = memref::AllocOp::create(*this, memrefType); + allocatedMemrefs.insert(memref); llvm::SmallVector qubits; qubits.reserve(size); - for (int64_t i = 0; i < size; ++i) { - auto index = arith::ConstantOp::create(*this, getIndexAttr(i)); + auto index = arith::ConstantIndexOp::create(*this, i); auto load = memref::LoadOp::create(*this, memref, index.getResult()); const auto& qubit = qubits.emplace_back(load.getResult()); - // Track the allocated qubit for automatic deallocation allocatedQubits.insert(qubit); } - - allocatedMemrefs.insert(memref); - - // TODO: Return register return qubits; } @@ -498,42 +492,16 @@ OwningOpRef QCProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } - llvm::SmallVector freeQubits; for (auto qubit : allocatedQubits) { if (!llvm::isa(qubit.getDefiningOp())) { - freeQubits.emplace_back(qubit); - } - } - - auto blockOrderComparator = [](Value a, Value b) { - auto* opA = a.getDefiningOp(); - auto* opB = b.getDefiningOp(); - if (!opA || !opB || opA->getBlock() != opB->getBlock()) { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + DeallocOp::create(*this, qubit); } - return opA->isBeforeInBlock(opB); - }; - - // Automatically deallocate all still-allocated qubits - // Sort qubits for deterministic output - llvm::SmallVector sortedQubits(freeQubits.begin(), freeQubits.end()); - llvm::sort(sortedQubits, blockOrderComparator); - - for (auto qubit : sortedQubits) { - DeallocOp::create(*this, qubit); } + allocatedQubits.clear(); - // Automatically deallocate all still-allocated memrefs - // Sort memrefs for deterministic output - llvm::SmallVector sortedMemrefs(allocatedMemrefs.begin(), - allocatedMemrefs.end()); - llvm::sort(sortedMemrefs, blockOrderComparator); - - for (auto memref : sortedMemrefs) { + for (auto memref : allocatedMemrefs) { memref::DeallocOp::create(*this, memref); } - - allocatedQubits.clear(); allocatedMemrefs.clear(); // Create constant 0 for successful exit code From 1fc2430535a59af62c83a47754a4f092a1f152b5 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 01:30:20 +0200 Subject: [PATCH 33/71] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Add=20back=20fold=20?= =?UTF-8?q?for=20`qtensor.insert`=20and=20enhance=20capabilities=20of=20th?= =?UTF-8?q?e=20canonicalization?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- .../mlir/Dialect/QTensor/IR/QTensorOps.td | 1 + .../QTensor/IR/Operations/InsertOp.cpp | 89 +++++++++++++++---- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index 41ddedcfcf..4c04c44fe9 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -208,6 +208,7 @@ def InsertOp }]; let hasCanonicalizer = 1; + let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index cc51d0d20b..31c07bca09 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -24,21 +24,77 @@ using namespace mlir; using namespace mlir::qtensor; /** - * @brief Find the `qtensor.extract` operation for a given `qtensor.insert` - * operation. + * @brief Checks whether two index values are equivalent for matching. */ -static ExtractOp findExtractOp(InsertOp op, Value index) { - auto* definingOp = op.getDest().getDefiningOp(); - if (llvm::isa(definingOp)) { - return llvm::cast(definingOp); +static bool areEquivalentIndices(Value lhs, Value rhs) { + return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); +} + +/** + * @brief Checks whether removing an extract-insert pair is linearity-safe. + */ +static bool isRemovableExtractInsertPair(InsertOp insertOp, + ExtractOp extractOp) { + return insertOp.getScalar() == extractOp.getResult() && + areEquivalentIndices(insertOp.getIndex(), extractOp.getIndex()); +} + +/** + * @brief Fold the direct pattern + * `insert(extract(tensor, idx).qubit, extract(tensor, idx).out, idx)`. + */ +static Value foldInsertAfterExtract(InsertOp insertOp) { + auto extractOp = insertOp.getScalar().getDefiningOp(); + if (!extractOp) { + return nullptr; } - if (llvm::isa(definingOp)) { - auto nestedInsertOp = llvm::cast(definingOp); - if (nestedInsertOp.getIndex() == index) { - return nullptr; + + if (insertOp.getDest() != extractOp.getOutTensor()) { + return nullptr; + } + + if (!isRemovableExtractInsertPair(insertOp, extractOp)) { + return nullptr; + } + + return extractOp.getTensor(); +} + +OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { + if (auto result = foldInsertAfterExtract(*this)) { + return result; + } + + return {}; +} + +/** + * @brief Find a matching `qtensor.extract` for an insert index in a tensor + * chain by traversing nested `qtensor.insert` and `qtensor.extract` ops. + */ +static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { + Value current = tensor; + while (Operation* definingOp = current.getDefiningOp()) { + if (auto nestedInsertOp = llvm::dyn_cast(definingOp)) { + // A more recent write to the same index shadows all older extracts. + if (areEquivalentIndices(nestedInsertOp.getIndex(), index)) { + return nullptr; + } + current = nestedInsertOp.getDest(); + continue; + } + + if (auto extractOp = llvm::dyn_cast(definingOp)) { + if (areEquivalentIndices(extractOp.getIndex(), index)) { + return extractOp; + } + current = extractOp.getTensor(); + continue; } - return findExtractOp(nestedInsertOp, index); + + break; } + return nullptr; } @@ -52,18 +108,13 @@ struct RemoveExtractInsertPair final : OpRewritePattern { LogicalResult matchAndRewrite(InsertOp op, PatternRewriter& rewriter) const override { - auto index = op.getIndex(); - - auto extractOp = findExtractOp(op, index); + auto extractOp = + findMatchingExtractInTensorChain(op.getDest(), op.getIndex()); if (!extractOp) { return failure(); } - if (op.getScalar() != extractOp.getResult()) { - return failure(); - } - - if (index != extractOp.getIndex()) { + if (!isRemovableExtractInsertPair(op, extractOp)) { return failure(); } From a4aa18ae5f821d0268da8d6631c65e5ff1323d06 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 01:32:24 +0200 Subject: [PATCH 34/71] =?UTF-8?q?=F0=9F=9A=B8=20Generalize=20IRVerificatio?= =?UTF-8?q?n=20to=20handle=20commuting=20chains=20of=20`qtensor.insert`=20?= =?UTF-8?q?operations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- mlir/lib/Support/CMakeLists.txt | 3 +- mlir/lib/Support/IRVerification.cpp | 258 +++++++++++++++++++++++++++- 2 files changed, 254 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt index 0e93033d4a..f63a83c87e 100644 --- a/mlir/lib/Support/CMakeLists.txt +++ b/mlir/lib/Support/CMakeLists.txt @@ -23,7 +23,8 @@ add_mlir_library( MLIRTransformUtils MLIRLLVMDialect MLIRFuncDialect - MLIRArithDialect) + MLIRArithDialect + MLIRQTensorDialect) mqt_mlir_target_use_project_options(MLIRSupportMQT) diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 4c3d470289..8438e0c10f 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -40,7 +41,6 @@ using namespace mlir; namespace { - /// Compute a structural hash for an operation (excluding SSA value identities). /// This hash is based on operation name, types, and attributes only. struct OperationStructuralHash { @@ -116,7 +116,238 @@ struct StructuralOperationKey { /// Map to track value equivalence between two modules. using ValueEquivalenceMap = llvm::DenseMap; + +using OperationSet = llvm::DenseSet; + +struct InsertWrite { + Value scalar; + Value index; +}; + +struct InsertChainSummary { + Value baseTensor; + Value finalTensor; + llvm::SmallVector writes; +}; } // namespace +static bool areValuesEquivalent(Value lhsValue, Value rhsValue, + ValueEquivalenceMap& valueMap) { + if (auto it = valueMap.find(lhsValue); it != valueMap.end()) { + return it->second == rhsValue; + } + valueMap[lhsValue] = rhsValue; + return true; +} + +static bool isQTensorInsertOp(Operation* op) { + return llvm::isa(op); +} + +static bool isCommutableQTensorInsertDependency(Operation* dependent, + Operation* dependency) { + auto dependentInsert = llvm::dyn_cast(dependent); + auto dependencyInsert = llvm::dyn_cast(dependency); + if (!dependentInsert || !dependencyInsert) { + return false; + } + return dependentInsert.getDest() == dependencyInsert.getResult(); +} + +static Value getInsertChainBaseTensor(Value tensor, const OperationSet& group) { + Value current = tensor; + while (auto insertOp = current.getDefiningOp()) { + if (!group.contains(insertOp.getOperation())) { + break; + } + current = insertOp.getDest(); + } + return current; +} + +static bool +summarizeInsertGroup(llvm::ArrayRef ops, + llvm::SmallVectorImpl& chains) { + OperationSet groupOps; + for (Operation* op : ops) { + groupOps.insert(op); + } + + llvm::DenseSet consumedInsertResults; + for (Operation* op : ops) { + auto insertOp = llvm::cast(op); + if (auto definingInsert = + insertOp.getDest().getDefiningOp()) { + if (groupOps.contains(definingInsert.getOperation())) { + consumedInsertResults.insert(insertOp.getDest()); + } + } + } + + llvm::DenseMap chainByBaseTensor; + for (Operation* op : ops) { + auto insertOp = llvm::cast(op); + const Value baseTensor = + getInsertChainBaseTensor(insertOp.getDest(), groupOps); + + size_t chainIdx = 0; + if (auto it = chainByBaseTensor.find(baseTensor); + it != chainByBaseTensor.end()) { + chainIdx = it->second; + } else { + chainIdx = chains.size(); + chainByBaseTensor[baseTensor] = chainIdx; + InsertChainSummary summary; + summary.baseTensor = baseTensor; + chains.emplace_back(std::move(summary)); + } + + auto& chain = chains[chainIdx]; + chain.writes.push_back( + InsertWrite{insertOp.getScalar(), insertOp.getIndex()}); + + if (!consumedInsertResults.contains(insertOp.getResult())) { + if (chain.finalTensor) { + return false; + } + chain.finalTensor = insertOp.getResult(); + } + } + + for (const auto& chain : chains) { + if (!chain.finalTensor) { + return false; + } + + // Reordering writes to the same index is not semantics-preserving. + llvm::DenseSet seenIndices; + for (const auto& write : chain.writes) { + if (!seenIndices.insert(write.index).second) { + return false; + } + } + } + + return true; +} + +static bool areInsertWritesEquivalentRec(const size_t lhsIdx, + llvm::ArrayRef lhsWrites, + llvm::ArrayRef rhsWrites, + llvm::SmallVectorImpl& rhsUsed, + ValueEquivalenceMap& valueMap) { + if (lhsIdx == lhsWrites.size()) { + return true; + } + + for (size_t rhsIdx = 0; rhsIdx < rhsWrites.size(); ++rhsIdx) { + if (rhsUsed[rhsIdx] != 0) { + continue; + } + + ValueEquivalenceMap tempMap = valueMap; + if (!areValuesEquivalent(lhsWrites[lhsIdx].scalar, rhsWrites[rhsIdx].scalar, + tempMap) || + !areValuesEquivalent(lhsWrites[lhsIdx].index, rhsWrites[rhsIdx].index, + tempMap)) { + continue; + } + + rhsUsed[rhsIdx] = 1; + if (areInsertWritesEquivalentRec(lhsIdx + 1, lhsWrites, rhsWrites, rhsUsed, + tempMap)) { + valueMap = std::move(tempMap); + return true; + } + rhsUsed[rhsIdx] = 0; + } + + return false; +} + +static bool areInsertWritesEquivalent(llvm::ArrayRef lhsWrites, + llvm::ArrayRef rhsWrites, + ValueEquivalenceMap& valueMap) { + if (lhsWrites.size() != rhsWrites.size()) { + return false; + } + llvm::SmallVector rhsUsed(rhsWrites.size(), 0); + return areInsertWritesEquivalentRec(0, lhsWrites, rhsWrites, rhsUsed, + valueMap); +} + +static bool areInsertChainsEquivalent(const InsertChainSummary& lhsChain, + const InsertChainSummary& rhsChain, + ValueEquivalenceMap& valueMap) { + ValueEquivalenceMap tempMap = valueMap; + if (!areValuesEquivalent(lhsChain.baseTensor, rhsChain.baseTensor, tempMap)) { + return false; + } + + if (!areInsertWritesEquivalent(lhsChain.writes, rhsChain.writes, tempMap)) { + return false; + } + + if (!areValuesEquivalent(lhsChain.finalTensor, rhsChain.finalTensor, + tempMap)) { + return false; + } + + valueMap = std::move(tempMap); + return true; +} + +static bool areInsertGroupsEquivalentRec( + const size_t lhsChainIdx, llvm::ArrayRef lhsChains, + llvm::ArrayRef rhsChains, + llvm::SmallVectorImpl& rhsChainUsed, ValueEquivalenceMap& valueMap) { + if (lhsChainIdx == lhsChains.size()) { + return true; + } + + for (size_t rhsChainIdx = 0; rhsChainIdx < rhsChains.size(); ++rhsChainIdx) { + if (rhsChainUsed[rhsChainIdx] != 0) { + continue; + } + + ValueEquivalenceMap tempMap = valueMap; + if (!areInsertChainsEquivalent(lhsChains[lhsChainIdx], + rhsChains[rhsChainIdx], tempMap)) { + continue; + } + + rhsChainUsed[rhsChainIdx] = 1; + if (areInsertGroupsEquivalentRec(lhsChainIdx + 1, lhsChains, rhsChains, + rhsChainUsed, tempMap)) { + valueMap = std::move(tempMap); + return true; + } + rhsChainUsed[rhsChainIdx] = 0; + } + + return false; +} + +static bool areInsertGroupsEquivalent(llvm::ArrayRef lhsOps, + llvm::ArrayRef rhsOps, + ValueEquivalenceMap& valueMap) { + if (lhsOps.size() != rhsOps.size()) { + return false; + } + + llvm::SmallVector lhsChains; + llvm::SmallVector rhsChains; + if (!summarizeInsertGroup(lhsOps, lhsChains) || + !summarizeInsertGroup(rhsOps, rhsChains)) { + return false; + } + if (lhsChains.size() != rhsChains.size()) { + return false; + } + + llvm::SmallVector rhsChainUsed(rhsChains.size(), 0); + return areInsertGroupsEquivalentRec(0, lhsChains, rhsChains, rhsChainUsed, + valueMap); +} /// DenseMapInfo specialization for StructuralOperationKey template <> struct llvm::DenseMapInfo { @@ -378,18 +609,21 @@ llvm::SmallVector processed; llvm::SmallVector currentGroup; for (auto* op : ops) { bool dependsOnCurrent = false; // Check if this operation depends on any operation in the current group - for (const auto* groupOp : currentGroup) { - if (dependsOn[op].contains(groupOp)) { - dependsOnCurrent = true; - break; + for (auto* groupOp : currentGroup) { + if (!dependsOn[op].contains(groupOp)) { + continue; } + if (isCommutableQTensorInsertDependency(op, groupOp)) { + continue; + } + dependsOnCurrent = true; + break; } // Check if this operation has ordering constraints @@ -502,6 +736,18 @@ static bool areBlocksEquivalent(Block& lhs, Block& rhs, auto& lhsGroup = lhsGroups[groupIdx]; auto& rhsGroup = rhsGroups[groupIdx]; + const bool lhsInsertGroup = llvm::all_of(lhsGroup, isQTensorInsertOp); + const bool rhsInsertGroup = llvm::all_of(rhsGroup, isQTensorInsertOp); + if (lhsInsertGroup || rhsInsertGroup) { + if (!lhsInsertGroup || !rhsInsertGroup) { + return false; + } + if (!areInsertGroupsEquivalent(lhsGroup, rhsGroup, valueMap)) { + return false; + } + continue; + } + if (!areIndependentGroupsEquivalent(lhsGroup, rhsGroup)) { return false; } From f10935e5cf9c9b807fd9521c184a1ef793e3243b Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 01:35:10 +0200 Subject: [PATCH 35/71] =?UTF-8?q?=F0=9F=8E=A8=20Small=20tweaks=20to=20the?= =?UTF-8?q?=20QCO=20program=20builder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 4 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 94 +++++++------------ mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 63 +++++++++++++ 3 files changed, 98 insertions(+), 63 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index eef37dac18..9579f9e98c 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1360,7 +1360,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { }; /// Track valid (unconsumed) qubit SSA values for linear type enforcement. - /// Only values present in this set are valid for use in operations. + /// Only values present in this map are valid for use in operations. /// When an operation consumes a qubit and produces a new one, the old value /// is removed and the new output is added. llvm::DenseMap validQubits; @@ -1390,7 +1390,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { }; /// Track valid (unconsumed) tensor SSA values for linear type enforcement. - /// Only values present in this set are valid for use in operations. + /// Only values present in this map are valid for use in operations. /// When an operation consumes a tensor and produces a new one, the old value /// is removed and the new output is added. llvm::DenseMap validTensors; diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index a63ef6a593..d6d40ad94b 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -79,7 +79,7 @@ Value QCOProgramBuilder::allocQubit() { auto qubit = allocOp.getResult(); // Track the allocated qubit as valid - validQubits.insert({qubit, {}}); + validQubits.try_emplace(qubit, QubitInfo{}); return qubit; } @@ -91,7 +91,7 @@ Value QCOProgramBuilder::staticQubit(const uint64_t index) { const auto qubit = staticOp.getQubit(); // Track the static qubit as valid - validQubits.insert({qubit, {}}); + validQubits.try_emplace(qubit, QubitInfo{}); return qubit; } @@ -108,14 +108,11 @@ QCOProgramBuilder::allocQubitRegister(const int64_t size) { llvm::SmallVector qubits; qubits.reserve(size); - for (int64_t i = 0; i < size; ++i) { auto [qtensorOut, qubit] = qtensorExtract(qtensor, i); qtensor = qtensorOut; qubits.emplace_back(qubit); } - - // TODO: Return register return qubits; } @@ -157,7 +154,7 @@ void QCOProgramBuilder::updateQubitTracking(Value inputQubit, validQubits.erase(it); // Add the output (new) value to tracking - validQubits.insert({outputQubit, info}); + validQubits.try_emplace(outputQubit, std::move(info)); } void QCOProgramBuilder::validateTensorValue(Value tensor) const { @@ -190,7 +187,7 @@ void QCOProgramBuilder::updateTensorTracking(Value inputTensor, validTensors.erase(it); // Add the output (new) value to tracking - validTensors.insert({outputTensor, info}); + validTensors.try_emplace(outputTensor, std::move(info)); } //===----------------------------------------------------------------------===// @@ -201,11 +198,11 @@ Value QCOProgramBuilder::qtensorAlloc( const std::variant& size) { checkFinalized(); - auto sizeValue = utils::variantToValue(*this, getLoc(), size); + auto sizeValue = variantToValue(*this, getLoc(), size); auto allocOp = qtensor::AllocOp::create(*this, sizeValue); auto result = allocOp.getResult(); - validTensors.insert({result, {tensorCounter++}}); + validTensors.try_emplace(result, TensorInfo{tensorCounter++}); return result; } @@ -227,7 +224,7 @@ Value QCOProgramBuilder::qtensorFromElements(ValueRange elements) { auto fromElementsOp = qtensor::FromElementsOp::create(*this, elements); auto result = fromElementsOp.getResult(); - validTensors.insert({result, {tensorCounter++}}); + validTensors.try_emplace(result, TensorInfo{tensorCounter++}); return result; } @@ -235,8 +232,7 @@ std::pair QCOProgramBuilder::qtensorExtract(Value tensor, const int64_t index) { checkFinalized(); - auto indexValue = - arith::ConstantOp::create(*this, getIndexAttr(index)).getResult(); + auto indexValue = arith::ConstantIndexOp::create(*this, index).getResult(); auto extractOp = qtensor::ExtractOp::create(*this, tensor, indexValue); auto qubit = extractOp.getResult(); auto outTensor = extractOp.getOutTensor(); @@ -244,7 +240,7 @@ std::pair QCOProgramBuilder::qtensorExtract(Value tensor, validateTensorValue(tensor); const auto regId = validTensors[tensor].regId; - validQubits.insert({qubit, {.regId = regId, .regIndex = index}}); + validQubits.try_emplace(qubit, QubitInfo{.regId = regId, .regIndex = index}); updateTensorTracking(tensor, outTensor); return {outTensor, qubit}; @@ -255,14 +251,14 @@ std::pair QCOProgramBuilder::qtensorExtractSlice( const std::variant& size) { checkFinalized(); - auto offsetValue = utils::variantToValue(*this, getLoc(), offset); - auto sizesValue = utils::variantToValue(*this, getLoc(), size); + auto offsetValue = variantToValue(*this, getLoc(), offset); + auto sizesValue = variantToValue(*this, getLoc(), size); auto extractSliceOp = qtensor::ExtractSliceOp::create(*this, tensor, offsetValue, sizesValue); auto slicedTensor = extractSliceOp.getResult(); auto outTensor = extractSliceOp.getOutTensor(); - validTensors.insert({slicedTensor, {tensorCounter++}}); + validTensors.try_emplace(slicedTensor, TensorInfo{tensorCounter++}); updateTensorTracking(tensor, outTensor); return {outTensor, slicedTensor}; @@ -272,7 +268,7 @@ Value QCOProgramBuilder::qtensorInsert( Value scalar, Value tensor, const std::variant& index) { checkFinalized(); - auto indexValue = utils::variantToValue(*this, getLoc(), index); + auto indexValue = variantToValue(*this, getLoc(), index); auto insertOp = qtensor::InsertOp::create(*this, scalar, tensor, indexValue); auto outTensor = insertOp.getResult(); @@ -289,8 +285,8 @@ Value QCOProgramBuilder::qtensorInsertSlice( const std::variant& size) { checkFinalized(); - auto offsetValue = utils::variantToValue(*this, getLoc(), offset); - auto sizeValue = utils::variantToValue(*this, getLoc(), size); + auto offsetValue = variantToValue(*this, getLoc(), offset); + auto sizeValue = variantToValue(*this, getLoc(), size); auto insertSliceOp = qtensor::InsertSliceOp::create(*this, source, dest, offsetValue, sizeValue); @@ -843,7 +839,7 @@ ValueRange QCOProgramBuilder::qcoIf( llvm::function_ref(ValueRange)> elseBody) { checkFinalized(); - auto conditionValue = utils::variantToValue(*this, getLoc(), condition); + auto conditionValue = variantToValue(*this, getLoc(), condition); auto ifOp = IfOp::create(*this, conditionValue, qubits); // Create the then and else block @@ -854,8 +850,8 @@ ValueRange QCOProgramBuilder::qcoIf( for (auto qubitType : qubits.getTypes()) { const auto thenArg = thenBlock.addArgument(qubitType, getLoc()); const auto elseArg = elseBlock.addArgument(qubitType, getLoc()); - validQubits.insert({thenArg, {}}); - validQubits.insert({elseArg, {}}); + validQubits.try_emplace(thenArg, QubitInfo{}); + validQubits.try_emplace(elseArg, QubitInfo{}); } // Construct the bodies of the regions @@ -930,61 +926,37 @@ OwningOpRef QCOProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } - auto blockOrderComparator = [](const std::pair& a, - const std::pair& b) { - auto* opA = a.first.getDefiningOp(); - auto* opB = b.first.getDefiningOp(); - if (!opA || !opB || opA->getBlock() != opB->getBlock()) { - return a.first.getAsOpaquePointer() < b.first.getAsOpaquePointer(); - } - if (opA != opB) { - return opA->isBeforeInBlock(opB); - } - return a.second < b.second; - }; - llvm::DenseSet validTensorIds; for (const auto& [tensor, info] : validTensors) { validTensorIds.insert(info.regId); } - llvm::SmallVector freeQubits; llvm::DenseMap registerQubits; for (auto [qubit, info] : validQubits) { if (info.regId == -1 || !validTensorIds.contains(info.regId)) { - freeQubits.push_back(qubit); + // Automatically deallocate all still-allocated qubits + SinkOp::create(*this, qubit); } else { - registerQubits.insert({qubit, info}); + registerQubits.try_emplace(qubit, info); } } - // Automatically deallocate all still-allocated qubits - for (auto qubit : freeQubits) { - SinkOp::create(*this, qubit); - } - // Automatically deallocate all still-allocated tensors - if (!validTensors.empty()) { - for (auto& [tensor, tensorInfo] : llvm::to_vector(validTensors)) { - // Filter out qubits belonging to this tensor - llvm::SmallVector> toInsert; - for (auto& [qubit, qubitInfo] : registerQubits) { - if (qubitInfo.regId != tensorInfo.regId) { - continue; - } - toInsert.push_back({qubit, qubitInfo.regIndex}); + for (auto& [tensor, tensorInfo] : validTensors) { + Value currentTensor = tensor; + // Filter out qubits belonging to this tensor + for (auto& [qubit, qubitInfo] : registerQubits) { + if (qubitInfo.regId != tensorInfo.regId) { + continue; } - // Sort qubits for deterministic output - llvm::sort(toInsert, blockOrderComparator); - // Insert qubits - for (auto& [qubit, index] : toInsert) { - tensor = qtensorInsert(qubit, tensor, index); - } - // Deallocate tensor - qtensor::DeallocOp::create(*this, tensor); + auto indexValue = constantFromScalar(*this, getLoc(), qubitInfo.regIndex); + currentTensor = + qtensor::InsertOp::create(*this, qubit, currentTensor, indexValue) + .getResult(); } + // Deallocate tensor + qtensor::DeallocOp::create(*this, currentTensor); } - validQubits.clear(); validTensors.clear(); diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index da5a3ad310..8f0cf08693 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -67,6 +68,34 @@ class QCOTest : public testing::TestWithParam { } }; +OwningOpRef +buildTwoQubitInsertChainProgram(MLIRContext* context, + const bool reverseInsertOrder, + const bool swapInsertTargets) { + qco::QCOProgramBuilder builder(context); + builder.initialize(); + + auto tensor = builder.qtensorAlloc(2); + auto [tensorAfterFirstExtract, qubit0] = builder.qtensorExtract(tensor, 0); + auto [baseTensor, qubit1] = + builder.qtensorExtract(tensorAfterFirstExtract, 1); + + const int64_t qubit0Target = swapInsertTargets ? 1 : 0; + const int64_t qubit1Target = swapInsertTargets ? 0 : 1; + + Value currentTensor = baseTensor; + if (reverseInsertOrder) { + currentTensor = builder.qtensorInsert(qubit1, currentTensor, qubit1Target); + currentTensor = builder.qtensorInsert(qubit0, currentTensor, qubit0Target); + } else { + currentTensor = builder.qtensorInsert(qubit0, currentTensor, qubit0Target); + currentTensor = builder.qtensorInsert(qubit1, currentTensor, qubit1Target); + } + + builder.qtensorDealloc(currentTensor); + return builder.finalize(); +} + } // namespace TEST_P(QCOTest, ProgramEquivalence) { @@ -96,6 +125,40 @@ TEST_P(QCOTest, ProgramEquivalence) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } +TEST_F(QCOTest, InsertChainPermutationEquivalence) { + auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + runCanonicalizationPasses(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildTwoQubitInsertChainProgram(context.get(), true, false); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QCOTest, InsertChainDifferentAssignmentsNotEquivalent) { + auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + runCanonicalizationPasses(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildTwoQubitInsertChainProgram(context.get(), true, true); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_FALSE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + TEST_F(QCOTest, DirectIfBuilder) { // Test If construction directly qco::QCOProgramBuilder builder(context.get()); From 00b1c836567529217d4e2ccf677e8d9efd23fcb0 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 12:03:58 +0200 Subject: [PATCH 36/71] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Add=20canonicalizati?= =?UTF-8?q?on=20for=20interleaved=20extract-insert=20operations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- .../mlir/Dialect/QTensor/IR/QTensorOps.td | 1 + .../QTensor/IR/Operations/ExtractOp.cpp | 122 +++++++++++++++++- mlir/lib/Support/IRVerification.cpp | 24 +++- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 78 +++++++++++ 4 files changed, 217 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index 4c04c44fe9..29227d12a8 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -134,6 +134,7 @@ def ExtractOp let results = (outs 1DTensorOf<[QubitType]>:$out_tensor, QubitType:$result); let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor)"; + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 27e8de6995..e17f6e6172 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -10,9 +10,12 @@ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include #include #include +#include #include +#include #include #include @@ -34,6 +37,34 @@ LogicalResult ExtractOp::verify() { return success(); } +/** + * @brief Checks whether two index values are equivalent for matching. + */ +static bool areEquivalentIndices(Value lhs, Value rhs) { + return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); +} + +/** + * @brief Tensor-transforming ops in a chain that can commute past + * `qtensor.extract` at a different index. + */ +static bool isTensorChainOp(Operation* op) { + return llvm::isa(op); +} + +/** + * @brief Returns the tensor input of a tensor-transforming op. + */ +static Value getTensorChainInput(Operation* op) { + if (auto insertOp = llvm::dyn_cast(op)) { + return insertOp.getDest(); + } + if (auto extractOp = llvm::dyn_cast(op)) { + return extractOp.getTensor(); + } + return nullptr; +} + /** * @brief If an ExtractOp consumes an InsertOp with the same index, * return the scalar and the destTensor from the InsertOp directly. @@ -44,10 +75,7 @@ static InsertOp foldExtractAfterInsert(ExtractOp extractOp) { return nullptr; } - Value insertIndex = insertOp.getIndex(); - Value extractIndex = extractOp.getIndex(); - - if (getAsOpFoldResult(insertIndex) != getAsOpFoldResult(extractIndex)) { + if (!areEquivalentIndices(insertOp.getIndex(), extractOp.getIndex())) { return nullptr; } @@ -64,3 +92,89 @@ LogicalResult ExtractOp::fold(FoldAdaptor /*adaptor*/, return failure(); } + +namespace { + +/** + * @brief Remove matching insert-extract pairs through commuting tensor-chain + * operations on different indices. + */ +struct RemoveInsertExtractPair final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static Value getTensorChainOutput(Operation* op) { + if (auto insertOp = llvm::dyn_cast(op)) { + return insertOp.getResult(); + } + if (auto nestedExtractOp = llvm::dyn_cast(op)) { + return nestedExtractOp.getOutTensor(); + } + return nullptr; + } + + static void setTensorChainInput(Operation* op, Value tensor) { + if (llvm::isa(op)) { + op->setOperand(1, tensor); + return; + } + if (llvm::isa(op)) { + op->setOperand(0, tensor); + } + } + + LogicalResult matchAndRewrite(ExtractOp extractOp, + PatternRewriter& rewriter) const override { + llvm::SmallVector traversedOps; + Value currentTensor = extractOp.getTensor(); + InsertOp matchedInsertOp = nullptr; + + while (auto* definingOp = currentTensor.getDefiningOp()) { + if (!isTensorChainOp(definingOp)) { + break; + } + + if (auto insertOp = llvm::dyn_cast(definingOp)) { + if (areEquivalentIndices(insertOp.getIndex(), extractOp.getIndex())) { + matchedInsertOp = insertOp; + break; + } + } else { + auto nestedExtractOp = llvm::cast(definingOp); + if (areEquivalentIndices(nestedExtractOp.getIndex(), + extractOp.getIndex())) { + // Do not reorder reads from the same index. + return failure(); + } + } + + traversedOps.push_back(definingOp); + currentTensor = getTensorChainInput(definingOp); + } + + if (!matchedInsertOp) { + return failure(); + } + + Value outTensor = matchedInsertOp.getDest(); + if (!traversedOps.empty()) { + Operation* oldestCommutedOp = traversedOps.back(); + rewriter.modifyOpInPlace(oldestCommutedOp, [&]() { + setTensorChainInput(oldestCommutedOp, matchedInsertOp.getDest()); + }); + outTensor = getTensorChainOutput(traversedOps.front()); + if (!outTensor) { + return failure(); + } + } + + rewriter.replaceOp(extractOp, {outTensor, matchedInsertOp.getScalar()}); + return success(); + } +}; + +} // namespace + +void ExtractOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 8438e0c10f..3ade4cd5ce 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -139,6 +140,18 @@ static bool areValuesEquivalent(Value lhsValue, Value rhsValue, return true; } +static bool areEquivalentIndices(Value lhsValue, Value rhsValue) { + return getAsOpFoldResult(lhsValue) == getAsOpFoldResult(rhsValue); +} + +static bool areIndexValuesEquivalent(Value lhsValue, Value rhsValue, + ValueEquivalenceMap& valueMap) { + if (areEquivalentIndices(lhsValue, rhsValue)) { + return true; + } + return areValuesEquivalent(lhsValue, rhsValue, valueMap); +} + static bool isQTensorInsertOp(Operation* op) { return llvm::isa(op); } @@ -219,11 +232,14 @@ summarizeInsertGroup(llvm::ArrayRef ops, } // Reordering writes to the same index is not semantics-preserving. - llvm::DenseSet seenIndices; + llvm::SmallVector seenIndices; for (const auto& write : chain.writes) { - if (!seenIndices.insert(write.index).second) { + if (llvm::any_of(seenIndices, [&](Value seenIndex) { + return areEquivalentIndices(seenIndex, write.index); + })) { return false; } + seenIndices.push_back(write.index); } } @@ -247,8 +263,8 @@ static bool areInsertWritesEquivalentRec(const size_t lhsIdx, ValueEquivalenceMap tempMap = valueMap; if (!areValuesEquivalent(lhsWrites[lhsIdx].scalar, rhsWrites[rhsIdx].scalar, tempMap) || - !areValuesEquivalent(lhsWrites[lhsIdx].index, rhsWrites[rhsIdx].index, - tempMap)) { + !areIndexValuesEquivalent(lhsWrites[lhsIdx].index, + rhsWrites[rhsIdx].index, tempMap)) { continue; } diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 8f0cf08693..24ff467ec7 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -31,6 +31,7 @@ #include #include #include +#include using namespace mlir; using namespace mlir::qco; @@ -96,6 +97,49 @@ buildTwoQubitInsertChainProgram(MLIRContext* context, return builder.finalize(); } +OwningOpRef +buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, + const bool swapInsertTargets) { + qco::QCOProgramBuilder builder(context); + builder.initialize(); + + auto tensor = builder.qtensorAlloc(3); + Value tensorAfterReads = tensor; + Value qubit0 = nullptr; + Value qubit1 = nullptr; + + if (reverseOrder) { + std::tie(tensorAfterReads, qubit1) = + builder.qtensorExtract(tensorAfterReads, 1); + std::tie(tensorAfterReads, qubit0) = + builder.qtensorExtract(tensorAfterReads, 0); + } else { + std::tie(tensorAfterReads, qubit0) = + builder.qtensorExtract(tensorAfterReads, 0); + std::tie(tensorAfterReads, qubit1) = + builder.qtensorExtract(tensorAfterReads, 1); + } + + const int64_t q0Target = 0; + const int64_t q1Target = swapInsertTargets ? 2 : 1; + + Value tensorAfterWrites = tensorAfterReads; + if (reverseOrder) { + tensorAfterWrites = + builder.qtensorInsert(qubit0, tensorAfterWrites, q0Target); + tensorAfterWrites = + builder.qtensorInsert(qubit1, tensorAfterWrites, q1Target); + } else { + tensorAfterWrites = + builder.qtensorInsert(qubit1, tensorAfterWrites, q1Target); + tensorAfterWrites = + builder.qtensorInsert(qubit0, tensorAfterWrites, q0Target); + } + + builder.qtensorDealloc(tensorAfterWrites); + return builder.finalize(); +} + } // namespace TEST_P(QCOTest, ProgramEquivalence) { @@ -159,6 +203,40 @@ TEST_F(QCOTest, InsertChainDifferentAssignmentsNotEquivalent) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } +TEST_F(QCOTest, MixedExtractInsertPermutationEquivalence) { + auto program = buildMixedExtractInsertProgram(context.get(), false, false); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + runCanonicalizationPasses(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildMixedExtractInsertProgram(context.get(), true, false); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QCOTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { + auto program = buildMixedExtractInsertProgram(context.get(), false, false); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + runCanonicalizationPasses(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildMixedExtractInsertProgram(context.get(), true, true); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_FALSE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + TEST_F(QCOTest, DirectIfBuilder) { // Test If construction directly qco::QCOProgramBuilder builder(context.get()); From dbbf1542a900bd30bb8eec860dc8f21ed160c258 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 12:55:27 +0200 Subject: [PATCH 37/71] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Turn=20`qco.reset`?= =?UTF-8?q?=20canonicalization=20into=20fold=20and=20handle=20tensor=20ope?= =?UTF-8?q?ration=20chains=20in=20other=20canonicalization?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 1 + .../lib/Dialect/QCO/IR/Operations/ResetOp.cpp | 67 ++++++++------- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 82 +++++++++++++++++++ 3 files changed, 121 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index efee9623e5..397a38886f 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -180,6 +180,7 @@ def ResetOp : QCOOp<"reset", [Idempotent, SameOperandsAndResultType]> { let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index 5a49cdfe1c..3e29c68472 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -12,8 +12,8 @@ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include +#include #include -#include #include #include #include @@ -22,41 +22,42 @@ using namespace mlir; using namespace mlir::qco; /** - * @brief Check if a `qtensor.extract` operation ultimately originates from a - * `qtensor.alloc` operation. + * @brief Check if a `qtensor.extract` operation is guaranteed to read from a + * `qtensor.alloc` chain. + * + * In QTensor's linear tensor model, reads/writes on different indices commute. + * We can therefore skip over `qtensor.insert` on other indices while tracing + * provenance. A write to the same index invalidates the proof. */ static bool originatesFromAlloc(qtensor::ExtractOp extractOp) { - auto* definingOp = extractOp.getTensor().getDefiningOp(); - if (llvm::isa(definingOp)) { - return true; - } - if (llvm::isa(definingOp)) { - return originatesFromAlloc(llvm::cast(definingOp)); - } - return false; -} + Value currentTensor = extractOp.getTensor(); + const auto extractIndex = getAsOpFoldResult(extractOp.getIndex()); -namespace { + while (auto* definingOp = currentTensor.getDefiningOp()) { + if (llvm::isa(definingOp)) { + return true; + } -/** - * @brief Remove reset operations that immediately follow a `qco.alloc` - * operation. - */ -struct RemoveResetAfterAlloc final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { + currentTensor = nestedExtractOp.getTensor(); + continue; + } - LogicalResult matchAndRewrite(ResetOp op, - PatternRewriter& rewriter) const override { - // Check if the predecessor is an AllocOp - if (auto allocOp = op.getQubitIn().getDefiningOp(); !allocOp) { - return failure(); + if (auto insertOp = llvm::dyn_cast(definingOp)) { + if (getAsOpFoldResult(insertOp.getIndex()) == extractIndex) { + return false; + } + currentTensor = insertOp.getDest(); + continue; } - // Remove the ResetOp - rewriter.replaceOp(op, op.getQubitIn()); - return success(); + return false; } -}; + + return false; +} + +namespace { /** * @brief Remove reset operations that immediately follow a `qtensor.extract` @@ -86,7 +87,15 @@ struct RemoveResetAfterExtract final : OpRewritePattern { } // namespace +OpFoldResult ResetOp::fold(FoldAdaptor /*adaptor*/) { + if (getQubitIn().getDefiningOp()) { + return getQubitIn(); + } + + return {}; +} + void ResetOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 24ff467ec7..5173489c14 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -140,6 +140,54 @@ buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, return builder.finalize(); } +OwningOpRef +buildResetWithCommutingInsertProgram(MLIRContext* context, + const bool withReset) { + qco::QCOProgramBuilder builder(context); + builder.initialize(); + + auto tensor = builder.qtensorAlloc(2); + auto [tensorAfterExtract0, qubit0] = builder.qtensorExtract(tensor, 0); + auto tensorAfterInsert0 = + builder.qtensorInsert(qubit0, tensorAfterExtract0, 0); + auto [tensorAfterExtract1, qubit1] = + builder.qtensorExtract(tensorAfterInsert0, 1); + if (withReset) { + qubit1 = builder.reset(qubit1); + } + auto tensorFinal = builder.qtensorInsert(qubit1, tensorAfterExtract1, 1); + builder.qtensorDealloc(tensorFinal); + + return builder.finalize(); +} + +OwningOpRef +buildResetWithSameIndexInsertProgram(MLIRContext* context, + const bool withReset) { + qco::QCOProgramBuilder builder(context); + builder.initialize(); + + auto tensor = builder.qtensorAlloc(2); + auto [tensorAfterExtract0, qubit0] = builder.qtensorExtract(tensor, 0); + auto [tensorAfterExtract1, qubit1] = + builder.qtensorExtract(tensorAfterExtract0, 1); + qubit1 = builder.h(qubit1); + auto tensorAfterInsert1 = + builder.qtensorInsert(qubit1, tensorAfterExtract1, 1); + auto [tensorAfterReadBack1, qubit1ReadBack] = + builder.qtensorExtract(tensorAfterInsert1, 1); + if (withReset) { + qubit1ReadBack = builder.reset(qubit1ReadBack); + } + auto tensorAfterInsert1ReadBack = + builder.qtensorInsert(qubit1ReadBack, tensorAfterReadBack1, 1); + auto tensorFinal = + builder.qtensorInsert(qubit0, tensorAfterInsert1ReadBack, 0); + builder.qtensorDealloc(tensorFinal); + + return builder.finalize(); +} + } // namespace TEST_P(QCOTest, ProgramEquivalence) { @@ -237,6 +285,40 @@ TEST_F(QCOTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } +TEST_F(QCOTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { + auto program = buildResetWithCommutingInsertProgram(context.get(), true); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + runCanonicalizationPasses(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildResetWithCommutingInsertProgram(context.get(), false); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QCOTest, ResetAfterExtractThroughSameIndexInsertIsNotEliminated) { + auto program = buildResetWithSameIndexInsertProgram(context.get(), true); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + runCanonicalizationPasses(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildResetWithSameIndexInsertProgram(context.get(), false); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_FALSE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + TEST_F(QCOTest, DirectIfBuilder) { // Test If construction directly qco::QCOProgramBuilder builder(context.get()); From badf76fc9ce7c4a3672f90d9ab241ea13b6ec55e Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 13:51:04 +0200 Subject: [PATCH 38/71] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Turn=20zero-angle=20?= =?UTF-8?q?rotation=20gate=20canonicalizations=20into=20folds?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 10 +++++ mlir/include/mlir/Dialect/QCO/QCOUtils.h | 44 ------------------- .../QCO/IR/Operations/StandardGates/POp.cpp | 23 +++++----- .../QCO/IR/Operations/StandardGates/RXOp.cpp | 22 ++++------ .../QCO/IR/Operations/StandardGates/RXXOp.cpp | 26 +++++------ .../QCO/IR/Operations/StandardGates/RYOp.cpp | 22 ++++------ .../QCO/IR/Operations/StandardGates/RYYOp.cpp | 26 +++++------ .../QCO/IR/Operations/StandardGates/RZOp.cpp | 22 ++++------ .../QCO/IR/Operations/StandardGates/RZXOp.cpp | 25 +++++------ .../QCO/IR/Operations/StandardGates/RZZOp.cpp | 26 +++++------ .../Operations/StandardGates/XXMinusYYOp.cpp | 25 +++++------ .../Operations/StandardGates/XXPlusYYOp.cpp | 25 +++++------ 12 files changed, 119 insertions(+), 177 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index 397a38886f..20b1190d15 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -525,6 +525,7 @@ def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -553,6 +554,7 @@ def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -581,6 +583,7 @@ def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -609,6 +612,7 @@ def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -842,6 +846,7 @@ def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit0_in, "Value":$qubit1_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -873,6 +878,7 @@ def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit0_in, "Value":$qubit1_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -904,6 +910,7 @@ def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit0_in, "Value":$qubit1_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -935,6 +942,7 @@ def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let builders = [OpBuilder<(ins "Value":$qubit0_in, "Value":$qubit1_in, "const std::variant&":$theta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -969,6 +977,7 @@ def XXPlusYYOp : QCOOp<"xx_plus_yy", "const std::variant&":$theta, "const std::variant&":$beta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -1003,6 +1012,7 @@ def XXMinusYYOp : QCOOp<"xx_minus_yy", "const std::variant&":$theta, "const std::variant&":$beta)>]; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 337b767fb9..f888509129 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -242,48 +242,4 @@ mergeTwoTargetOneParameterWithSwappedTargets(OpType op, return success(); } -/** - * @brief Remove a trivial one-target, one-parameter operation - * - * @tparam OpType The type of the operation to be checked. - * @param op The operation instance. - * @param rewriter The pattern rewriter. - * @return LogicalResult Success or failure of the removal. - */ -template -mlir::LogicalResult -removeTrivialOneTargetOneParameter(OpType op, PatternRewriter& rewriter) { - const auto param = utils::valueToDouble(op.getOperand(1)); - if (!param || std::abs(*param) > utils::TOLERANCE) { - return failure(); - } - - // Trivialize operation - rewriter.replaceOp(op, op.getInputQubit(0)); - - return success(); -} - -/** - * @brief Remove a trivial two-target, one-parameter operation - * - * @tparam OpType The type of the operation to be checked. - * @param op The operation instance. - * @param rewriter The pattern rewriter. - * @return LogicalResult Success or failure of the removal. - */ -template -mlir::LogicalResult -removeTrivialTwoTargetOneParameter(OpType op, PatternRewriter& rewriter) { - const auto param = utils::valueToDouble(op.getOperand(2)); - if (!param || std::abs(*param) > utils::TOLERANCE) { - return failure(); - } - - // Trivialize operation - rewriter.replaceOp(op, op.getInputQubits()); - - return success(); -} - } // namespace mlir::qco diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp index 709c7ca0bf..08ee6e068e 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -42,18 +43,6 @@ struct MergeSubsequentP final : OpRewritePattern { } }; -/** - * @brief Remove trivial P operations. - */ -struct RemoveTrivialP final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(POp op, - PatternRewriter& rewriter) const override { - return removeTrivialOneTargetOneParameter(op, rewriter); - } -}; - } // namespace void POp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, @@ -63,9 +52,17 @@ void POp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, build(odsBuilder, odsState, qubitIn, thetaOperand); } +OpFoldResult POp::fold(FoldAdaptor /*adaptor*/) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + return getInputQubit(0); + } + return {}; +} + void POp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional POp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp index a57ea85105..eef8bf100c 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp @@ -43,18 +43,6 @@ struct MergeSubsequentRX final : OpRewritePattern { } }; -/** - * @brief Remove trivial RX operations. - */ -struct RemoveTrivialRX final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RXOp op, - PatternRewriter& rewriter) const override { - return removeTrivialOneTargetOneParameter(op, rewriter); - } -}; - } // namespace void RXOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, @@ -64,9 +52,17 @@ void RXOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, build(odsBuilder, odsState, qubitIn, thetaOperand); } +OpFoldResult RXOp::fold(FoldAdaptor /*adaptor*/) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + return getInputQubit(0); + } + return {}; +} + void RXOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional RXOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp index 667720a792..58b2cf1c7e 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp @@ -56,18 +56,6 @@ struct MergeSwappedTargetsRXX final : OpRewritePattern { } }; -/** - * @brief Remove trivial RXX operations. - */ -struct RemoveTrivialRXX final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RXXOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void RXXOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -78,10 +66,20 @@ void RXXOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand); } +LogicalResult RXXOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void RXXOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add( - context); + results.add(context); } std::optional RXXOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp index 634fd106a7..3b9255250a 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp @@ -43,18 +43,6 @@ struct MergeSubsequentRY final : OpRewritePattern { } }; -/** - * @brief Remove trivial RY operations. - */ -struct RemoveTrivialRY final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RYOp op, - PatternRewriter& rewriter) const override { - return removeTrivialOneTargetOneParameter(op, rewriter); - } -}; - } // namespace void RYOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, @@ -64,9 +52,17 @@ void RYOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, build(odsBuilder, odsState, qubitIn, thetaOperand); } +OpFoldResult RYOp::fold(FoldAdaptor /*adaptor*/) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + return getInputQubit(0); + } + return {}; +} + void RYOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional RYOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp index 7709c872e6..18d069263a 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp @@ -56,18 +56,6 @@ struct MergeSwappedTargetsRYY final : OpRewritePattern { } }; -/** - * @brief Remove trivial RYY operations. - */ -struct RemoveTrivialRYY final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RYYOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void RYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -78,10 +66,20 @@ void RYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand); } +LogicalResult RYYOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void RYYOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add( - context); + results.add(context); } std::optional RYYOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp index d888a36689..bd789e770f 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp @@ -42,18 +42,6 @@ struct MergeSubsequentRZ final : OpRewritePattern { } }; -/** - * @brief Remove trivial RZ operations. - */ -struct RemoveTrivialRZ final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RZOp op, - PatternRewriter& rewriter) const override { - return removeTrivialOneTargetOneParameter(op, rewriter); - } -}; - } // namespace void RZOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, @@ -63,9 +51,17 @@ void RZOp::build(OpBuilder& odsBuilder, OperationState& odsState, Value qubitIn, build(odsBuilder, odsState, qubitIn, thetaOperand); } +OpFoldResult RZOp::fold(FoldAdaptor /*adaptor*/) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + return getInputQubit(0); + } + return {}; +} + void RZOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional RZOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp index a9b9770fff..70bbb999fb 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp @@ -43,18 +43,6 @@ struct MergeSubsequentRZX final : OpRewritePattern { } }; -/** - * @brief Remove trivial RZX operations. - */ -struct RemoveTrivialRZX final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RZXOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void RZXOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -65,9 +53,20 @@ void RZXOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand); } +LogicalResult RZXOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void RZXOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional RZXOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp index 85ea44a7da..2132bfdd1f 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp @@ -55,18 +55,6 @@ struct MergeSwappedTargetsRZZ final : OpRewritePattern { } }; -/** - * @brief Remove trivial RZZ operations. - */ -struct RemoveTrivialRZZ final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RZZOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void RZZOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -77,10 +65,20 @@ void RZZOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand); } +LogicalResult RZZOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void RZZOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add( - context); + results.add(context); } std::optional RZZOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp index 221f8d6d3c..6e40c801c3 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp @@ -77,18 +77,6 @@ struct MergeSubsequentXXMinusYY final : OpRewritePattern { } }; -/** - * @brief Remove trivial XXMinusYY operations. - */ -struct RemoveTrivialXXMinusYY final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(XXMinusYYOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void XXMinusYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -101,9 +89,20 @@ void XXMinusYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand, betaOperand); } +LogicalResult XXMinusYYOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void XXMinusYYOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional XXMinusYYOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp index d920e4777f..010913f98a 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp @@ -76,18 +76,6 @@ struct MergeSubsequentXXPlusYY final : OpRewritePattern { } }; -/** - * @brief Remove trivial XXPlusYY operations. - */ -struct RemoveTrivialXXPlusYY final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(XXPlusYYOp op, - PatternRewriter& rewriter) const override { - return removeTrivialTwoTargetOneParameter(op, rewriter); - } -}; - } // namespace void XXPlusYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -100,9 +88,20 @@ void XXPlusYYOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubit0In, qubit1In, thetaOperand, betaOperand); } +LogicalResult XXPlusYYOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (const auto theta = valueToDouble(getTheta()); + theta && std::abs(*theta) <= TOLERANCE) { + results.emplace_back(getInputQubit(0)); + results.emplace_back(getInputQubit(1)); + return success(); + } + return failure(); +} + void XXPlusYYOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional XXPlusYYOp::getUnitaryMatrix() { From f2ac639482b298dc46977207a6ff88505b431299 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 13:52:32 +0200 Subject: [PATCH 39/71] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Expand=20QTensor=20c?= =?UTF-8?q?ommutation=20canonicalizations=20to=20also=20handle=20slices?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- .../mlir/Dialect/QTensor/IR/QTensorOps.td | 2 + .../QTensor/IR/Operations/ExtractOp.cpp | 110 ++++++--- .../QTensor/IR/Operations/ExtractSliceOp.cpp | 223 +++++++++++++++++- .../QTensor/IR/Operations/InsertOp.cpp | 44 ++++ .../QTensor/IR/Operations/InsertSliceOp.cpp | 177 +++++++++++++- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 72 ++++++ 6 files changed, 597 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index 29227d12a8..d7233f8460 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -176,6 +176,7 @@ def ExtractSliceOp let builders = [OpBuilder<(ins "Value":$tensor, "Value":$offset, "Value":$size, CArg<"ArrayRef", "{}">:$attrs)>]; + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -245,6 +246,7 @@ def InsertSliceOp attr-dict `:` type($source) `into` type($dest) }]; + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index e17f6e6172..0d4b081dec 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -19,6 +19,8 @@ #include #include +#include + using namespace mlir; using namespace mlir::qtensor; @@ -37,6 +39,8 @@ LogicalResult ExtractOp::verify() { return success(); } +enum class AccessRelation : std::uint8_t { Disjoint, Overlap, Unknown }; + /** * @brief Checks whether two index values are equivalent for matching. */ @@ -45,11 +49,32 @@ static bool areEquivalentIndices(Value lhs, Value rhs) { } /** - * @brief Tensor-transforming ops in a chain that can commute past - * `qtensor.extract` at a different index. + * @brief Classify the relation between a scalar index and a slice range. + */ +static AccessRelation classifyIndexAndRange(Value index, Value offset, + Value size) { + if (areEquivalentIndices(index, offset)) { + return AccessRelation::Overlap; + } + + const auto indexValue = getConstantIntValue(index); + const auto offsetValue = getConstantIntValue(offset); + const auto sizeValue = getConstantIntValue(size); + if (!indexValue || !offsetValue || !sizeValue) { + return AccessRelation::Unknown; + } + + if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { + return AccessRelation::Disjoint; + } + return AccessRelation::Overlap; +} + +/** + * @brief Tensor-transforming ops in a chain that can commute with extracts. */ static bool isTensorChainOp(Operation* op) { - return llvm::isa(op); + return llvm::isa(op); } /** @@ -62,9 +87,47 @@ static Value getTensorChainInput(Operation* op) { if (auto extractOp = llvm::dyn_cast(op)) { return extractOp.getTensor(); } + if (auto insertSliceOp = llvm::dyn_cast(op)) { + return insertSliceOp.getDest(); + } + if (auto extractSliceOp = llvm::dyn_cast(op)) { + return extractSliceOp.getTensor(); + } return nullptr; } +/** + * @brief Returns the tensor output of a tensor-transforming op. + */ +static Value getTensorChainOutput(Operation* op) { + if (auto insertOp = llvm::dyn_cast(op)) { + return insertOp.getResult(); + } + if (auto extractOp = llvm::dyn_cast(op)) { + return extractOp.getOutTensor(); + } + if (auto insertSliceOp = llvm::dyn_cast(op)) { + return insertSliceOp.getResult(); + } + if (auto extractSliceOp = llvm::dyn_cast(op)) { + return extractSliceOp.getOutTensor(); + } + return nullptr; +} + +/** + * @brief Rewire the tensor input of a tensor-transforming op. + */ +static void setTensorChainInput(Operation* op, Value tensor) { + if (llvm::isa(op)) { + op->setOperand(1, tensor); + return; + } + if (llvm::isa(op)) { + op->setOperand(0, tensor); + } +} + /** * @brief If an ExtractOp consumes an InsertOp with the same index, * return the scalar and the destTensor from the InsertOp directly. @@ -96,32 +159,12 @@ LogicalResult ExtractOp::fold(FoldAdaptor /*adaptor*/, namespace { /** - * @brief Remove matching insert-extract pairs through commuting tensor-chain - * operations on different indices. + * @brief Remove matching insert-extract pairs through commuting disjoint + * tensor-chain operations. */ struct RemoveInsertExtractPair final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - static Value getTensorChainOutput(Operation* op) { - if (auto insertOp = llvm::dyn_cast(op)) { - return insertOp.getResult(); - } - if (auto nestedExtractOp = llvm::dyn_cast(op)) { - return nestedExtractOp.getOutTensor(); - } - return nullptr; - } - - static void setTensorChainInput(Operation* op, Value tensor) { - if (llvm::isa(op)) { - op->setOperand(1, tensor); - return; - } - if (llvm::isa(op)) { - op->setOperand(0, tensor); - } - } - LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter& rewriter) const override { llvm::SmallVector traversedOps; @@ -138,13 +181,26 @@ struct RemoveInsertExtractPair final : OpRewritePattern { matchedInsertOp = insertOp; break; } - } else { - auto nestedExtractOp = llvm::cast(definingOp); + } else if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { if (areEquivalentIndices(nestedExtractOp.getIndex(), extractOp.getIndex())) { // Do not reorder reads from the same index. return failure(); } + } else if (auto insertSliceOp = + llvm::dyn_cast(definingOp)) { + if (classifyIndexAndRange( + extractOp.getIndex(), insertSliceOp.getOffset(), + insertSliceOp.getSize()) != AccessRelation::Disjoint) { + return failure(); + } + } else if (auto extractSliceOp = + llvm::dyn_cast(definingOp)) { + if (classifyIndexAndRange( + extractOp.getIndex(), extractSliceOp.getOffset(), + extractSliceOp.getSize()) != AccessRelation::Disjoint) { + return failure(); + } } traversedOps.push_back(definingOp); diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp index 3c7648b04d..5d39329152 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp @@ -10,19 +10,151 @@ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include #include #include #include #include #include +#include #include #include +#include #include #include +#include + using namespace mlir; using namespace mlir::qtensor; +enum class RangeRelation : std::uint8_t { Disjoint, Overlap, Equal, Unknown }; + +/** + * @brief Checks whether two index values are equivalent for matching. + */ +static bool areEquivalentIndices(Value lhs, Value rhs) { + return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); +} + +/** + * @brief Checks whether two slice ranges are equivalent for matching. + */ +static bool areEquivalentRanges(Value lhsOffset, Value lhsSize, Value rhsOffset, + Value rhsSize) { + return areEquivalentIndices(lhsOffset, rhsOffset) && + areEquivalentIndices(lhsSize, rhsSize); +} + +/** + * @brief Classify the relation between two slice ranges. + */ +static RangeRelation classifyRanges(Value lhsOffset, Value lhsSize, + Value rhsOffset, Value rhsSize) { + if (areEquivalentRanges(lhsOffset, lhsSize, rhsOffset, rhsSize)) { + return RangeRelation::Equal; + } + + const auto lhsOffsetValue = getConstantIntValue(lhsOffset); + const auto lhsSizeValue = getConstantIntValue(lhsSize); + const auto rhsOffsetValue = getConstantIntValue(rhsOffset); + const auto rhsSizeValue = getConstantIntValue(rhsSize); + if (!lhsOffsetValue || !lhsSizeValue || !rhsOffsetValue || !rhsSizeValue) { + if (areEquivalentIndices(lhsOffset, rhsOffset)) { + return RangeRelation::Overlap; + } + return RangeRelation::Unknown; + } + + const auto lhsEnd = *lhsOffsetValue + *lhsSizeValue; + const auto rhsEnd = *rhsOffsetValue + *rhsSizeValue; + if (lhsEnd <= *rhsOffsetValue || rhsEnd <= *lhsOffsetValue) { + return RangeRelation::Disjoint; + } + return RangeRelation::Overlap; +} + +/** + * @brief Classify the relation between a scalar index and a slice range. + */ +static RangeRelation classifyIndexAndRange(Value index, Value offset, + Value size) { + if (areEquivalentIndices(index, offset)) { + return RangeRelation::Overlap; + } + + const auto indexValue = getConstantIntValue(index); + const auto offsetValue = getConstantIntValue(offset); + const auto sizeValue = getConstantIntValue(size); + if (!indexValue || !offsetValue || !sizeValue) { + return RangeRelation::Unknown; + } + + if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { + return RangeRelation::Disjoint; + } + return RangeRelation::Overlap; +} + +/** + * @brief Tensor-transforming ops in a chain that can commute with slice + * extracts. + */ +static bool isTensorChainOp(Operation* op) { + return llvm::isa(op); +} + +/** + * @brief Returns the tensor input of a tensor-transforming op. + */ +static Value getTensorChainInput(Operation* op) { + if (auto insertOp = llvm::dyn_cast(op)) { + return insertOp.getDest(); + } + if (auto extractOp = llvm::dyn_cast(op)) { + return extractOp.getTensor(); + } + if (auto insertSliceOp = llvm::dyn_cast(op)) { + return insertSliceOp.getDest(); + } + if (auto extractSliceOp = llvm::dyn_cast(op)) { + return extractSliceOp.getTensor(); + } + return nullptr; +} + +/** + * @brief Returns the tensor output of a tensor-transforming op. + */ +static Value getTensorChainOutput(Operation* op) { + if (auto insertOp = llvm::dyn_cast(op)) { + return insertOp.getResult(); + } + if (auto extractOp = llvm::dyn_cast(op)) { + return extractOp.getOutTensor(); + } + if (auto insertSliceOp = llvm::dyn_cast(op)) { + return insertSliceOp.getResult(); + } + if (auto extractSliceOp = llvm::dyn_cast(op)) { + return extractSliceOp.getOutTensor(); + } + return nullptr; +} + +/** + * @brief Rewire the tensor input of a tensor-transforming op. + */ +static void setTensorChainInput(Operation* op, Value tensor) { + if (llvm::isa(op)) { + op->setOperand(1, tensor); + return; + } + if (llvm::isa(op)) { + op->setOperand(0, tensor); + } +} + void ExtractSliceOp::build(OpBuilder& b, OperationState& result, Value tensor, Value offset, Value size, ArrayRef attrs) { @@ -83,8 +215,8 @@ foldExtractAfterInsertSlice(ExtractSliceOp extractSliceOp) { auto insertSize = insertSliceOp.getSize(); auto extractSize = extractSliceOp.getSize(); - if (getAsOpFoldResult(insertOffset) != getAsOpFoldResult(extractOffset) || - getAsOpFoldResult(insertSize) != getAsOpFoldResult(extractSize)) { + if (!areEquivalentRanges(insertOffset, insertSize, extractOffset, + extractSize)) { return nullptr; } @@ -101,3 +233,90 @@ LogicalResult ExtractSliceOp::fold(FoldAdaptor /*adaptor*/, return failure(); } + +namespace { + +/** + * @brief Remove matching insert_slice-extract_slice pairs through commuting + * disjoint tensor-chain operations. + */ +struct RemoveInsertSliceExtractSlicePair final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp, + PatternRewriter& rewriter) const override { + llvm::SmallVector traversedOps; + Value currentTensor = extractSliceOp.getTensor(); + InsertSliceOp matchedInsertSliceOp = nullptr; + + while (auto* definingOp = currentTensor.getDefiningOp()) { + if (!isTensorChainOp(definingOp)) { + break; + } + + if (auto insertSliceOp = llvm::dyn_cast(definingOp)) { + const auto relation = classifyRanges( + insertSliceOp.getOffset(), insertSliceOp.getSize(), + extractSliceOp.getOffset(), extractSliceOp.getSize()); + if (relation == RangeRelation::Equal) { + matchedInsertSliceOp = insertSliceOp; + break; + } + if (relation != RangeRelation::Disjoint) { + return failure(); + } + } else if (auto insertOp = llvm::dyn_cast(definingOp)) { + if (classifyIndexAndRange( + insertOp.getIndex(), extractSliceOp.getOffset(), + extractSliceOp.getSize()) != RangeRelation::Disjoint) { + return failure(); + } + } else if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { + if (classifyIndexAndRange( + nestedExtractOp.getIndex(), extractSliceOp.getOffset(), + extractSliceOp.getSize()) != RangeRelation::Disjoint) { + return failure(); + } + } else if (auto nestedExtractSliceOp = + llvm::dyn_cast(definingOp)) { + if (classifyRanges( + nestedExtractSliceOp.getOffset(), + nestedExtractSliceOp.getSize(), extractSliceOp.getOffset(), + extractSliceOp.getSize()) != RangeRelation::Disjoint) { + return failure(); + } + } + + traversedOps.push_back(definingOp); + currentTensor = getTensorChainInput(definingOp); + } + + if (!matchedInsertSliceOp) { + return failure(); + } + + Value outTensor = matchedInsertSliceOp.getDest(); + if (!traversedOps.empty()) { + Operation* oldestCommutedOp = traversedOps.back(); + rewriter.modifyOpInPlace(oldestCommutedOp, [&]() { + setTensorChainInput(oldestCommutedOp, matchedInsertSliceOp.getDest()); + }); + outTensor = getTensorChainOutput(traversedOps.front()); + if (!outTensor) { + return failure(); + } + } + + rewriter.replaceOp(extractSliceOp, + {outTensor, matchedInsertSliceOp.getSource()}); + return success(); + } +}; + +} // namespace + +void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 31c07bca09..9327f676b0 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -20,9 +20,13 @@ #include #include +#include + using namespace mlir; using namespace mlir::qtensor; +enum class AccessRelation : std::uint8_t { Disjoint, Overlap, Unknown }; + /** * @brief Checks whether two index values are equivalent for matching. */ @@ -30,6 +34,28 @@ static bool areEquivalentIndices(Value lhs, Value rhs) { return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); } +/** + * @brief Classify the relation between a scalar index and a slice range. + */ +static AccessRelation classifyIndexAndRange(Value index, Value offset, + Value size) { + if (areEquivalentIndices(index, offset)) { + return AccessRelation::Overlap; + } + + const auto indexValue = getConstantIntValue(index); + const auto offsetValue = getConstantIntValue(offset); + const auto sizeValue = getConstantIntValue(size); + if (!indexValue || !offsetValue || !sizeValue) { + return AccessRelation::Unknown; + } + + if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { + return AccessRelation::Disjoint; + } + return AccessRelation::Overlap; +} + /** * @brief Checks whether removing an extract-insert pair is linearity-safe. */ @@ -83,6 +109,15 @@ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { current = nestedInsertOp.getDest(); continue; } + if (auto nestedInsertSliceOp = llvm::dyn_cast(definingOp)) { + if (classifyIndexAndRange(index, nestedInsertSliceOp.getOffset(), + nestedInsertSliceOp.getSize()) != + AccessRelation::Disjoint) { + return nullptr; + } + current = nestedInsertSliceOp.getDest(); + continue; + } if (auto extractOp = llvm::dyn_cast(definingOp)) { if (areEquivalentIndices(extractOp.getIndex(), index)) { @@ -91,6 +126,15 @@ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { current = extractOp.getTensor(); continue; } + if (auto extractSliceOp = llvm::dyn_cast(definingOp)) { + if (classifyIndexAndRange(index, extractSliceOp.getOffset(), + extractSliceOp.getSize()) != + AccessRelation::Disjoint) { + return nullptr; + } + current = extractSliceOp.getTensor(); + continue; + } break; } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp index 1a7b6526ab..009aaec400 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp @@ -10,16 +10,154 @@ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include #include #include +#include #include #include +#include #include #include +#include + using namespace mlir; using namespace mlir::qtensor; +enum class RangeRelation : std::uint8_t { Disjoint, Overlap, Equal, Unknown }; + +/** + * @brief Checks whether two index values are equivalent for matching. + */ +static bool areEquivalentIndices(Value lhs, Value rhs) { + return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); +} + +/** + * @brief Checks whether two slice ranges are equivalent for matching. + */ +static bool areEquivalentRanges(Value lhsOffset, Value lhsSize, Value rhsOffset, + Value rhsSize) { + return areEquivalentIndices(lhsOffset, rhsOffset) && + areEquivalentIndices(lhsSize, rhsSize); +} + +/** + * @brief Classify the relation between a scalar index and a slice range. + */ +static RangeRelation classifyIndexAndRange(Value index, Value offset, + Value size) { + if (areEquivalentIndices(index, offset)) { + return RangeRelation::Overlap; + } + + const auto indexValue = getConstantIntValue(index); + const auto offsetValue = getConstantIntValue(offset); + const auto sizeValue = getConstantIntValue(size); + if (!indexValue || !offsetValue || !sizeValue) { + return RangeRelation::Unknown; + } + + if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { + return RangeRelation::Disjoint; + } + return RangeRelation::Overlap; +} + +/** + * @brief Classify the relation between two slice ranges. + */ +static RangeRelation classifyRanges(Value lhsOffset, Value lhsSize, + Value rhsOffset, Value rhsSize) { + if (areEquivalentRanges(lhsOffset, lhsSize, rhsOffset, rhsSize)) { + return RangeRelation::Equal; + } + + const auto lhsOffsetValue = getConstantIntValue(lhsOffset); + const auto lhsSizeValue = getConstantIntValue(lhsSize); + const auto rhsOffsetValue = getConstantIntValue(rhsOffset); + const auto rhsSizeValue = getConstantIntValue(rhsSize); + if (!lhsOffsetValue || !lhsSizeValue || !rhsOffsetValue || !rhsSizeValue) { + if (areEquivalentIndices(lhsOffset, rhsOffset)) { + return RangeRelation::Overlap; + } + return RangeRelation::Unknown; + } + + const auto lhsEnd = *lhsOffsetValue + *lhsSizeValue; + const auto rhsEnd = *rhsOffsetValue + *rhsSizeValue; + if (lhsEnd <= *rhsOffsetValue || rhsEnd <= *lhsOffsetValue) { + return RangeRelation::Disjoint; + } + return RangeRelation::Overlap; +} + +/** + * @brief Checks whether removing an extract_slice-insert_slice pair is + * linearity-safe. + */ +static bool +isRemovableExtractSliceInsertSlicePair(InsertSliceOp insertSliceOp, + ExtractSliceOp extractSliceOp) { + return insertSliceOp.getSource() == extractSliceOp.getResult() && + areEquivalentRanges(insertSliceOp.getOffset(), insertSliceOp.getSize(), + extractSliceOp.getOffset(), + extractSliceOp.getSize()); +} + +/** + * @brief Find a matching `qtensor.extract_slice` for an insert_slice range in + * a tensor chain by traversing scalar and slice tensor operations. + */ +static ExtractSliceOp +findMatchingExtractSliceInTensorChain(Value tensor, Value offset, Value size) { + Value current = tensor; + while (Operation* definingOp = current.getDefiningOp()) { + if (auto nestedInsertOp = llvm::dyn_cast(definingOp)) { + if (classifyIndexAndRange(nestedInsertOp.getIndex(), offset, size) != + RangeRelation::Disjoint) { + return nullptr; + } + current = nestedInsertOp.getDest(); + continue; + } + if (auto nestedInsertSliceOp = llvm::dyn_cast(definingOp)) { + if (classifyRanges(nestedInsertSliceOp.getOffset(), + nestedInsertSliceOp.getSize(), offset, + size) != RangeRelation::Disjoint) { + return nullptr; + } + current = nestedInsertSliceOp.getDest(); + continue; + } + if (auto extractOp = llvm::dyn_cast(definingOp)) { + if (classifyIndexAndRange(extractOp.getIndex(), offset, size) != + RangeRelation::Disjoint) { + return nullptr; + } + current = extractOp.getTensor(); + continue; + } + if (auto extractSliceOp = llvm::dyn_cast(definingOp)) { + const auto relation = classifyRanges( + extractSliceOp.getOffset(), extractSliceOp.getSize(), offset, size); + if (relation == RangeRelation::Equal) { + return extractSliceOp; + } + if (relation != RangeRelation::Disjoint) { + return nullptr; + } + current = extractSliceOp.getTensor(); + continue; + } + + break; + } + + return nullptr; +} + LogicalResult InsertSliceOp::verify() { auto srcDim = getSource().getType().getDimSize(0); auto dstDim = getDest().getType().getDimSize(0); @@ -69,8 +207,8 @@ static Value foldInsertAfterExtractSlice(InsertSliceOp insertSliceOp) { auto insertSize = insertSliceOp.getSize(); auto extractSize = extractSliceOp.getSize(); - if (getAsOpFoldResult(insertOffset) != getAsOpFoldResult(extractOffset) || - getAsOpFoldResult(insertSize) != getAsOpFoldResult(extractSize)) { + if (!areEquivalentRanges(insertOffset, insertSize, extractOffset, + extractSize)) { return nullptr; } @@ -84,3 +222,38 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor /*adaptor*/) { return {}; } + +namespace { + +/** + * @brief Remove matching `qtensor.insert_slice` and `qtensor.extract_slice` + * pairs through commuting disjoint tensor-chain operations. + */ +struct RemoveExtractSliceInsertSlicePair final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertSliceOp op, + PatternRewriter& rewriter) const override { + auto extractSliceOp = findMatchingExtractSliceInTensorChain( + op.getDest(), op.getOffset(), op.getSize()); + if (!extractSliceOp) { + return failure(); + } + + if (!isRemovableExtractSliceInsertSlicePair(op, extractSliceOp)) { + return failure(); + } + + rewriter.replaceOp(op, op.getDest()); + rewriter.replaceOp(extractSliceOp, {extractSliceOp.getTensor(), nullptr}); + return success(); + } +}; + +} // namespace + +void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 5173489c14..9fb458a9df 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -140,6 +140,40 @@ buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, return builder.finalize(); } +OwningOpRef +buildMixedScalarSliceInsertProgram(MLIRContext* context, + const bool reverseOrder, const bool overlap, + const bool mutateScalar) { + qco::QCOProgramBuilder builder(context); + builder.initialize(); + + auto tensor = builder.qtensorAlloc(6); + auto [tensorAfterSliceExtract, slice] = + builder.qtensorExtractSlice(tensor, 1, 2); + const int64_t scalarIndex = overlap ? 1 : 5; + auto [tensorAfterScalarExtract, scalar] = + builder.qtensorExtract(tensorAfterSliceExtract, scalarIndex); + if (mutateScalar) { + scalar = builder.h(scalar); + } + + Value tensorAfterWrites = tensorAfterScalarExtract; + if (reverseOrder) { + tensorAfterWrites = + builder.qtensorInsertSlice(slice, tensorAfterWrites, 1, 2); + tensorAfterWrites = + builder.qtensorInsert(scalar, tensorAfterWrites, scalarIndex); + } else { + tensorAfterWrites = + builder.qtensorInsert(scalar, tensorAfterWrites, scalarIndex); + tensorAfterWrites = + builder.qtensorInsertSlice(slice, tensorAfterWrites, 1, 2); + } + + builder.qtensorDealloc(tensorAfterWrites); + return builder.finalize(); +} + OwningOpRef buildResetWithCommutingInsertProgram(MLIRContext* context, const bool withReset) { @@ -285,6 +319,44 @@ TEST_F(QCOTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } +TEST_F(QCOTest, MixedScalarSliceInsertPermutationEquivalence) { + auto program = + buildMixedScalarSliceInsertProgram(context.get(), false, false, false); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + runCanonicalizationPasses(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = + buildMixedScalarSliceInsertProgram(context.get(), true, false, false); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QCOTest, MixedScalarSliceInsertOverlapNotEquivalent) { + auto program = + buildMixedScalarSliceInsertProgram(context.get(), false, true, true); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + runCanonicalizationPasses(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = + buildMixedScalarSliceInsertProgram(context.get(), true, true, true); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + runCanonicalizationPasses(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_FALSE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + TEST_F(QCOTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { auto program = buildResetWithCommutingInsertProgram(context.get(), true); ASSERT_TRUE(program); From a2ab458759b37fd22fdde5b7d2f88630ec481e34 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 14:04:48 +0200 Subject: [PATCH 40/71] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Extract=20common=20c?= =?UTF-8?q?ode=20from=20QTensor=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- .../mlir/Dialect/QTensor/IR/QTensorUtils.h | 152 ++++++++++++++++++ .../QTensor/IR/Operations/ExtractOp.cpp | 92 +---------- .../QTensor/IR/Operations/ExtractSliceOp.cpp | 150 ++--------------- .../QTensor/IR/Operations/InsertOp.cpp | 36 +---- .../QTensor/IR/Operations/InsertSliceOp.cpp | 91 ++--------- 5 files changed, 173 insertions(+), 348 deletions(-) create mode 100644 mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h new file mode 100644 index 0000000000..145e5e017b --- /dev/null +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" + +#include +#include +#include + +#include + +namespace mlir::qtensor { + +/** + * @brief Relation of two tensor accesses. + */ +enum class AccessRelation : std::uint8_t { Disjoint, Overlap, Equal, Unknown }; + +/** + * @brief Checks whether two index values are equivalent for matching. + */ +inline bool areEquivalentIndices(Value lhs, Value rhs) { + return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); +} + +/** + * @brief Checks whether two slice ranges are equivalent for matching. + */ +inline bool areEquivalentRanges(Value lhsOffset, Value lhsSize, Value rhsOffset, + Value rhsSize) { + return areEquivalentIndices(lhsOffset, rhsOffset) && + areEquivalentIndices(lhsSize, rhsSize); +} + +/** + * @brief Classify the relation between a scalar index and a slice range. + */ +inline AccessRelation classifyIndexAndRange(Value index, Value offset, + Value size) { + if (areEquivalentIndices(index, offset)) { + return AccessRelation::Overlap; + } + + const auto indexValue = getConstantIntValue(index); + const auto offsetValue = getConstantIntValue(offset); + const auto sizeValue = getConstantIntValue(size); + if (!indexValue || !offsetValue || !sizeValue) { + return AccessRelation::Unknown; + } + + if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { + return AccessRelation::Disjoint; + } + return AccessRelation::Overlap; +} + +/** + * @brief Classify the relation between two slice ranges. + */ +inline AccessRelation classifyRanges(Value lhsOffset, Value lhsSize, + Value rhsOffset, Value rhsSize) { + if (areEquivalentRanges(lhsOffset, lhsSize, rhsOffset, rhsSize)) { + return AccessRelation::Equal; + } + + const auto lhsOffsetValue = getConstantIntValue(lhsOffset); + const auto lhsSizeValue = getConstantIntValue(lhsSize); + const auto rhsOffsetValue = getConstantIntValue(rhsOffset); + const auto rhsSizeValue = getConstantIntValue(rhsSize); + if (!lhsOffsetValue || !lhsSizeValue || !rhsOffsetValue || !rhsSizeValue) { + if (areEquivalentIndices(lhsOffset, rhsOffset)) { + return AccessRelation::Overlap; + } + return AccessRelation::Unknown; + } + + const auto lhsEnd = *lhsOffsetValue + *lhsSizeValue; + const auto rhsEnd = *rhsOffsetValue + *rhsSizeValue; + if (lhsEnd <= *rhsOffsetValue || rhsEnd <= *lhsOffsetValue) { + return AccessRelation::Disjoint; + } + return AccessRelation::Overlap; +} + +/** + * @brief Tensor-transforming ops in a chain that can commute by index/range. + */ +inline bool isTensorChainOp(Operation* op) { + return llvm::isa(op); +} + +/** + * @brief Returns the tensor input of a tensor-transforming op. + */ +inline Value getTensorChainInput(Operation* op) { + if (auto insertOp = llvm::dyn_cast(op)) { + return insertOp.getDest(); + } + if (auto extractOp = llvm::dyn_cast(op)) { + return extractOp.getTensor(); + } + if (auto insertSliceOp = llvm::dyn_cast(op)) { + return insertSliceOp.getDest(); + } + if (auto extractSliceOp = llvm::dyn_cast(op)) { + return extractSliceOp.getTensor(); + } + return nullptr; +} + +/** + * @brief Returns the tensor output of a tensor-transforming op. + */ +inline Value getTensorChainOutput(Operation* op) { + if (auto insertOp = llvm::dyn_cast(op)) { + return insertOp.getResult(); + } + if (auto extractOp = llvm::dyn_cast(op)) { + return extractOp.getOutTensor(); + } + if (auto insertSliceOp = llvm::dyn_cast(op)) { + return insertSliceOp.getResult(); + } + if (auto extractSliceOp = llvm::dyn_cast(op)) { + return extractSliceOp.getOutTensor(); + } + return nullptr; +} + +/** + * @brief Rewire the tensor input of a tensor-transforming op. + */ +inline void setTensorChainInput(Operation* op, Value tensor) { + if (llvm::isa(op)) { + op->setOperand(1, tensor); + return; + } + if (llvm::isa(op)) { + op->setOperand(0, tensor); + } +} + +} // namespace mlir::qtensor diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 0d4b081dec..2b4874fb28 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" #include #include @@ -19,8 +20,6 @@ #include #include -#include - using namespace mlir; using namespace mlir::qtensor; @@ -39,95 +38,6 @@ LogicalResult ExtractOp::verify() { return success(); } -enum class AccessRelation : std::uint8_t { Disjoint, Overlap, Unknown }; - -/** - * @brief Checks whether two index values are equivalent for matching. - */ -static bool areEquivalentIndices(Value lhs, Value rhs) { - return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); -} - -/** - * @brief Classify the relation between a scalar index and a slice range. - */ -static AccessRelation classifyIndexAndRange(Value index, Value offset, - Value size) { - if (areEquivalentIndices(index, offset)) { - return AccessRelation::Overlap; - } - - const auto indexValue = getConstantIntValue(index); - const auto offsetValue = getConstantIntValue(offset); - const auto sizeValue = getConstantIntValue(size); - if (!indexValue || !offsetValue || !sizeValue) { - return AccessRelation::Unknown; - } - - if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { - return AccessRelation::Disjoint; - } - return AccessRelation::Overlap; -} - -/** - * @brief Tensor-transforming ops in a chain that can commute with extracts. - */ -static bool isTensorChainOp(Operation* op) { - return llvm::isa(op); -} - -/** - * @brief Returns the tensor input of a tensor-transforming op. - */ -static Value getTensorChainInput(Operation* op) { - if (auto insertOp = llvm::dyn_cast(op)) { - return insertOp.getDest(); - } - if (auto extractOp = llvm::dyn_cast(op)) { - return extractOp.getTensor(); - } - if (auto insertSliceOp = llvm::dyn_cast(op)) { - return insertSliceOp.getDest(); - } - if (auto extractSliceOp = llvm::dyn_cast(op)) { - return extractSliceOp.getTensor(); - } - return nullptr; -} - -/** - * @brief Returns the tensor output of a tensor-transforming op. - */ -static Value getTensorChainOutput(Operation* op) { - if (auto insertOp = llvm::dyn_cast(op)) { - return insertOp.getResult(); - } - if (auto extractOp = llvm::dyn_cast(op)) { - return extractOp.getOutTensor(); - } - if (auto insertSliceOp = llvm::dyn_cast(op)) { - return insertSliceOp.getResult(); - } - if (auto extractSliceOp = llvm::dyn_cast(op)) { - return extractSliceOp.getOutTensor(); - } - return nullptr; -} - -/** - * @brief Rewire the tensor input of a tensor-transforming op. - */ -static void setTensorChainInput(Operation* op, Value tensor) { - if (llvm::isa(op)) { - op->setOperand(1, tensor); - return; - } - if (llvm::isa(op)) { - op->setOperand(0, tensor); - } -} - /** * @brief If an ExtractOp consumes an InsertOp with the same index, * return the scalar and the destTensor from the InsertOp directly. diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp index 5d39329152..c3fdb0be8e 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" #include #include @@ -23,138 +24,9 @@ #include #include -#include - using namespace mlir; using namespace mlir::qtensor; -enum class RangeRelation : std::uint8_t { Disjoint, Overlap, Equal, Unknown }; - -/** - * @brief Checks whether two index values are equivalent for matching. - */ -static bool areEquivalentIndices(Value lhs, Value rhs) { - return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); -} - -/** - * @brief Checks whether two slice ranges are equivalent for matching. - */ -static bool areEquivalentRanges(Value lhsOffset, Value lhsSize, Value rhsOffset, - Value rhsSize) { - return areEquivalentIndices(lhsOffset, rhsOffset) && - areEquivalentIndices(lhsSize, rhsSize); -} - -/** - * @brief Classify the relation between two slice ranges. - */ -static RangeRelation classifyRanges(Value lhsOffset, Value lhsSize, - Value rhsOffset, Value rhsSize) { - if (areEquivalentRanges(lhsOffset, lhsSize, rhsOffset, rhsSize)) { - return RangeRelation::Equal; - } - - const auto lhsOffsetValue = getConstantIntValue(lhsOffset); - const auto lhsSizeValue = getConstantIntValue(lhsSize); - const auto rhsOffsetValue = getConstantIntValue(rhsOffset); - const auto rhsSizeValue = getConstantIntValue(rhsSize); - if (!lhsOffsetValue || !lhsSizeValue || !rhsOffsetValue || !rhsSizeValue) { - if (areEquivalentIndices(lhsOffset, rhsOffset)) { - return RangeRelation::Overlap; - } - return RangeRelation::Unknown; - } - - const auto lhsEnd = *lhsOffsetValue + *lhsSizeValue; - const auto rhsEnd = *rhsOffsetValue + *rhsSizeValue; - if (lhsEnd <= *rhsOffsetValue || rhsEnd <= *lhsOffsetValue) { - return RangeRelation::Disjoint; - } - return RangeRelation::Overlap; -} - -/** - * @brief Classify the relation between a scalar index and a slice range. - */ -static RangeRelation classifyIndexAndRange(Value index, Value offset, - Value size) { - if (areEquivalentIndices(index, offset)) { - return RangeRelation::Overlap; - } - - const auto indexValue = getConstantIntValue(index); - const auto offsetValue = getConstantIntValue(offset); - const auto sizeValue = getConstantIntValue(size); - if (!indexValue || !offsetValue || !sizeValue) { - return RangeRelation::Unknown; - } - - if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { - return RangeRelation::Disjoint; - } - return RangeRelation::Overlap; -} - -/** - * @brief Tensor-transforming ops in a chain that can commute with slice - * extracts. - */ -static bool isTensorChainOp(Operation* op) { - return llvm::isa(op); -} - -/** - * @brief Returns the tensor input of a tensor-transforming op. - */ -static Value getTensorChainInput(Operation* op) { - if (auto insertOp = llvm::dyn_cast(op)) { - return insertOp.getDest(); - } - if (auto extractOp = llvm::dyn_cast(op)) { - return extractOp.getTensor(); - } - if (auto insertSliceOp = llvm::dyn_cast(op)) { - return insertSliceOp.getDest(); - } - if (auto extractSliceOp = llvm::dyn_cast(op)) { - return extractSliceOp.getTensor(); - } - return nullptr; -} - -/** - * @brief Returns the tensor output of a tensor-transforming op. - */ -static Value getTensorChainOutput(Operation* op) { - if (auto insertOp = llvm::dyn_cast(op)) { - return insertOp.getResult(); - } - if (auto extractOp = llvm::dyn_cast(op)) { - return extractOp.getOutTensor(); - } - if (auto insertSliceOp = llvm::dyn_cast(op)) { - return insertSliceOp.getResult(); - } - if (auto extractSliceOp = llvm::dyn_cast(op)) { - return extractSliceOp.getOutTensor(); - } - return nullptr; -} - -/** - * @brief Rewire the tensor input of a tensor-transforming op. - */ -static void setTensorChainInput(Operation* op, Value tensor) { - if (llvm::isa(op)) { - op->setOperand(1, tensor); - return; - } - if (llvm::isa(op)) { - op->setOperand(0, tensor); - } -} - void ExtractSliceOp::build(OpBuilder& b, OperationState& result, Value tensor, Value offset, Value size, ArrayRef attrs) { @@ -210,13 +82,9 @@ foldExtractAfterInsertSlice(ExtractSliceOp extractSliceOp) { return nullptr; } - auto insertOffset = insertSliceOp.getOffset(); - auto extractOffset = extractSliceOp.getOffset(); - auto insertSize = insertSliceOp.getSize(); - auto extractSize = extractSliceOp.getSize(); - - if (!areEquivalentRanges(insertOffset, insertSize, extractOffset, - extractSize)) { + if (!areEquivalentRanges(insertSliceOp.getOffset(), insertSliceOp.getSize(), + extractSliceOp.getOffset(), + extractSliceOp.getSize())) { return nullptr; } @@ -259,23 +127,23 @@ struct RemoveInsertSliceExtractSlicePair final const auto relation = classifyRanges( insertSliceOp.getOffset(), insertSliceOp.getSize(), extractSliceOp.getOffset(), extractSliceOp.getSize()); - if (relation == RangeRelation::Equal) { + if (relation == AccessRelation::Equal) { matchedInsertSliceOp = insertSliceOp; break; } - if (relation != RangeRelation::Disjoint) { + if (relation != AccessRelation::Disjoint) { return failure(); } } else if (auto insertOp = llvm::dyn_cast(definingOp)) { if (classifyIndexAndRange( insertOp.getIndex(), extractSliceOp.getOffset(), - extractSliceOp.getSize()) != RangeRelation::Disjoint) { + extractSliceOp.getSize()) != AccessRelation::Disjoint) { return failure(); } } else if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { if (classifyIndexAndRange( nestedExtractOp.getIndex(), extractSliceOp.getOffset(), - extractSliceOp.getSize()) != RangeRelation::Disjoint) { + extractSliceOp.getSize()) != AccessRelation::Disjoint) { return failure(); } } else if (auto nestedExtractSliceOp = @@ -283,7 +151,7 @@ struct RemoveInsertSliceExtractSlicePair final if (classifyRanges( nestedExtractSliceOp.getOffset(), nestedExtractSliceOp.getSize(), extractSliceOp.getOffset(), - extractSliceOp.getSize()) != RangeRelation::Disjoint) { + extractSliceOp.getSize()) != AccessRelation::Disjoint) { return failure(); } } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 9327f676b0..c8a8839492 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" #include #include @@ -20,42 +21,9 @@ #include #include -#include - using namespace mlir; using namespace mlir::qtensor; -enum class AccessRelation : std::uint8_t { Disjoint, Overlap, Unknown }; - -/** - * @brief Checks whether two index values are equivalent for matching. - */ -static bool areEquivalentIndices(Value lhs, Value rhs) { - return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); -} - -/** - * @brief Classify the relation between a scalar index and a slice range. - */ -static AccessRelation classifyIndexAndRange(Value index, Value offset, - Value size) { - if (areEquivalentIndices(index, offset)) { - return AccessRelation::Overlap; - } - - const auto indexValue = getConstantIntValue(index); - const auto offsetValue = getConstantIntValue(offset); - const auto sizeValue = getConstantIntValue(size); - if (!indexValue || !offsetValue || !sizeValue) { - return AccessRelation::Unknown; - } - - if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { - return AccessRelation::Disjoint; - } - return AccessRelation::Overlap; -} - /** * @brief Checks whether removing an extract-insert pair is linearity-safe. */ @@ -96,7 +64,7 @@ OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { /** * @brief Find a matching `qtensor.extract` for an insert index in a tensor - * chain by traversing nested `qtensor.insert` and `qtensor.extract` ops. + * chain by traversing nested scalar and slice tensor ops. */ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { Value current = tensor; diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp index 009aaec400..6d2f6628a5 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" #include #include @@ -20,79 +21,9 @@ #include #include -#include - using namespace mlir; using namespace mlir::qtensor; -enum class RangeRelation : std::uint8_t { Disjoint, Overlap, Equal, Unknown }; - -/** - * @brief Checks whether two index values are equivalent for matching. - */ -static bool areEquivalentIndices(Value lhs, Value rhs) { - return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); -} - -/** - * @brief Checks whether two slice ranges are equivalent for matching. - */ -static bool areEquivalentRanges(Value lhsOffset, Value lhsSize, Value rhsOffset, - Value rhsSize) { - return areEquivalentIndices(lhsOffset, rhsOffset) && - areEquivalentIndices(lhsSize, rhsSize); -} - -/** - * @brief Classify the relation between a scalar index and a slice range. - */ -static RangeRelation classifyIndexAndRange(Value index, Value offset, - Value size) { - if (areEquivalentIndices(index, offset)) { - return RangeRelation::Overlap; - } - - const auto indexValue = getConstantIntValue(index); - const auto offsetValue = getConstantIntValue(offset); - const auto sizeValue = getConstantIntValue(size); - if (!indexValue || !offsetValue || !sizeValue) { - return RangeRelation::Unknown; - } - - if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { - return RangeRelation::Disjoint; - } - return RangeRelation::Overlap; -} - -/** - * @brief Classify the relation between two slice ranges. - */ -static RangeRelation classifyRanges(Value lhsOffset, Value lhsSize, - Value rhsOffset, Value rhsSize) { - if (areEquivalentRanges(lhsOffset, lhsSize, rhsOffset, rhsSize)) { - return RangeRelation::Equal; - } - - const auto lhsOffsetValue = getConstantIntValue(lhsOffset); - const auto lhsSizeValue = getConstantIntValue(lhsSize); - const auto rhsOffsetValue = getConstantIntValue(rhsOffset); - const auto rhsSizeValue = getConstantIntValue(rhsSize); - if (!lhsOffsetValue || !lhsSizeValue || !rhsOffsetValue || !rhsSizeValue) { - if (areEquivalentIndices(lhsOffset, rhsOffset)) { - return RangeRelation::Overlap; - } - return RangeRelation::Unknown; - } - - const auto lhsEnd = *lhsOffsetValue + *lhsSizeValue; - const auto rhsEnd = *rhsOffsetValue + *rhsSizeValue; - if (lhsEnd <= *rhsOffsetValue || rhsEnd <= *lhsOffsetValue) { - return RangeRelation::Disjoint; - } - return RangeRelation::Overlap; -} - /** * @brief Checks whether removing an extract_slice-insert_slice pair is * linearity-safe. @@ -116,7 +47,7 @@ findMatchingExtractSliceInTensorChain(Value tensor, Value offset, Value size) { while (Operation* definingOp = current.getDefiningOp()) { if (auto nestedInsertOp = llvm::dyn_cast(definingOp)) { if (classifyIndexAndRange(nestedInsertOp.getIndex(), offset, size) != - RangeRelation::Disjoint) { + AccessRelation::Disjoint) { return nullptr; } current = nestedInsertOp.getDest(); @@ -125,7 +56,7 @@ findMatchingExtractSliceInTensorChain(Value tensor, Value offset, Value size) { if (auto nestedInsertSliceOp = llvm::dyn_cast(definingOp)) { if (classifyRanges(nestedInsertSliceOp.getOffset(), nestedInsertSliceOp.getSize(), offset, - size) != RangeRelation::Disjoint) { + size) != AccessRelation::Disjoint) { return nullptr; } current = nestedInsertSliceOp.getDest(); @@ -133,7 +64,7 @@ findMatchingExtractSliceInTensorChain(Value tensor, Value offset, Value size) { } if (auto extractOp = llvm::dyn_cast(definingOp)) { if (classifyIndexAndRange(extractOp.getIndex(), offset, size) != - RangeRelation::Disjoint) { + AccessRelation::Disjoint) { return nullptr; } current = extractOp.getTensor(); @@ -142,10 +73,10 @@ findMatchingExtractSliceInTensorChain(Value tensor, Value offset, Value size) { if (auto extractSliceOp = llvm::dyn_cast(definingOp)) { const auto relation = classifyRanges( extractSliceOp.getOffset(), extractSliceOp.getSize(), offset, size); - if (relation == RangeRelation::Equal) { + if (relation == AccessRelation::Equal) { return extractSliceOp; } - if (relation != RangeRelation::Disjoint) { + if (relation != AccessRelation::Disjoint) { return nullptr; } current = extractSliceOp.getTensor(); @@ -202,13 +133,9 @@ static Value foldInsertAfterExtractSlice(InsertSliceOp insertSliceOp) { return nullptr; } - auto insertOffset = insertSliceOp.getOffset(); - auto extractOffset = extractSliceOp.getOffset(); - auto insertSize = insertSliceOp.getSize(); - auto extractSize = extractSliceOp.getSize(); - - if (!areEquivalentRanges(insertOffset, insertSize, extractOffset, - extractSize)) { + if (!areEquivalentRanges(insertSliceOp.getOffset(), insertSliceOp.getSize(), + extractSliceOp.getOffset(), + extractSliceOp.getSize())) { return nullptr; } From 3d800b990355a21773259a81b3d8885edfd1e3e7 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 14:38:43 +0200 Subject: [PATCH 41/71] =?UTF-8?q?=E2=8F=AA=20Revert=20changes=20to=20progr?= =?UTF-8?q?ams?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- .../Conversion/QCToQIR/test_qc_to_qir.cpp | 27 ++++++------- mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp | 38 +++++++++--------- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 39 ++++++++++--------- mlir/unittests/programs/qc_programs.cpp | 30 -------------- mlir/unittests/programs/qc_programs.h | 18 --------- mlir/unittests/programs/qco_programs.cpp | 5 --- mlir/unittests/programs/qco_programs.h | 3 -- mlir/unittests/programs/qir_programs.cpp | 24 ------------ mlir/unittests/programs/qir_programs.h | 20 ---------- 9 files changed, 51 insertions(+), 153 deletions(-) diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index 58d98f3dad..20b8687f9f 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -119,17 +119,16 @@ INSTANTIATE_TEST_SUITE_P( QCToQIRBarrierOpTest, QCToQIRTest, testing::Values( QCToQIRTestCase{"Barrier", MQT_NAMED_BUILDER(qc::barrier), - MQT_NAMED_BUILDER(qir::barrierConverted)}, + MQT_NAMED_BUILDER(qir::emptyQIR)}, QCToQIRTestCase{"BarrierTwoQubits", MQT_NAMED_BUILDER(qc::barrierTwoQubits), - MQT_NAMED_BUILDER(qir::barrierTwoQubitsConverted)}, + MQT_NAMED_BUILDER(qir::emptyQIR)}, QCToQIRTestCase{"BarrierMultipleQubits", MQT_NAMED_BUILDER(qc::barrierMultipleQubits), - MQT_NAMED_BUILDER(qir::barrierMultipleQubitsConverted)}, - QCToQIRTestCase{ - "SingleControlledBarrier", - MQT_NAMED_BUILDER(qc::singleControlledBarrier), - MQT_NAMED_BUILDER(qir::singleControlledBarrierConverted)})); + MQT_NAMED_BUILDER(qir::emptyQIR)}, + QCToQIRTestCase{"SingleControlledBarrier", + MQT_NAMED_BUILDER(qc::singleControlledBarrier), + MQT_NAMED_BUILDER(qir::emptyQIR)})); /// @} /// \name QCToQIR/Operations/StandardGates/DcxOp.cpp @@ -192,14 +191,12 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( QCToQIRTestCase{"Identity", MQT_NAMED_BUILDER(qc::identity), MQT_NAMED_BUILDER(qir::identity)}, - QCToQIRTestCase{ - "SingleControlledIdentity", - MQT_NAMED_BUILDER(qc::singleControlledIdentity), - MQT_NAMED_BUILDER(qir::singleControlledIdentityConverted)}, - QCToQIRTestCase{ - "MultipleControlledIdentity", - MQT_NAMED_BUILDER(qc::multipleControlledIdentity), - MQT_NAMED_BUILDER(qir::multipleControlledIdentityConverted)})); + QCToQIRTestCase{"SingleControlledIdentity", + MQT_NAMED_BUILDER(qc::singleControlledIdentity), + MQT_NAMED_BUILDER(qir::identity)}, + QCToQIRTestCase{"MultipleControlledIdentity", + MQT_NAMED_BUILDER(qc::multipleControlledIdentity), + MQT_NAMED_BUILDER(qir::identity)})); /// @} /// \name QCToQIR/Operations/StandardGates/IswapOp.cpp diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index 094a7a5dbb..2f42ed4d06 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -177,19 +177,20 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCBarrierOpTest, QCTest, - testing::Values( - QCTestCase{"Barrier", MQT_NAMED_BUILDER(barrier), - MQT_NAMED_BUILDER(barrier)}, - QCTestCase{"BarrierTwoQubits", MQT_NAMED_BUILDER(barrierTwoQubits), - MQT_NAMED_BUILDER(barrierTwoQubits)}, - QCTestCase{"BarrierMultipleQubits", - MQT_NAMED_BUILDER(barrierMultipleQubits), - MQT_NAMED_BUILDER(barrierMultipleQubits)}, - QCTestCase{"SingleControlledBarrier", - MQT_NAMED_BUILDER(singleControlledBarrier), - MQT_NAMED_BUILDER(singleControlledBarrierCanonicalized)}, - QCTestCase{"InverseBarrier", MQT_NAMED_BUILDER(inverseBarrier), - MQT_NAMED_BUILDER(barrier)})); + testing::Values(QCTestCase{"Barrier", MQT_NAMED_BUILDER(barrier), + MQT_NAMED_BUILDER(barrier)}, + QCTestCase{"BarrierTwoQubits", + MQT_NAMED_BUILDER(barrierTwoQubits), + MQT_NAMED_BUILDER(barrierTwoQubits)}, + QCTestCase{"BarrierMultipleQubits", + MQT_NAMED_BUILDER(barrierMultipleQubits), + MQT_NAMED_BUILDER(barrierMultipleQubits)}, + QCTestCase{"SingleControlledBarrier", + MQT_NAMED_BUILDER(singleControlledBarrier), + MQT_NAMED_BUILDER(barrier)}, + QCTestCase{"InverseBarrier", + MQT_NAMED_BUILDER(inverseBarrier), + MQT_NAMED_BUILDER(barrier)})); /// @} /// \name QC/Operations/StandardGates/DcxOp.cpp @@ -257,7 +258,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(multipleControlledP)}, QCTestCase{"NestedControlledGlobalPhase", MQT_NAMED_BUILDER(nestedControlledGlobalPhase), - MQT_NAMED_BUILDER(nestedControlledGlobalPhaseCanonicalized)}, + MQT_NAMED_BUILDER(singleControlledP)}, QCTestCase{"TrivialControlledGlobalPhase", MQT_NAMED_BUILDER(trivialControlledGlobalPhase), MQT_NAMED_BUILDER(globalPhase)}, @@ -299,13 +300,13 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(identity)}, QCTestCase{"SingleControlledIdentity", MQT_NAMED_BUILDER(singleControlledIdentity), - MQT_NAMED_BUILDER(singleControlledIdentityCanonicalized)}, + MQT_NAMED_BUILDER(identity)}, QCTestCase{"MultipleControlledIdentity", MQT_NAMED_BUILDER(multipleControlledIdentity), - MQT_NAMED_BUILDER(multipleControlledIdentityCanonicalized)}, + MQT_NAMED_BUILDER(identity)}, QCTestCase{"NestedControlledIdentity", MQT_NAMED_BUILDER(nestedControlledIdentity), - MQT_NAMED_BUILDER(nestedControlledIdentityCanonicalized)}, + MQT_NAMED_BUILDER(identity)}, QCTestCase{"TrivialControlledIdentity", MQT_NAMED_BUILDER(trivialControlledIdentity), MQT_NAMED_BUILDER(identity)}, @@ -313,8 +314,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(identity)}, QCTestCase{"InverseMultipleControlledIdentity", MQT_NAMED_BUILDER(inverseMultipleControlledIdentity), - MQT_NAMED_BUILDER( - inverseMultipleControlledIdentityCanonicalized)})); + MQT_NAMED_BUILDER(identity)})); /// @} /// \name QC/Operations/StandardGates/IswapOp.cpp diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 9fb458a9df..efa1a0f066 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -393,10 +393,10 @@ TEST_F(QCOTest, ResetAfterExtractThroughSameIndexInsertIsNotEliminated) { TEST_F(QCOTest, DirectIfBuilder) { // Test If construction directly - qco::QCOProgramBuilder builder(context.get()); + QCOProgramBuilder builder(context.get()); builder.initialize(); - auto c0 = arith::ConstantOp::create(builder, builder.getIndexAttr(0)); - auto c1 = arith::ConstantOp::create(builder, builder.getIndexAttr(1)); + auto c0 = arith::ConstantIndexOp::create(builder, 0); + auto c1 = arith::ConstantIndexOp::create(builder, 1); auto r0 = qtensor::AllocOp::create(builder, c1); auto extractOp = qtensor::ExtractOp::create(builder, r0, c0); auto q1 = HOp::create(builder, extractOp.getResult()); @@ -405,7 +405,7 @@ TEST_F(QCOTest, DirectIfBuilder) { IfOp::create(builder, measureOp.getResult(), measureOp.getQubitOut(), [&](ValueRange qubits) -> llvm::SmallVector { auto innerQubit = XOp::create(builder, qubits[0]); - return llvm::SmallVector{innerQubit}; + return llvm::SmallVector{innerQubit}; }); auto r2 = qtensor::InsertOp::create(builder, ifOp.getResult(0), extractOp.getOutTensor(), c0); @@ -486,21 +486,22 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCOBarrierOpTest, QCOTest, - testing::Values( - QCOTestCase{"Barrier", MQT_NAMED_BUILDER(barrier), - MQT_NAMED_BUILDER(barrier)}, - QCOTestCase{"BarrierTwoQubits", MQT_NAMED_BUILDER(barrierTwoQubits), - MQT_NAMED_BUILDER(barrierTwoQubits)}, - QCOTestCase{"BarrierMultipleQubits", - MQT_NAMED_BUILDER(barrierMultipleQubits), - MQT_NAMED_BUILDER(barrierMultipleQubits)}, - QCOTestCase{"SingleControlledBarrier", - MQT_NAMED_BUILDER(singleControlledBarrier), - MQT_NAMED_BUILDER(singleControlledBarrierCanonicalized)}, - QCOTestCase{"InverseBarrier", MQT_NAMED_BUILDER(inverseBarrier), - MQT_NAMED_BUILDER(barrier)}, - QCOTestCase{"TwoBarrier", MQT_NAMED_BUILDER(twoBarrier), - MQT_NAMED_BUILDER(barrierTwoQubits)})); + testing::Values(QCOTestCase{"Barrier", MQT_NAMED_BUILDER(barrier), + MQT_NAMED_BUILDER(barrier)}, + QCOTestCase{"BarrierTwoQubits", + MQT_NAMED_BUILDER(barrierTwoQubits), + MQT_NAMED_BUILDER(barrierTwoQubits)}, + QCOTestCase{"BarrierMultipleQubits", + MQT_NAMED_BUILDER(barrierMultipleQubits), + MQT_NAMED_BUILDER(barrierMultipleQubits)}, + QCOTestCase{"SingleControlledBarrier", + MQT_NAMED_BUILDER(singleControlledBarrier), + MQT_NAMED_BUILDER(barrier)}, + QCOTestCase{"InverseBarrier", + MQT_NAMED_BUILDER(inverseBarrier), + MQT_NAMED_BUILDER(barrier)}, + QCOTestCase{"TwoBarrier", MQT_NAMED_BUILDER(twoBarrier), + MQT_NAMED_BUILDER(barrierTwoQubits)})); /// @} /// \name QCO/Operations/StandardGates/DcxOp.cpp diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 7cf468b898..2357ce2440 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -182,11 +182,6 @@ void nestedControlledGlobalPhase(QCProgramBuilder& b) { b.ctrl(q[0], [&] { b.cgphase(0.123, q[1]); }); } -void nestedControlledGlobalPhaseCanonicalized(QCProgramBuilder& b) { - auto q = b.allocQubitRegister(3); - b.cp(0.123, q[0], q[1]); -} - void trivialControlledGlobalPhase(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.mcgphase(0.123, {}); @@ -211,31 +206,16 @@ void singleControlledIdentity(QCProgramBuilder& b) { b.cid(q[1], q[0]); } -void singleControlledIdentityCanonicalized(QCProgramBuilder& b) { - auto q = b.allocQubitRegister(2); - b.id(q[0]); -} - void multipleControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); b.mcid({q[2], q[1]}, q[0]); } -void multipleControlledIdentityCanonicalized(QCProgramBuilder& b) { - auto q = b.allocQubitRegister(3); - b.id(q[0]); -} - void nestedControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); b.ctrl(q[2], [&] { b.cid(q[1], q[0]); }); } -void nestedControlledIdentityCanonicalized(QCProgramBuilder& b) { - auto q = b.allocQubitRegister(3); - b.id(q[0]); -} - void trivialControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.mcid({}, q[0]); @@ -251,11 +231,6 @@ void inverseMultipleControlledIdentity(QCProgramBuilder& b) { b.inv([&]() { b.mcid({q[2], q[1]}, q[0]); }); } -void inverseMultipleControlledIdentityCanonicalized(QCProgramBuilder& b) { - auto q = b.allocQubitRegister(3); - b.id(q[0]); -} - void x(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.x(q[0]); @@ -1232,11 +1207,6 @@ void singleControlledBarrier(QCProgramBuilder& b) { b.ctrl(q[1], [&] { b.barrier(q[0]); }); } -void singleControlledBarrierCanonicalized(QCProgramBuilder& b) { - auto q = b.allocQubitRegister(2); - b.barrier(q[0]); -} - void inverseBarrier(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.inv([&]() { b.barrier(q[0]); }); diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index 5c4972b4a3..fac185a866 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -106,9 +106,6 @@ void multipleControlledGlobalPhase(QCProgramBuilder& b); /// Creates a circuit with a nested controlled global phase gate. void nestedControlledGlobalPhase(QCProgramBuilder& b); -/// Canonicalized version of `nestedControlledGlobalPhase`. -void nestedControlledGlobalPhaseCanonicalized(QCProgramBuilder& b); - /// Creates a circuit with a trivial controlled global phase gate. void trivialControlledGlobalPhase(QCProgramBuilder& b); @@ -127,21 +124,12 @@ void identity(QCProgramBuilder& b); /// Creates a controlled identity gate with a single control qubit. void singleControlledIdentity(QCProgramBuilder& b); -/// Canonicalized version of `singleControlledIdentity`. -void singleControlledIdentityCanonicalized(QCProgramBuilder& b); - /// Creates a multi-controlled identity gate with multiple control qubits. void multipleControlledIdentity(QCProgramBuilder& b); -/// Canonicalized version of `multipleControlledIdentity`. -void multipleControlledIdentityCanonicalized(QCProgramBuilder& b); - /// Creates a circuit with a nested controlled identity gate. void nestedControlledIdentity(QCProgramBuilder& b); -/// Canonicalized version of `nestedControlledIdentity`. -void nestedControlledIdentityCanonicalized(QCProgramBuilder& b); - /// Creates a circuit with a trivial controlled identity gate. void trivialControlledIdentity(QCProgramBuilder& b); @@ -152,9 +140,6 @@ void inverseIdentity(QCProgramBuilder& b); /// gate. void inverseMultipleControlledIdentity(QCProgramBuilder& b); -/// Canonicalized version of `inverseMultipleControlledIdentity`. -void inverseMultipleControlledIdentityCanonicalized(QCProgramBuilder& b); - // --- XOp ------------------------------------------------------------------ // /// Creates a circuit with just an X gate. @@ -801,9 +786,6 @@ void barrierMultipleQubits(QCProgramBuilder& b); /// Creates a circuit with a single controlled barrier. void singleControlledBarrier(QCProgramBuilder& b); -/// Canonicalized version of `singleControlledBarrier`. -void singleControlledBarrierCanonicalized(QCProgramBuilder& b); - /// Creates a circuit with an inverse modifier applied to a barrier. void inverseBarrier(QCProgramBuilder& b); diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index f01db152e5..22da726ef0 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -1926,11 +1926,6 @@ void singleControlledBarrier(QCOProgramBuilder& b) { }); } -void singleControlledBarrierCanonicalized(QCOProgramBuilder& b) { - auto q = b.allocQubitRegister(2); - b.barrier(q[0]); -} - void inverseBarrier(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.inv({q[0]}, [&](mlir::ValueRange qubits) { diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index abdf03c024..a14100102c 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -929,9 +929,6 @@ void barrierMultipleQubits(QCOProgramBuilder& b); /// Creates a circuit with a single controlled barrier. void singleControlledBarrier(QCOProgramBuilder& b); -/// Canonicalized version of `singleControlledBarrier`. -void singleControlledBarrierCanonicalized(QCOProgramBuilder& b); - /// Creates a circuit with an inverse modifier applied to a barrier. void inverseBarrier(QCOProgramBuilder& b); diff --git a/mlir/unittests/programs/qir_programs.cpp b/mlir/unittests/programs/qir_programs.cpp index e1691f26c0..82bc0e6d2c 100644 --- a/mlir/unittests/programs/qir_programs.cpp +++ b/mlir/unittests/programs/qir_programs.cpp @@ -170,21 +170,11 @@ void singleControlledIdentity(QIRProgramBuilder& b) { b.cid(q[0], q[1]); } -void singleControlledIdentityConverted(QIRProgramBuilder& b) { - auto q = b.allocQubitRegister(2); - b.id(q[0]); -} - void multipleControlledIdentity(QIRProgramBuilder& b) { auto q = b.allocQubitRegister(3); b.mcid({q[0], q[1]}, q[2]); } -void multipleControlledIdentityConverted(QIRProgramBuilder& b) { - auto q = b.allocQubitRegister(3); - b.id(q[0]); -} - void x(QIRProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.x(q[0]); @@ -595,18 +585,4 @@ void multipleControlledXxMinusYY(QIRProgramBuilder& b) { b.mcxx_minus_yy(0.123, 0.456, {q[0], q[1]}, q[2], q[3]); } -void barrierConverted(QIRProgramBuilder& b) { b.allocQubitRegister(1); } - -void barrierTwoQubitsConverted(QIRProgramBuilder& b) { - b.allocQubitRegister(2); -} - -void barrierMultipleQubitsConverted(QIRProgramBuilder& b) { - b.allocQubitRegister(3); -} - -void singleControlledBarrierConverted(QIRProgramBuilder& b) { - b.allocQubitRegister(2); -} - } // namespace mlir::qir diff --git a/mlir/unittests/programs/qir_programs.h b/mlir/unittests/programs/qir_programs.h index ed7d1a0dcb..2b9e591432 100644 --- a/mlir/unittests/programs/qir_programs.h +++ b/mlir/unittests/programs/qir_programs.h @@ -102,15 +102,9 @@ void identity(QIRProgramBuilder& b); /// Creates a controlled identity gate with a single control qubit. void singleControlledIdentity(QIRProgramBuilder& b); -/// Converted version of `qc::singleControlledIdentity`. -void singleControlledIdentityConverted(QIRProgramBuilder& b); - /// Creates a multi-controlled identity gate with multiple control qubits. void multipleControlledIdentity(QIRProgramBuilder& b); -/// Converted version of `qc::multipleControlledIdentity`. -void multipleControlledIdentityConverted(QIRProgramBuilder& b); - // --- XOp ------------------------------------------------------------------ // /// Creates a circuit with just an X gate. @@ -411,18 +405,4 @@ void singleControlledXxMinusYY(QIRProgramBuilder& b); /// Creates a circuit with a multi-controlled XXMinusYY gate. void multipleControlledXxMinusYY(QIRProgramBuilder& b); -// --- BarrierOp ------------------------------------------------------------ // - -/// Converted version of `qc::barrier`. -void barrierConverted(QIRProgramBuilder& b); - -/// Converted version of `qc::barrierTwoQubits`. -void barrierTwoQubitsConverted(QIRProgramBuilder& b); - -/// Converted version of `qc::barrierMultipleQubits`. -void barrierMultipleQubitsConverted(QIRProgramBuilder& b); - -/// Converted version of `qc::singleControlledBarrier`. -void singleControlledBarrierConverted(QIRProgramBuilder& b); - } // namespace mlir::qir From 4a666b83a225c5d40aceae30c4ca9a7e17abaa80 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 23:44:32 +0200 Subject: [PATCH 42/71] =?UTF-8?q?=E2=9C=A8=20Add=20QC=20pass=20for=20shrin?= =?UTF-8?q?king=20registers=20to=20fit=20accessed=20indices?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/QC/CMakeLists.txt | 1 + .../mlir/Dialect/QC/Transforms/CMakeLists.txt | 13 ++ .../mlir/Dialect/QC/Transforms/Passes.h | 31 ++++ .../mlir/Dialect/QC/Transforms/Passes.td | 27 +++ mlir/lib/Dialect/QC/CMakeLists.txt | 1 + mlir/lib/Dialect/QC/Transforms/CMakeLists.txt | 40 +++++ .../QC/Transforms/ShrinkQubitRegisters.cpp | 158 ++++++++++++++++++ 7 files changed, 271 insertions(+) create mode 100644 mlir/include/mlir/Dialect/QC/Transforms/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/QC/Transforms/Passes.h create mode 100644 mlir/include/mlir/Dialect/QC/Transforms/Passes.td create mode 100644 mlir/lib/Dialect/QC/Transforms/CMakeLists.txt create mode 100644 mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp diff --git a/mlir/include/mlir/Dialect/QC/CMakeLists.txt b/mlir/include/mlir/Dialect/QC/CMakeLists.txt index b181a84fed..3b0a561d0f 100644 --- a/mlir/include/mlir/Dialect/QC/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/QC/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/QC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/QC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..115aeb67b7 --- /dev/null +++ b/mlir/include/mlir/Dialect/QC/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name QC) +add_public_tablegen_target(MLIRQCTransformsIncGen) + +add_mlir_doc(Passes QCTransforms Passes/ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/QC/Transforms/Passes.h b/mlir/include/mlir/Dialect/QC/Transforms/Passes.h new file mode 100644 index 0000000000..435be783d7 --- /dev/null +++ b/mlir/include/mlir/Dialect/QC/Transforms/Passes.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include "mlir/Dialect/QC/IR/QCDialect.h" + +#include +#include + +namespace mlir::qc { + +#define GEN_PASS_DECL +#include "mlir/Dialect/QC/Transforms/Passes.h.inc" // IWYU pragma: export + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/QC/Transforms/Passes.h.inc" // IWYU pragma: export + +} // namespace mlir::qc diff --git a/mlir/include/mlir/Dialect/QC/Transforms/Passes.td b/mlir/include/mlir/Dialect/QC/Transforms/Passes.td new file mode 100644 index 0000000000..d717641820 --- /dev/null +++ b/mlir/include/mlir/Dialect/QC/Transforms/Passes.td @@ -0,0 +1,27 @@ +// Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +// Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +// All rights reserved. +// +// SPDX-License-Identifier: MIT +// +// Licensed under the MIT License + +#ifndef MLIR_DIALECT_QC_TRANSFORMS_PASSES_TD +#define MLIR_DIALECT_QC_TRANSFORMS_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def ShrinkQubitRegistersPass + : Pass<"qc-shrink-qubit-registers", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qc::QCDialect", "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect"]; + let summary = + "Shrink static qc::QubitType memref registers to accessed indices."; + let description = [{ + Shrinks one-dimensional static memref registers with element type + `!qc.qubit` by removing never-read indices and remapping `memref.load` + users accordingly. + }]; +} + +#endif // MLIR_DIALECT_QC_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/QC/CMakeLists.txt b/mlir/lib/Dialect/QC/CMakeLists.txt index 49d4a2a9fc..c4d5cdbc80 100644 --- a/mlir/lib/Dialect/QC/CMakeLists.txt +++ b/mlir/lib/Dialect/QC/CMakeLists.txt @@ -9,3 +9,4 @@ add_subdirectory(IR) add_subdirectory(Builder) add_subdirectory(Translation) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/QC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..e64d323c01 --- /dev/null +++ b/mlir/lib/Dialect/QC/Transforms/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +file(GLOB_RECURSE PASSES_SOURCES *.cpp) + +add_mlir_library( + MLIRQCTransforms + ${PASSES_SOURCES} + LINK_LIBS + PRIVATE + MLIRQCDialect + DEPENDS + MLIRQCTransformsIncGen) + +# collect header files +file(GLOB_RECURSE PASSES_HEADERS_SOURCE + ${MQT_MLIR_SOURCE_INCLUDE_DIR}/mlir/Dialect/QC/Transforms/*.h) +file(GLOB_RECURSE PASSES_HEADERS_BUILD + ${MQT_MLIR_BUILD_INCLUDE_DIR}/mlir/Dialect/QC/Transforms/*.inc) + +# add public headers using file sets +target_sources( + MLIRQCTransforms + PUBLIC FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_SOURCE_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_SOURCE} + FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_BUILD_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_BUILD}) diff --git a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp new file mode 100644 index 0000000000..fc8e9868c8 --- /dev/null +++ b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QC/IR/QCOps.h" +#include "mlir/Dialect/QC/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace mlir::qc { + +#define GEN_PASS_DEF_SHRINKQUBITREGISTERSPASS +#include "mlir/Dialect/QC/Transforms/Passes.h.inc" + +/** + * @brief Return the constant index of a one-dimensional memref load. + */ +[[nodiscard]] static std::optional +getLoadIndex(memref::LoadOp loadOp) { + if (loadOp.getIndices().size() != 1) { + return std::nullopt; + } + return getConstantIntValue(loadOp.getIndices().front()); +} + +namespace { +/** + * @brief Shrink static qubit registers to actually read indices. + */ +struct ShrinkQubitRegister final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DeallocOp op, + PatternRewriter& rewriter) const override { + auto allocOp = op.getMemref().getDefiningOp(); + if (!allocOp) { + return failure(); + } + + auto memRefType = llvm::dyn_cast(op.getMemref().getType()); + if (!memRefType || memRefType.getRank() != 1 || + !memRefType.hasStaticShape()) { + return failure(); + } + if (!llvm::isa(memRefType.getElementType())) { + return failure(); + } + + llvm::SmallVector loadOps; + llvm::SmallVector liveIndices; + llvm::DenseMap newIndexByOldIndex; + + for (auto* user : op.getMemref().getUsers()) { + if (user == op.getOperation()) { + continue; + } + auto loadOp = llvm::dyn_cast(user); + if (!loadOp) { + return failure(); + } + auto index = getLoadIndex(loadOp); + if (!index) { + return failure(); + } + loadOps.push_back(loadOp); + if (!loadOp.getResult().use_empty() && + !newIndexByOldIndex.contains(*index)) { + newIndexByOldIndex.try_emplace(*index, 0U); + liveIndices.push_back(*index); + } + } + + if (liveIndices.empty()) { + for (auto loadOp : loadOps) { + rewriter.eraseOp(loadOp); + } + rewriter.eraseOp(op); + rewriter.eraseOp(allocOp); + return success(); + } + + llvm::sort(liveIndices); + if (static_cast(liveIndices.size()) == memRefType.getDimSize(0) && + llvm::all_of(llvm::enumerate(liveIndices), [](const auto& indexed) { + return static_cast(indexed.index()) == indexed.value(); + })) { + return failure(); + } + + newIndexByOldIndex.clear(); + for (size_t i = 0; i < liveIndices.size(); ++i) { + newIndexByOldIndex.try_emplace(liveIndices[i], i); + } + + rewriter.setInsertionPoint(allocOp); + auto newMemRefType = + MemRefType::get({static_cast(liveIndices.size())}, + memRefType.getElementType()); + auto newAlloc = + memref::AllocOp::create(rewriter, allocOp.getLoc(), newMemRefType); + + for (auto loadOp : loadOps) { + if (loadOp.getResult().use_empty()) { + rewriter.eraseOp(loadOp); + continue; + } + + const auto oldIndex = *getLoadIndex(loadOp); + const auto newIndex = + static_cast(newIndexByOldIndex.lookup(oldIndex)); + rewriter.setInsertionPoint(loadOp); + auto indexConst = + arith::ConstantIndexOp::create(rewriter, loadOp.getLoc(), newIndex); + auto newLoad = memref::LoadOp::create(rewriter, loadOp.getLoc(), + newAlloc.getResult(), + ValueRange{indexConst.getResult()}); + rewriter.replaceOp(loadOp, newLoad); + } + + rewriter.setInsertionPoint(op); + memref::DeallocOp::create(rewriter, op.getLoc(), newAlloc.getResult()); + rewriter.eraseOp(op); + rewriter.eraseOp(allocOp); + return success(); + } +}; + +struct ShrinkQubitRegistersPass final + : impl::ShrinkQubitRegistersPassBase { +protected: + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } + } +}; +} // namespace +} // namespace mlir::qc From 3a94b90f83c607cdabe8ba9ed28e873485ecd818 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 23:45:55 +0200 Subject: [PATCH 43/71] =?UTF-8?q?=E2=9C=A8=20Add=20QCO=20pass=20for=20shri?= =?UTF-8?q?nking=20qtensors=20to=20fit=20accessed=20indices?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- .../mlir/Dialect/QTensor/CMakeLists.txt | 1 + .../Dialect/QTensor/Transforms/CMakeLists.txt | 13 + .../mlir/Dialect/QTensor/Transforms/Passes.h | 31 ++ .../mlir/Dialect/QTensor/Transforms/Passes.td | 25 ++ mlir/lib/Dialect/QTensor/CMakeLists.txt | 1 + .../Dialect/QTensor/Transforms/CMakeLists.txt | 40 ++ .../QTensor/Transforms/ShrinkRegisters.cpp | 418 ++++++++++++++++++ 7 files changed, 529 insertions(+) create mode 100644 mlir/include/mlir/Dialect/QTensor/Transforms/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/QTensor/Transforms/Passes.h create mode 100644 mlir/include/mlir/Dialect/QTensor/Transforms/Passes.td create mode 100644 mlir/lib/Dialect/QTensor/Transforms/CMakeLists.txt create mode 100644 mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp diff --git a/mlir/include/mlir/Dialect/QTensor/CMakeLists.txt b/mlir/include/mlir/Dialect/QTensor/CMakeLists.txt index b181a84fed..3b0a561d0f 100644 --- a/mlir/include/mlir/Dialect/QTensor/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/QTensor/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/QTensor/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/QTensor/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..de5795040d --- /dev/null +++ b/mlir/include/mlir/Dialect/QTensor/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name QTensor) +add_public_tablegen_target(MLIRQTensorTransformsIncGen) + +add_mlir_doc(Passes QTensorTransforms Passes/ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.h new file mode 100644 index 0000000000..e33924b5a5 --- /dev/null +++ b/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" + +#include +#include + +namespace mlir::qtensor { + +#define GEN_PASS_DECL +#include "mlir/Dialect/QTensor/Transforms/Passes.h.inc" // IWYU pragma: export + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/QTensor/Transforms/Passes.h.inc" // IWYU pragma: export + +} // namespace mlir::qtensor diff --git a/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.td new file mode 100644 index 0000000000..cfc322d3cf --- /dev/null +++ b/mlir/include/mlir/Dialect/QTensor/Transforms/Passes.td @@ -0,0 +1,25 @@ +// Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +// Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +// All rights reserved. +// +// SPDX-License-Identifier: MIT +// +// Licensed under the MIT License + +#ifndef MLIR_DIALECT_QTENSOR_TRANSFORMS_PASSES_TD +#define MLIR_DIALECT_QTENSOR_TRANSFORMS_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def ShrinkQTensorToFitPass : Pass<"qtensor-shrink-to-fit", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qtensor::QTensorDialect", + "mlir::arith::ArithDialect"]; + let summary = "Shrink static qtensors to their actually accessed indices."; + let description = [{ + Shrinks one-dimensional static qtensors by tracing linear tensor chains from + `qtensor.dealloc` to `qtensor.alloc` and rebuilding the chain on a compact + allocation that only keeps accessed indices. + }]; +} + +#endif // MLIR_DIALECT_QTENSOR_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/QTensor/CMakeLists.txt b/mlir/lib/Dialect/QTensor/CMakeLists.txt index b181a84fed..3b0a561d0f 100644 --- a/mlir/lib/Dialect/QTensor/CMakeLists.txt +++ b/mlir/lib/Dialect/QTensor/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/QTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QTensor/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..2c110797bd --- /dev/null +++ b/mlir/lib/Dialect/QTensor/Transforms/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +file(GLOB_RECURSE PASSES_SOURCES *.cpp) + +add_mlir_library( + MLIRQTensorTransforms + ${PASSES_SOURCES} + LINK_LIBS + PRIVATE + MLIRQTensorDialect + DEPENDS + MLIRQTensorTransformsIncGen) + +# collect header files +file(GLOB_RECURSE PASSES_HEADERS_SOURCE + ${MQT_MLIR_SOURCE_INCLUDE_DIR}/mlir/Dialect/QTensor/Transforms/*.h) +file(GLOB_RECURSE PASSES_HEADERS_BUILD + ${MQT_MLIR_BUILD_INCLUDE_DIR}/mlir/Dialect/QTensor/Transforms/*.inc) + +# add public headers using file sets +target_sources( + MLIRQTensorTransforms + PUBLIC FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_SOURCE_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_SOURCE} + FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_BUILD_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_BUILD}) diff --git a/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp new file mode 100644 index 0000000000..36030c7cee --- /dev/null +++ b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace mlir::qtensor { + +#define GEN_PASS_DEF_SHRINKQTENSORTOFITPASS +#include "mlir/Dialect/QTensor/Transforms/Passes.h.inc" + +/** + * @brief Return the unique user of a linear qtensor value. + */ +[[nodiscard]] static Operation* getLinearTensorUser(const Value tensor) { + assert(tensor.hasOneUse() && "Expected a linear tensor with exactly one use"); + return *tensor.getUsers().begin(); +} + +/** + * @brief Mark a single live index. + */ +[[nodiscard]] static LogicalResult markLiveIndex(const int64_t index, + llvm::BitVector& liveIndices) { + if (index < 0 || index >= static_cast(liveIndices.size())) { + return failure(); + } + liveIndices.set(static_cast(index)); + return success(); +} + +/** + * @brief Mark a contiguous live range. + */ +[[nodiscard]] static LogicalResult markLiveRange(const int64_t offset, + const int64_t size, + llvm::BitVector& liveIndices) { + if (offset < 0 || size <= 0 || + offset + size > static_cast(liveIndices.size())) { + return failure(); + } + for (int64_t index = offset; index < offset + size; ++index) { + liveIndices.set(static_cast(index)); + } + return success(); +} + +/** + * @brief Redirect the tensor operand from @p from to @p to. + */ +[[nodiscard]] static LogicalResult remapTensorOperand(Operation* op, Value from, + Value to) { + if (auto extractOp = llvm::dyn_cast(op)) { + if (extractOp.getTensor() != from) { + return failure(); + } + extractOp->setOperand(0, to); + return success(); + } + if (auto insertOp = llvm::dyn_cast(op)) { + if (insertOp.getDest() != from) { + return failure(); + } + insertOp->setOperand(1, to); + return success(); + } + if (auto extractSliceOp = llvm::dyn_cast(op)) { + if (extractSliceOp.getTensor() != from) { + return failure(); + } + extractSliceOp->setOperand(0, to); + return success(); + } + if (auto insertSliceOp = llvm::dyn_cast(op)) { + if (insertSliceOp.getDest() != from) { + return failure(); + } + insertSliceOp->setOperand(1, to); + return success(); + } + if (auto deallocOp = llvm::dyn_cast(op)) { + if (deallocOp.getTensor() != from) { + return failure(); + } + deallocOp->setOperand(0, to); + return success(); + } + return failure(); +} + +/** + * @brief Walk alloc->dealloc and collect all touched indices. + */ +[[nodiscard]] static LogicalResult collectLiveIndices(AllocOp allocOp, + llvm::BitVector& live, + DeallocOp& deallocOp) { + Value tensor = allocOp.getResult(); + while (true) { + auto* user = getLinearTensorUser(tensor); + if (!user) { + return failure(); + } + + if (auto currentDealloc = llvm::dyn_cast(user)) { + if (currentDealloc.getTensor() != tensor) { + return failure(); + } + deallocOp = currentDealloc; + return success(); + } + + if (auto extractOp = llvm::dyn_cast(user)) { + if (extractOp.getTensor() != tensor) { + return failure(); + } + auto index = getConstantIntValue(extractOp.getIndex()); + if (!index || failed(markLiveIndex(*index, live))) { + return failure(); + } + tensor = extractOp.getOutTensor(); + continue; + } + + if (auto insertOp = llvm::dyn_cast(user)) { + if (insertOp.getDest() != tensor) { + return failure(); + } + auto index = getConstantIntValue(insertOp.getIndex()); + if (!index || failed(markLiveIndex(*index, live))) { + return failure(); + } + tensor = insertOp.getResult(); + continue; + } + + if (auto extractSliceOp = llvm::dyn_cast(user)) { + if (extractSliceOp.getTensor() != tensor) { + return failure(); + } + auto offset = getConstantIntValue(extractSliceOp.getOffset()); + auto size = getConstantIntValue(extractSliceOp.getSize()); + if (!offset || !size || failed(markLiveRange(*offset, *size, live))) { + return failure(); + } + tensor = extractSliceOp.getOutTensor(); + continue; + } + + if (auto insertSliceOp = llvm::dyn_cast(user)) { + if (insertSliceOp.getDest() != tensor) { + return failure(); + } + auto offset = getConstantIntValue(insertSliceOp.getOffset()); + auto size = getConstantIntValue(insertSliceOp.getSize()); + if (!offset || !size || failed(markLiveRange(*offset, *size, live))) { + return failure(); + } + tensor = insertSliceOp.getResult(); + continue; + } + + return failure(); + } +} + +/** + * @brief Shrink static qtensors by removing never-accessed indices. + * @details QTensor is linear, so this rewrite follows a single use-def chain. + */ +struct ShrinkStaticQTensor final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocOp allocOp, + PatternRewriter& rewriter) const override { + auto oldSize = getConstantIntValue(allocOp.getSize()); + if (!oldSize || *oldSize <= 0) { + return failure(); + } + + llvm::BitVector live(static_cast(*oldSize), false); + DeallocOp oldDeallocOp{}; + if (failed(collectLiveIndices(allocOp, live, oldDeallocOp))) { + return failure(); + } + + if (!oldDeallocOp) { + return failure(); + } + + llvm::SmallVector newIndexByOldIndex(static_cast(*oldSize), + -1); + int64_t newSize = 0; + for (int64_t index = 0; index < *oldSize; ++index) { + if (live.test(static_cast(index))) { + newIndexByOldIndex[static_cast(index)] = newSize++; + } + } + + if (newSize <= 0 || newSize == *oldSize) { + return failure(); + } + + rewriter.setInsertionPoint(allocOp); + auto size = + arith::ConstantIndexOp::create(rewriter, allocOp.getLoc(), newSize); + auto newAlloc = + AllocOp::create(rewriter, allocOp.getLoc(), size.getResult()); + + Value oldTensor = allocOp.getResult(); + Value currentTensor = newAlloc.getResult(); + while (true) { + Operation* currentOp = getLinearTensorUser(oldTensor); + if (!currentOp) { + return failure(); + } + + if (auto deallocOp = llvm::dyn_cast(currentOp)) { + if (deallocOp != oldDeallocOp || deallocOp.getTensor() != oldTensor) { + return failure(); + } + rewriter.setInsertionPoint(deallocOp); + DeallocOp::create(rewriter, deallocOp.getLoc(), currentTensor); + rewriter.eraseOp(deallocOp); + break; + } + + if (auto extractOp = llvm::dyn_cast(currentOp)) { + if (extractOp.getTensor() != oldTensor) { + return failure(); + } + const auto oldIndex = *getConstantIntValue(extractOp.getIndex()); + if (oldIndex < 0 || + oldIndex >= static_cast(newIndexByOldIndex.size())) { + return failure(); + } + const auto mappedIndex = + newIndexByOldIndex[static_cast(oldIndex)]; + if (mappedIndex < 0) { + return failure(); + } + Value oldOutTensor = extractOp.getOutTensor(); + Operation* nextOp = getLinearTensorUser(oldOutTensor); + if (!nextOp) { + return failure(); + } + + rewriter.setInsertionPoint(extractOp); + auto index = arith::ConstantIndexOp::create( + rewriter, extractOp.getLoc(), mappedIndex); + auto newExtract = ExtractOp::create(rewriter, extractOp.getLoc(), + currentTensor, index.getResult()); + rewriter.replaceAllUsesWith(extractOp.getResult(), + newExtract.getResult()); + + currentTensor = newExtract.getOutTensor(); + if (failed(remapTensorOperand(nextOp, oldOutTensor, oldTensor))) { + return failure(); + } + rewriter.eraseOp(extractOp); + continue; + } + + if (auto insertOp = llvm::dyn_cast(currentOp)) { + if (insertOp.getDest() != oldTensor) { + return failure(); + } + const auto oldIndex = *getConstantIntValue(insertOp.getIndex()); + if (oldIndex < 0 || + oldIndex >= static_cast(newIndexByOldIndex.size())) { + return failure(); + } + const auto mappedIndex = + newIndexByOldIndex[static_cast(oldIndex)]; + if (mappedIndex < 0) { + return failure(); + } + Value oldResultTensor = insertOp.getResult(); + Operation* nextOp = getLinearTensorUser(oldResultTensor); + if (!nextOp) { + return failure(); + } + + rewriter.setInsertionPoint(insertOp); + auto index = arith::ConstantIndexOp::create(rewriter, insertOp.getLoc(), + mappedIndex); + auto newInsert = + InsertOp::create(rewriter, insertOp.getLoc(), insertOp.getScalar(), + currentTensor, index.getResult()); + + currentTensor = newInsert.getResult(); + if (failed(remapTensorOperand(nextOp, oldResultTensor, oldTensor))) { + return failure(); + } + rewriter.eraseOp(insertOp); + continue; + } + + if (auto extractSliceOp = llvm::dyn_cast(currentOp)) { + if (extractSliceOp.getTensor() != oldTensor) { + return failure(); + } + const auto oldOffset = *getConstantIntValue(extractSliceOp.getOffset()); + const auto oldSliceSize = + *getConstantIntValue(extractSliceOp.getSize()); + if (oldOffset < 0 || oldSliceSize <= 0 || + oldOffset + oldSliceSize > + static_cast(newIndexByOldIndex.size())) { + return failure(); + } + const auto mappedOffset = + newIndexByOldIndex[static_cast(oldOffset)]; + if (mappedOffset < 0) { + return failure(); + } + Value oldOutTensor = extractSliceOp.getOutTensor(); + Operation* nextOp = getLinearTensorUser(oldOutTensor); + if (!nextOp) { + return failure(); + } + rewriter.setInsertionPoint(extractSliceOp); + auto newOffset = arith::ConstantIndexOp::create( + rewriter, extractSliceOp.getLoc(), mappedOffset); + auto newSliceSize = arith::ConstantIndexOp::create( + rewriter, extractSliceOp.getLoc(), oldSliceSize); + auto newExtractSlice = ExtractSliceOp::create( + rewriter, extractSliceOp.getLoc(), currentTensor, + newOffset.getResult(), newSliceSize.getResult()); + rewriter.replaceAllUsesWith(extractSliceOp.getResult(), + newExtractSlice.getResult()); + + currentTensor = newExtractSlice.getOutTensor(); + if (failed(remapTensorOperand(nextOp, oldOutTensor, oldTensor))) { + return failure(); + } + rewriter.eraseOp(extractSliceOp); + continue; + } + + if (auto insertSliceOp = llvm::dyn_cast(currentOp)) { + if (insertSliceOp.getDest() != oldTensor) { + return failure(); + } + const auto oldOffset = *getConstantIntValue(insertSliceOp.getOffset()); + const auto oldSliceSize = *getConstantIntValue(insertSliceOp.getSize()); + if (oldOffset < 0 || oldSliceSize <= 0 || + oldOffset + oldSliceSize > + static_cast(newIndexByOldIndex.size())) { + return failure(); + } + const auto mappedOffset = + newIndexByOldIndex[static_cast(oldOffset)]; + if (mappedOffset < 0) { + return failure(); + } + Value oldResultTensor = insertSliceOp.getResult(); + Operation* nextOp = getLinearTensorUser(oldResultTensor); + if (!nextOp) { + return failure(); + } + + rewriter.setInsertionPoint(insertSliceOp); + auto newOffset = arith::ConstantIndexOp::create( + rewriter, insertSliceOp.getLoc(), mappedOffset); + auto newSliceSize = arith::ConstantIndexOp::create( + rewriter, insertSliceOp.getLoc(), oldSliceSize); + auto newInsertSlice = InsertSliceOp::create( + rewriter, insertSliceOp.getLoc(), insertSliceOp.getSource(), + currentTensor, newOffset.getResult(), newSliceSize.getResult()); + + currentTensor = newInsertSlice.getResult(); + if (failed(remapTensorOperand(nextOp, oldResultTensor, oldTensor))) { + return failure(); + } + rewriter.eraseOp(insertSliceOp); + continue; + } + + return failure(); + } + + rewriter.eraseOp(allocOp); + return success(); + } +}; + +struct ShrinkQTensorToFitPass final + : impl::ShrinkQTensorToFitPassBase { +protected: + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace mlir::qtensor From 1aeaf7a0c98ad902cea4a81b2964ac8ca2697e85 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 23:46:47 +0200 Subject: [PATCH 44/71] =?UTF-8?q?=E2=9C=A8=20Add=20QIR=20pass=20for=20simp?= =?UTF-8?q?lifying=20qubit-array=20allocation/release=20pairs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/CMakeLists.txt | 1 + mlir/include/mlir/Dialect/QIR/CMakeLists.txt | 9 + .../Dialect/QIR/Transforms/CMakeLists.txt | 13 ++ .../mlir/Dialect/QIR/Transforms/Passes.h | 29 +++ .../mlir/Dialect/QIR/Transforms/Passes.td | 23 ++ mlir/lib/Dialect/QIR/CMakeLists.txt | 1 + .../lib/Dialect/QIR/Transforms/CMakeLists.txt | 43 ++++ .../lib/Dialect/QIR/Transforms/QIRCleanup.cpp | 212 ++++++++++++++++++ 8 files changed, 331 insertions(+) create mode 100644 mlir/include/mlir/Dialect/QIR/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/QIR/Transforms/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/QIR/Transforms/Passes.h create mode 100644 mlir/include/mlir/Dialect/QIR/Transforms/Passes.td create mode 100644 mlir/lib/Dialect/QIR/Transforms/CMakeLists.txt create mode 100644 mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index dfc77dc6e2..6714a978af 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -8,4 +8,5 @@ add_subdirectory(QC) add_subdirectory(QCO) +add_subdirectory(QIR) add_subdirectory(QTensor) diff --git a/mlir/include/mlir/Dialect/QIR/CMakeLists.txt b/mlir/include/mlir/Dialect/QIR/CMakeLists.txt new file mode 100644 index 0000000000..3c339729a9 --- /dev/null +++ b/mlir/include/mlir/Dialect/QIR/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/QIR/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/QIR/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..6d81096520 --- /dev/null +++ b/mlir/include/mlir/Dialect/QIR/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name QIR) +add_public_tablegen_target(MLIRQIRTransformsIncGen) + +add_mlir_doc(Passes QIRTransforms Passes/ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/QIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.h new file mode 100644 index 0000000000..e1d281d87a --- /dev/null +++ b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include +#include + +namespace mlir::qir { + +#define GEN_PASS_DECL +#include "mlir/Dialect/QIR/Transforms/Passes.h.inc" // IWYU pragma: export + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/QIR/Transforms/Passes.h.inc" // IWYU pragma: export + +} // namespace mlir::qir diff --git a/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td new file mode 100644 index 0000000000..efb7fea4b1 --- /dev/null +++ b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td @@ -0,0 +1,23 @@ +// Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +// Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +// All rights reserved. +// +// SPDX-License-Identifier: MIT +// +// Licensed under the MIT License + +#ifndef MLIR_DIALECT_QIR_TRANSFORMS_PASSES_TD +#define MLIR_DIALECT_QIR_TRANSFORMS_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def QIRCleanupPass : Pass<"qir-cleanup", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let summary = "Remove redundant QIR runtime bookkeeping."; + let description = [{ + Removes redundant QIR runtime qubit-array allocation/release pairs that do + not contribute to observable behavior, and keeps QIR modules compact. + }]; +} + +#endif // MLIR_DIALECT_QIR_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/QIR/CMakeLists.txt b/mlir/lib/Dialect/QIR/CMakeLists.txt index e2d07911c5..2aa54c44ff 100644 --- a/mlir/lib/Dialect/QIR/CMakeLists.txt +++ b/mlir/lib/Dialect/QIR/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(Utils) add_subdirectory(Builder) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/QIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QIR/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..95afa87e3f --- /dev/null +++ b/mlir/lib/Dialect/QIR/Transforms/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +file(GLOB_RECURSE PASSES_SOURCES *.cpp) + +add_mlir_library( + MLIRQIRTransforms + ${PASSES_SOURCES} + LINK_LIBS + PRIVATE + MLIRLLVMDialect + MLIRQIRUtils + DEPENDS + MLIRQIRTransformsIncGen) + +mqt_mlir_target_use_project_options(MLIRQIRTransforms) + +# collect header files +file(GLOB_RECURSE PASSES_HEADERS_SOURCE + ${MQT_MLIR_SOURCE_INCLUDE_DIR}/mlir/Dialect/QIR/Transforms/*.h) +file(GLOB_RECURSE PASSES_HEADERS_BUILD + ${MQT_MLIR_BUILD_INCLUDE_DIR}/mlir/Dialect/QIR/Transforms/*.inc) + +# add public headers using file sets +target_sources( + MLIRQIRTransforms + PUBLIC FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_SOURCE_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_SOURCE} + FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_BUILD_INCLUDE_DIR} + FILES + ${PASSES_HEADERS_BUILD}) diff --git a/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp b/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp new file mode 100644 index 0000000000..7daabd2cf4 --- /dev/null +++ b/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QIR/Transforms/Passes.h" +#include "mlir/Dialect/QIR/Utils/QIRUtils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlir::qir { + +#define GEN_PASS_DEF_QIRCLEANUPPASS +#include "mlir/Dialect/QIR/Transforms/Passes.h.inc" + +[[nodiscard]] static StringAttr getMetadataKey(const Attribute attr) { + auto pair = llvm::dyn_cast(attr); + if (!pair || pair.size() != 2) { + return {}; + } + auto key = llvm::dyn_cast(pair[0]); + if (!key || !llvm::isa(pair[1])) { + return {}; + } + return key; +} + +[[nodiscard]] static llvm::StringRef getCalleeName(LLVM::CallOp callOp) { + auto calleeAttr = callOp.getCalleeAttr(); + auto flatRef = llvm::dyn_cast_or_null(calleeAttr); + if (!flatRef) { + return {}; + } + return flatRef.getValue(); +} + +[[nodiscard]] static bool moduleHasDynamicQubitRuntimeCalls(ModuleOp module) { + return llvm::any_of(module.getOps(), [](LLVM::CallOp callOp) { + const auto callee = getCalleeName(callOp); + return callee == QIR_QUBIT_ALLOC || callee == QIR_QUBIT_ARRAY_ALLOC; + }); +} + +[[nodiscard]] static bool moduleHasDynamicResultRuntimeCalls(ModuleOp module) { + return llvm::any_of(module.getOps(), [](LLVM::CallOp callOp) { + const auto callee = getCalleeName(callOp); + return callee == QIR_RESULT_ALLOC || callee == QIR_RESULT_ARRAY_ALLOC; + }); +} + +static void dropUnusedExternalDeclarations(ModuleOp module) { + for (auto funcOp : + llvm::make_early_inc_range(module.getOps())) { + if (!funcOp.isExternal()) { + continue; + } + if (!SymbolTable::symbolKnownUseEmpty(funcOp, module)) { + continue; + } + funcOp.erase(); + } +} + +static void normalizeQIRMetadata(ModuleOp module) { + auto main = getMainFunction(module); + if (!main) { + return; + } + + auto passthroughAttr = main->getAttrOfType("passthrough"); + if (!passthroughAttr) { + return; + } + + const bool hasDynamicQubit = moduleHasDynamicQubitRuntimeCalls(module); + const bool hasDynamicResult = moduleHasDynamicResultRuntimeCalls(module); + if (hasDynamicQubit && hasDynamicResult) { + return; + } + + ArrayAttr requiredNumQubitsAttr = nullptr; + ArrayAttr requiredNumResultsAttr = nullptr; + for (const auto attr : passthroughAttr) { + const auto key = getMetadataKey(attr); + if (!key) { + continue; + } + if (key.getValue() == "required_num_qubits") { + requiredNumQubitsAttr = llvm::cast(attr); + } else if (key.getValue() == "required_num_results") { + requiredNumResultsAttr = llvm::cast(attr); + } + } + + OpBuilder builder(module.getContext()); + SmallVector updatedMetadata; + updatedMetadata.reserve(passthroughAttr.size() + 2); + + for (const auto attr : passthroughAttr) { + const auto key = getMetadataKey(attr); + if (!key) { + updatedMetadata.push_back(attr); + continue; + } + + if (key.getValue() == "dynamic_qubit_management" && !hasDynamicQubit) { + if (requiredNumQubitsAttr) { + updatedMetadata.push_back(requiredNumQubitsAttr); + } + continue; + } + if (key.getValue() == "dynamic_result_management" && !hasDynamicResult) { + if (requiredNumResultsAttr) { + updatedMetadata.push_back(requiredNumResultsAttr); + } + continue; + } + + updatedMetadata.push_back(attr); + } + + main->setAttr("passthrough", builder.getArrayAttr(updatedMetadata)); +} + +namespace { +/** + * @brief Remove dead QIR qubit-array allocation/release pairs. + * @details Matches an unused `__quantum__rt__qubit_array_allocate` / + * `__quantum__rt__qubit_array_release` pair on the same stack slot. + */ +struct RemoveDeadQubitArrayPair final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp releaseCall, + PatternRewriter& rewriter) const override { + if (getCalleeName(releaseCall) != QIR_QUBIT_ARRAY_RELEASE || + releaseCall.getNumOperands() < 2) { + return failure(); + } + + auto allocaOp = releaseCall.getOperand(1).getDefiningOp(); + if (!allocaOp) { + return failure(); + } + + LLVM::CallOp allocCall = nullptr; + for (Operation* user : allocaOp.getResult().getUsers()) { + auto callOp = llvm::dyn_cast(user); + if (!callOp) { + return failure(); + } + + if (callOp == releaseCall) { + continue; + } + + if (getCalleeName(callOp) != QIR_QUBIT_ARRAY_ALLOC || + callOp.getNumOperands() < 2 || + callOp.getOperand(1) != allocaOp.getResult()) { + return failure(); + } + if (allocCall != nullptr) { + return failure(); + } + allocCall = callOp; + } + + if (!allocCall) { + return failure(); + } + + rewriter.eraseOp(releaseCall); + rewriter.eraseOp(allocCall); + if (allocaOp->use_empty()) { + rewriter.eraseOp(allocaOp); + } + return success(); + } +}; + +struct QIRCleanupPass final : impl::QIRCleanupPassBase { +protected: + void runOnOperation() override { + auto module = getOperation(); + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + signalPassFailure(); + return; + } + + dropUnusedExternalDeclarations(module); + normalizeQIRMetadata(module); + } +}; + +} // namespace + +} // namespace mlir::qir From 45b74aaef04c78d400e67eee04b804d9a3274700 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 23:47:35 +0200 Subject: [PATCH 45/71] =?UTF-8?q?=F0=9F=8E=A8=20Miscellaneous=20small=20fi?= =?UTF-8?q?xes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- .../lib/Dialect/QCO/Transforms/CMakeLists.txt | 1 - .../Dialect/QTensor/IR/Operations/AllocOp.cpp | 19 +++++++++---------- .../QTensor/IR/Operations/DeallocOp.cpp | 2 +- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt index 3b965dce5e..2268584167 100644 --- a/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt @@ -15,7 +15,6 @@ add_mlir_library( PRIVATE MLIRQCODialect MLIRQCOUtils - ${dialect_libs} DEPENDS MLIRQCOTransformsIncGen) diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp index 898b8b6412..05978b6f36 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp @@ -45,16 +45,15 @@ LogicalResult AllocOp::verify() { if (sizeValue && *sizeValue <= 0) { return emitOpError("Constant size operand must be positive"); } - if (sizeValue.has_value() == resultType.isDynamicDim(0)) { - return emitOpError("Size operand and result type must both be static or " - "both be dynamic, but got ") - << (sizeValue ? "static size with dynamic result" - : "dynamic size with static result"); - } - if (sizeValue && resultSize != *sizeValue) { - return emitOpError("Constant size operand (") - << *sizeValue << ") does not match static result size (" - << resultSize << ")"; + if (!resultType.isDynamicDim(0)) { + if (!sizeValue) { + return emitOpError("Static result type requires constant size operand"); + } + if (resultSize != *sizeValue) { + return emitOpError("Constant size operand (") + << *sizeValue << ") does not match static result size (" + << resultSize << ")"; + } } return success(); diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp index 90f076ede1..bf62cd1df4 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp @@ -35,7 +35,7 @@ struct RemoveAllocDeallocPair final : OpRewritePattern { return failure(); } - // Remove the AllocOp and the DeallocOp + // Remove the AllocOp and the DeallocOp. rewriter.eraseOp(op); rewriter.eraseOp(allocOp); return success(); From c709efb2259e1ac2b218800293885f466de29d62 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Fri, 3 Apr 2026 23:51:19 +0200 Subject: [PATCH 46/71] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Adjust=20compiler=20?= =?UTF-8?q?pipeline=20and=20tests=20to=20use=20new=20passes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- mlir/include/mlir/Compiler/CompilerPipeline.h | 23 +++---- mlir/include/mlir/Support/Passes.h | 35 ++++++++++- mlir/lib/Compiler/CompilerPipeline.cpp | 60 +++++++++---------- mlir/lib/Support/CMakeLists.txt | 3 + mlir/lib/Support/Passes.cpp | 54 +++++++++++++++-- .../Compiler/test_compiler_pipeline.cpp | 4 +- .../JeffRoundTrip/test_jeff_round_trip.cpp | 8 +-- .../Conversion/QCOToQC/test_qco_to_qc.cpp | 6 +- .../Conversion/QCToQCO/test_qc_to_qco.cpp | 6 +- .../Conversion/QCToQIR/test_qc_to_qir.cpp | 6 +- mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp | 4 +- .../test_quantum_computation_translation.cpp | 4 +- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 53 ++++++++-------- mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp | 4 +- 14 files changed, 166 insertions(+), 104 deletions(-) diff --git a/mlir/include/mlir/Compiler/CompilerPipeline.h b/mlir/include/mlir/Compiler/CompilerPipeline.h index ec7473a8cf..f43bac99d7 100644 --- a/mlir/include/mlir/Compiler/CompilerPipeline.h +++ b/mlir/include/mlir/Compiler/CompilerPipeline.h @@ -72,18 +72,18 @@ struct CompilationRecord { * * 1. QC dialect (reference semantics) - imported from * qc::QuantumComputation - * 2. Canonicalization + cleanup + * 2. QC cleanup pipeline * 3. QCO dialect (value semantics) - enables SSA-based optimizations - * 4. Canonicalization + cleanup + * 4. QCO cleanup pipeline * 5. Quantum optimization passes - * 6. Canonicalization + cleanup + * 6. QCO cleanup pipeline * 7. QC dialect - converted back for backend lowering - * 8. Canonicalization + cleanup + * 8. QC cleanup pipeline * 9. QIR (Quantum Intermediate Representation) - optional final lowering - * 10. Canonicalization + cleanup + * 10. QIR cleanup pipeline * - * Following MLIR best practices, canonicalization and dead value removal - * are always run after each major transformation stage. + * Following MLIR best practices, simplification and dead-value cleanup are + * run after each major transformation stage. */ class QuantumCompilerPipeline { public: @@ -111,15 +111,6 @@ class QuantumCompilerPipeline { CompilationRecord* record = nullptr) const; private: - /** - * @brief Add canonicalization and cleanup passes - * - * @details - * Always adds the standard MLIR canonicalization pass followed by common - * sub-expression elimination and dead value removal. - */ - static void addCleanupPasses(PassManager& pm); - /** * @brief Configure PassManager with diagnostic options * diff --git a/mlir/include/mlir/Support/Passes.h b/mlir/include/mlir/Support/Passes.h index c671ca6970..4b8057ced2 100644 --- a/mlir/include/mlir/Support/Passes.h +++ b/mlir/include/mlir/Support/Passes.h @@ -12,9 +12,38 @@ namespace mlir { class ModuleOp; -} +class PassManager; +} // namespace mlir /** - * @brief Run canonicalization and dead value removal on the given module. + * @brief Populate a QC-oriented cleanup pipeline on the given pass manager. + * @details Adds generic cleanup and QC qubit-register shrinking. */ -void runCanonicalizationPasses(mlir::ModuleOp module); +void populateQCCleanupPipeline(mlir::PassManager& passManager); + +/** + * @brief Populate a QCO-oriented cleanup pipeline on the given pass manager. + * @details Adds generic cleanup and qtensor shrink-to-fit. + */ +void populateQCOCleanupPipeline(mlir::PassManager& passManager); + +/** + * @brief Populate a QIR-oriented cleanup pipeline on the given pass manager. + * @details Adds generic cleanup and QIR-specific simplifications. + */ +void populateQIRCleanupPipeline(mlir::PassManager& passManager); + +/** + * @brief Run the QC-oriented cleanup pipeline on a module. + */ +void runQCCleanupPipeline(mlir::ModuleOp module); + +/** + * @brief Run the QCO-oriented cleanup pipeline on a module. + */ +void runQCOCleanupPipeline(mlir::ModuleOp module); + +/** + * @brief Run the QIR-oriented cleanup pipeline on a module. + */ +void runQIRCleanupPipeline(mlir::ModuleOp module); diff --git a/mlir/lib/Compiler/CompilerPipeline.cpp b/mlir/lib/Compiler/CompilerPipeline.cpp index 119e343755..0bab524e53 100644 --- a/mlir/lib/Compiler/CompilerPipeline.cpp +++ b/mlir/lib/Compiler/CompilerPipeline.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/QCOToQC/QCOToQC.h" #include "mlir/Conversion/QCToQCO/QCToQCO.h" #include "mlir/Conversion/QCToQIR/QCToQIR.h" +#include "mlir/Support/Passes.h" #include "mlir/Support/PrettyPrinting.h" #include @@ -20,7 +21,6 @@ #include #include #include -#include #include @@ -45,14 +45,6 @@ static void prettyPrintStage(ModuleOp module, const llvm::StringRef stageName, printProgram(module, stageHeader, llvm::errs()); } -void QuantumCompilerPipeline::addCleanupPasses(PassManager& pm) { - // Always run canonicalization, common sub-expression elimination, and dead - // value removal - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - pm.addPass(createRemoveDeadValuesPass()); -} - void QuantumCompilerPipeline::configurePassManager(PassManager& pm) const { // Enable timing statistics if requested if (config_.enableTiming) { @@ -85,15 +77,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, // Determine total number of stages for progress indication // 1. QC import - // 2. QC canonicalization + // 2. QC cleanup // 3. QC-to-QCO conversion - // 4. QCO canonicalization + // 4. QCO cleanup // 5. Optimization passes - // 6. QCO canonicalization + // 6. QCO cleanup // 7. QCO-to-QC conversion - // 8. QC canonicalization + // 8. QC cleanup // 9. QC-to-QIR conversion (optional) - // 10. QIR canonicalization (optional) + // 10. QIR cleanup (optional) auto totalStages = 8; if (config_.convertToQIR) { totalStages += 2; @@ -108,14 +100,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, } } - // Stage 2: QC canonicalization - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 2: QC cleanup + if (failed( + runStage([&](PassManager& pm) { populateQCCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterInitialCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "Initial QC Canonicalization", ++currentStage, + prettyPrintStage(module, "Initial QC Cleanup", ++currentStage, totalStages); } } @@ -130,20 +123,22 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, totalStages); } } - // Stage 4: QCO canonicalization - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 4: QCO cleanup + if (failed( + runStage([&](PassManager& pm) { populateQCOCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterQCOCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "Initial QCO Canonicalization", ++currentStage, + prettyPrintStage(module, "Initial QCO Cleanup", ++currentStage, totalStages); } } // Stage 5: Optimization passes // TODO: Add optimization passes - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + if (failed( + runStage([&](PassManager& pm) { populateQCOCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { @@ -153,14 +148,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, totalStages); } } - // Stage 6: QCO canonicalization - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 6: QCO cleanup + if (failed( + runStage([&](PassManager& pm) { populateQCOCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterOptimizationCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "Final QCO Canonicalization", ++currentStage, + prettyPrintStage(module, "Final QCO Cleanup", ++currentStage, totalStages); } } @@ -175,15 +171,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, totalStages); } } - // Stage 8: QC canonicalization - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 8: QC cleanup + if (failed( + runStage([&](PassManager& pm) { populateQCCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterQCCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "Final QC Canonicalization", ++currentStage, - totalStages); + prettyPrintStage(module, "Final QC Cleanup", ++currentStage, totalStages); } } // Stage 9: QC-to-QIR conversion (optional) @@ -199,15 +195,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, totalStages); } } - // Stage 10: QIR canonicalization (optional) - if (failed(runStage([&](PassManager& pm) { addCleanupPasses(pm); }))) { + // Stage 10: QIR cleanup (optional) + if (failed(runStage( + [&](PassManager& pm) { populateQIRCleanupPipeline(pm); }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { record->afterQIRCanon = captureIR(module); if (config_.printIRAfterAllStages) { - prettyPrintStage(module, "QIR Canonicalization", ++currentStage, - totalStages); + prettyPrintStage(module, "QIR Cleanup", ++currentStage, totalStages); } } } diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt index f63a83c87e..462d791329 100644 --- a/mlir/lib/Support/CMakeLists.txt +++ b/mlir/lib/Support/CMakeLists.txt @@ -24,6 +24,9 @@ add_mlir_library( MLIRLLVMDialect MLIRFuncDialect MLIRArithDialect + MLIRQCTransforms + MLIRQIRTransforms + MLIRQTensorTransforms MLIRQTensorDialect) mqt_mlir_target_use_project_options(MLIRSupportMQT) diff --git a/mlir/lib/Support/Passes.cpp b/mlir/lib/Support/Passes.cpp index 5e998761bc..254159d76b 100644 --- a/mlir/lib/Support/Passes.cpp +++ b/mlir/lib/Support/Passes.cpp @@ -10,6 +10,11 @@ #include "mlir/Support/Passes.h" +#include "mlir/Dialect/QC/Transforms/Passes.h" +#include "mlir/Dialect/QIR/Transforms/Passes.h" +#include "mlir/Dialect/QTensor/Transforms/Passes.h" + +#include #include #include #include @@ -17,12 +22,51 @@ using namespace mlir; -void runCanonicalizationPasses(ModuleOp module) { +static void addSimplificationPasses(PassManager& passManager) { + passManager.addPass(createCanonicalizerPass()); + passManager.addPass(createCSEPass()); +} + +static void +runWithPassManager(ModuleOp module, + const llvm::function_ref populatePasses, + const llvm::StringRef errorMessage) { PassManager pm(module.getContext()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - pm.addPass(createRemoveDeadValuesPass()); + populatePasses(pm); if (pm.run(module).failed()) { - llvm::errs() << "Failed to run canonicalization passes.\n"; + llvm::errs() << errorMessage << "\n"; } } + +void populateQCCleanupPipeline(PassManager& passManager) { + addSimplificationPasses(passManager); + passManager.addPass(qc::createShrinkQubitRegistersPass()); + passManager.addPass(createRemoveDeadValuesPass()); +} + +void populateQCOCleanupPipeline(PassManager& passManager) { + addSimplificationPasses(passManager); + passManager.addPass(qtensor::createShrinkQTensorToFitPass()); + passManager.addPass(createRemoveDeadValuesPass()); +} + +void populateQIRCleanupPipeline(PassManager& passManager) { + addSimplificationPasses(passManager); + passManager.addPass(qir::createQIRCleanupPass()); + passManager.addPass(createRemoveDeadValuesPass()); +} + +void runQCCleanupPipeline(ModuleOp module) { + runWithPassManager(module, populateQCCleanupPipeline, + "Failed to run QC cleanup pipeline."); +} + +void runQCOCleanupPipeline(ModuleOp module) { + runWithPassManager(module, populateQCOCleanupPipeline, + "Failed to run QCO cleanup pipeline."); +} + +void runQIRCleanupPipeline(ModuleOp module) { + runWithPassManager(module, populateQIRCleanupPipeline, + "Failed to run QIR cleanup pipeline."); +} diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index f24a4bbffb..169ec794a1 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -99,7 +99,7 @@ class CompilerPipelineTest [[nodiscard]] mlir::OwningOpRef buildQCReference(const QCProgramBuilderFn builder) const { auto module = mlir::qc::QCProgramBuilder::build(context.get(), builder.fn); - runCanonicalizationPasses(module.get()); + runQCCleanupPipeline(module.get()); return module; } @@ -107,7 +107,7 @@ class CompilerPipelineTest buildQIRReference(const QIRProgramBuilderFn builder) const { auto module = mlir::qir::QIRProgramBuilder::build(context.get(), builder.fn); - runCanonicalizationPasses(module.get()); + runQIRCleanupPipeline(module.get()); return module; } diff --git a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp index 48b0b22926..5b046e5e98 100644 --- a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp +++ b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp @@ -93,7 +93,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(program.get(), "Original QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -101,7 +101,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(program.get(), "Converted Jeff IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized Converted Jeff IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -109,7 +109,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(program.get(), "Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -119,7 +119,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); printer.record(reference.get(), "Canonicalized Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index 96dc828dab..4d3afbe03f 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -90,7 +90,7 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { printer.record(program.get(), "Original QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -98,7 +98,7 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { printer.record(program.get(), "Converted QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized Converted QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -108,7 +108,7 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCCleanupPipeline(reference.get()); printer.record(reference.get(), "Canonicalized Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index b55d1295e0..6c595e9b5b 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -89,7 +89,7 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { printer.record(program.get(), "Original QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -97,7 +97,7 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { printer.record(program.get(), "Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -107,7 +107,7 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); printer.record(reference.get(), "Canonicalized Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index 20b8687f9f..994637e348 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -87,7 +87,7 @@ TEST_P(QCToQIRTest, ProgramEquivalence) { printer.record(program.get(), "Original QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -95,7 +95,7 @@ TEST_P(QCToQIRTest, ProgramEquivalence) { printer.record(program.get(), "Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQIRCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -105,7 +105,7 @@ TEST_P(QCToQIRTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQIRCleanupPipeline(reference.get()); printer.record(reference.get(), "Canonicalized Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index 2f42ed4d06..cb10225730 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -78,7 +78,7 @@ TEST_P(QCTest, ProgramEquivalence) { printer.record(program.get(), "Original QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -87,7 +87,7 @@ TEST_P(QCTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCCleanupPipeline(reference.get()); printer.record(reference.get(), "Canonicalized Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp index ef41b04bb5..d2a544b1cc 100644 --- a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp +++ b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp @@ -81,7 +81,7 @@ TEST_P(QuantumComputationTranslationTest, ProgramEquivalence) { printer.record(translated.get(), "Translated QC IR" + name); EXPECT_TRUE(mlir::verify(*translated).succeeded()); - runCanonicalizationPasses(translated.get()); + runQCCleanupPipeline(translated.get()); printer.record(translated.get(), "Canonicalized Translated QC IR" + name); EXPECT_TRUE(mlir::verify(*translated).succeeded()); @@ -91,7 +91,7 @@ TEST_P(QuantumComputationTranslationTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QC IR" + name); EXPECT_TRUE(mlir::verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCCleanupPipeline(reference.get()); printer.record(reference.get(), "Canonicalized Reference QC IR" + name); EXPECT_TRUE(mlir::verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index efa1a0f066..3b7b1c9d59 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -68,8 +68,9 @@ class QCOTest : public testing::TestWithParam { context->loadAllAvailableDialects(); } }; +} // namespace -OwningOpRef +static OwningOpRef buildTwoQubitInsertChainProgram(MLIRContext* context, const bool reverseInsertOrder, const bool swapInsertTargets) { @@ -97,7 +98,7 @@ buildTwoQubitInsertChainProgram(MLIRContext* context, return builder.finalize(); } -OwningOpRef +static OwningOpRef buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, const bool swapInsertTargets) { qco::QCOProgramBuilder builder(context); @@ -140,7 +141,7 @@ buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, return builder.finalize(); } -OwningOpRef +static OwningOpRef buildMixedScalarSliceInsertProgram(MLIRContext* context, const bool reverseOrder, const bool overlap, const bool mutateScalar) { @@ -174,7 +175,7 @@ buildMixedScalarSliceInsertProgram(MLIRContext* context, return builder.finalize(); } -OwningOpRef +static OwningOpRef buildResetWithCommutingInsertProgram(MLIRContext* context, const bool withReset) { qco::QCOProgramBuilder builder(context); @@ -195,7 +196,7 @@ buildResetWithCommutingInsertProgram(MLIRContext* context, return builder.finalize(); } -OwningOpRef +static OwningOpRef buildResetWithSameIndexInsertProgram(MLIRContext* context, const bool withReset) { qco::QCOProgramBuilder builder(context); @@ -222,8 +223,6 @@ buildResetWithSameIndexInsertProgram(MLIRContext* context, return builder.finalize(); } -} // namespace - TEST_P(QCOTest, ProgramEquivalence) { const auto& [_, programBuilder, referenceBuilder] = GetParam(); const auto name = " (" + GetParam().name + ")"; @@ -234,7 +233,7 @@ TEST_P(QCOTest, ProgramEquivalence) { printer.record(program.get(), "Original QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -243,7 +242,7 @@ TEST_P(QCOTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); printer.record(reference.get(), "Canonicalized Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); @@ -255,13 +254,13 @@ TEST_F(QCOTest, InsertChainPermutationEquivalence) { auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); ASSERT_TRUE(program); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildTwoQubitInsertChainProgram(context.get(), true, false); ASSERT_TRUE(reference); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE( @@ -272,13 +271,13 @@ TEST_F(QCOTest, InsertChainDifferentAssignmentsNotEquivalent) { auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); ASSERT_TRUE(program); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildTwoQubitInsertChainProgram(context.get(), true, true); ASSERT_TRUE(reference); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_FALSE( @@ -289,13 +288,13 @@ TEST_F(QCOTest, MixedExtractInsertPermutationEquivalence) { auto program = buildMixedExtractInsertProgram(context.get(), false, false); ASSERT_TRUE(program); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildMixedExtractInsertProgram(context.get(), true, false); ASSERT_TRUE(reference); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE( @@ -306,13 +305,13 @@ TEST_F(QCOTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { auto program = buildMixedExtractInsertProgram(context.get(), false, false); ASSERT_TRUE(program); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildMixedExtractInsertProgram(context.get(), true, true); ASSERT_TRUE(reference); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_FALSE( @@ -324,14 +323,14 @@ TEST_F(QCOTest, MixedScalarSliceInsertPermutationEquivalence) { buildMixedScalarSliceInsertProgram(context.get(), false, false, false); ASSERT_TRUE(program); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildMixedScalarSliceInsertProgram(context.get(), true, false, false); ASSERT_TRUE(reference); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE( @@ -343,14 +342,14 @@ TEST_F(QCOTest, MixedScalarSliceInsertOverlapNotEquivalent) { buildMixedScalarSliceInsertProgram(context.get(), false, true, true); ASSERT_TRUE(program); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildMixedScalarSliceInsertProgram(context.get(), true, true, true); ASSERT_TRUE(reference); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_FALSE( @@ -361,13 +360,13 @@ TEST_F(QCOTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { auto program = buildResetWithCommutingInsertProgram(context.get(), true); ASSERT_TRUE(program); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildResetWithCommutingInsertProgram(context.get(), false); ASSERT_TRUE(reference); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE( @@ -378,13 +377,13 @@ TEST_F(QCOTest, ResetAfterExtractThroughSameIndexInsertIsNotEliminated) { auto program = buildResetWithSameIndexInsertProgram(context.get(), true); ASSERT_TRUE(program); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQCOCleanupPipeline(program.get()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildResetWithSameIndexInsertProgram(context.get(), false); ASSERT_TRUE(reference); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQCOCleanupPipeline(reference.get()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_FALSE( @@ -414,14 +413,14 @@ TEST_F(QCOTest, DirectIfBuilder) { auto directBuilder = builder.finalize(); ASSERT_TRUE(directBuilder); EXPECT_TRUE(verify(*directBuilder).succeeded()); - runCanonicalizationPasses(directBuilder.get()); + runQCOCleanupPipeline(directBuilder.get()); EXPECT_TRUE(verify(*directBuilder).succeeded()); auto refBuilder = QCOProgramBuilder::build(context.get(), MQT_NAMED_BUILDER(simpleIf).fn); ASSERT_TRUE(refBuilder); EXPECT_TRUE(verify(*refBuilder).succeeded()); - runCanonicalizationPasses(refBuilder.get()); + runQCOCleanupPipeline(refBuilder.get()); EXPECT_TRUE(verify(*refBuilder).succeeded()); EXPECT_TRUE(areModulesEquivalentWithPermutations(directBuilder.get(), diff --git a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp index 9d4b38f9c8..6d6c66d7f9 100644 --- a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp +++ b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp @@ -71,7 +71,7 @@ TEST_P(QIRTest, ProgramEquivalence) { printer.record(program.get(), "Original QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runCanonicalizationPasses(program.get()); + runQIRCleanupPipeline(program.get()); printer.record(program.get(), "Canonicalized QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -80,7 +80,7 @@ TEST_P(QIRTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runCanonicalizationPasses(reference.get()); + runQIRCleanupPipeline(reference.get()); printer.record(reference.get(), "Canonicalized Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); From 28d6c313eaa37f8091c3598e78222c0d21925c08 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Sun, 5 Apr 2026 17:43:05 +0200 Subject: [PATCH 47/71] =?UTF-8?q?=E2=9A=A1=20Reduce=20redundant=20lookups?= =?UTF-8?q?=20in=20QCToQCO.cpp?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 58 ++++++++++++++++--------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 0cde5c22be..75ebfc360d 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -161,18 +161,26 @@ currentModifierFrame(LoweringState& state) { return state.modifierFrames.back(); } -/** @brief Finds the nearest region-local map containing @p reference. */ -[[nodiscard]] static llvm::DenseMap* +/** + * @brief Finds the nearest region-local map containing @p reference and + * returns the pair containing the map and a mutable reference to the value in + * the map. + */ +[[nodiscard]] static std::pair*, Value*> findRegionLocalMap(llvm::DenseMap>& map, Operation* anchor, Value reference) { for (auto* current = anchor->getParentRegion(); current != nullptr; current = current->getParentRegion()) { - auto it = map.find(current); - if (it != map.end() && it->second.contains(reference)) { - return &it->second; + if (auto it = map.find(current); it != map.end()) { + auto& regionMap = it->second; + if (auto valueIt = regionMap.find(reference); + valueIt != regionMap.end()) { + return {®ionMap, &valueIt->second}; + } + return {®ionMap, nullptr}; } } - return nullptr; + return {nullptr, nullptr}; } /** @brief Resolves the latest QCO SSA value for a QC qubit reference. */ @@ -186,21 +194,20 @@ findRegionLocalMap(llvm::DenseMap>& map, } } - auto* qubitMap = findRegionLocalMap(state.qubitMap, anchor, qcQubit); - assert(qubitMap != nullptr && "QC qubit not found"); - auto it = qubitMap->find(qcQubit); - assert(it != qubitMap->end() && "QC qubit not found"); - return it->second; + const auto& [qubitMap, qubitValue] = + findRegionLocalMap(state.qubitMap, anchor, qcQubit); + assert(qubitMap != nullptr && qubitValue != nullptr && "QC qubit not found"); + return *qubitValue; } /** @brief Resolves the latest QTensor SSA value for a QC register. */ [[nodiscard]] static Value lookupMappedTensor(LoweringState& state, Operation* anchor, Value memref) { - auto* tensorMap = findRegionLocalMap(state.tensorMap, anchor, memref); - assert(tensorMap != nullptr && "QC register not found"); - auto it = tensorMap->find(memref); - assert(it != tensorMap->end() && "QC register not found"); - return it->second; + const auto& [tensorMap, tensorValue] = + findRegionLocalMap(state.tensorMap, anchor, memref); + assert(tensorMap != nullptr && tensorValue != nullptr && + "QC register not found"); + return *tensorValue; } /** @brief Updates the latest QCO SSA value for a QC qubit reference. */ @@ -215,22 +222,33 @@ static void assignMappedQubit(LoweringState& state, Operation* anchor, } } - if (auto* qubitMap = findRegionLocalMap(state.qubitMap, anchor, qcQubit)) { + auto [qubitMap, qubitValue] = + findRegionLocalMap(state.qubitMap, anchor, qcQubit); + if (qubitValue != nullptr) { + *qubitValue = qcoQubit; + return; + } + if (qubitMap != nullptr) { (*qubitMap)[qcQubit] = qcoQubit; return; } - state.qubitMap[anchor->getParentRegion()][qcQubit] = qcoQubit; } /** @brief Updates the latest QTensor SSA value for a QC register. */ static void assignMappedTensor(LoweringState& state, Operation* anchor, Value memref, Value tensor) { - if (auto* tensorMap = findRegionLocalMap(state.tensorMap, anchor, memref)) { + auto [tensorMap, tensorValue] = + findRegionLocalMap(state.tensorMap, anchor, memref); + + if (tensorValue != nullptr) { + *tensorValue = tensor; + return; + } + if (tensorMap != nullptr) { (*tensorMap)[memref] = tensor; return; } - state.tensorMap[anchor->getParentRegion()][memref] = tensor; } From 87cebaad29efe43af83d4c6957830a9ebd119b57 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Sun, 5 Apr 2026 18:22:04 +0200 Subject: [PATCH 48/71] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Streamline=20the=20Q?= =?UTF-8?q?C=20to=20QCO=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 66 ++++++++----------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 75ebfc360d..a7470901d2 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -408,23 +408,19 @@ struct ConvertMemRefAllocOp final return failure(); } - auto& state = getState(); - auto* operation = op.getOperation(); - auto memref = op.getResult(); - Value qtensor; if (shape[0] == ShapedType::kDynamic) { qtensor = rewriter.replaceOpWithNewOp( op, adaptor.getDynamicSizes()[0]); } else { - auto size = arith::ConstantOp::create(rewriter, op.getLoc(), - rewriter.getIndexAttr(shape[0])); + auto size = + arith::ConstantIndexOp::create(rewriter, op.getLoc(), shape[0]); qtensor = rewriter.replaceOpWithNewOp(op, size.getResult()); } - - assignMappedTensor(state, operation, memref, qtensor); + auto& state = getState(); + assignMappedTensor(state, qtensor.getDefiningOp(), memref, qtensor); return success(); } @@ -448,7 +444,8 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - if (!llvm::isa(op.getMemref().getType().getElementType())) { + auto memref = op.getMemref(); + if (!llvm::isa(memref.getType().getElementType())) { return failure(); } @@ -457,7 +454,6 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { auto* operation = op.getOperation(); // Look up latest QTensor value for this QC register - auto memref = op.getMemref(); auto qtensor = lookupMappedTensor(state, operation, memref); auto index = adaptor.getIndices()[0]; @@ -471,11 +467,11 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { assignMappedTensor(state, operation, memref, extract.getOutTensor()); QubitInfo info{.reg = memref, .index = index}; - if (auto it = qubitInfoMap.find(operation->getParentRegion()); - it != qubitInfoMap.end()) { + auto* parentRegion = operation->getParentRegion(); + if (auto it = qubitInfoMap.find(parentRegion); it != qubitInfoMap.end()) { it->second[qcQubit] = info; } else { - qubitInfoMap[operation->getParentRegion()][qcQubit] = info; + qubitInfoMap[parentRegion][qcQubit] = info; } rewriter.eraseOp(op); @@ -510,7 +506,8 @@ struct ConvertMemRefDeallocOp final LogicalResult matchAndRewrite(memref::DeallocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - if (!llvm::isa(op.getMemref().getType().getElementType())) { + auto memref = op.getMemref(); + if (!llvm::isa(memref.getType().getElementType())) { return failure(); } @@ -520,45 +517,24 @@ struct ConvertMemRefDeallocOp final auto& qubitInfoMap = state.qubitInfoMap[op->getParentRegion()]; // Look up latest QTensor value for this QC register - auto memref = op.getMemref(); auto qtensor = lookupMappedTensor(state, op.getOperation(), memref); // Filter out qubits belonging to this tensor - llvm::SmallVector> toInsert; - toInsert.reserve(qubitMap.size()); - for (auto [qcQubit, qcoQubit] : qubitMap) { - auto& info = qubitInfoMap[qcQubit]; - if (info.reg != memref) { + for (auto it = qubitMap.begin(); it != qubitMap.end(); ++it) { + auto& [qcQubit, qcoQubit] = *it; + auto& [reg, index] = qubitInfoMap[qcQubit]; + if (reg != memref) { continue; } - toInsert.emplace_back(qcQubit, qcoQubit); - } - - // Sort qubits for deterministic output - llvm::sort(toInsert, [](const auto& a, const auto& b) { - auto* opA = a.first.getDefiningOp(); - auto* opB = b.first.getDefiningOp(); - if (!opA || !opB || opA->getBlock() != opB->getBlock()) { - return a.first.getAsOpaquePointer() < b.first.getAsOpaquePointer(); - } - return opA->isBeforeInBlock(opB); - }); - - // Insert qubits - for (auto [qcQubit, qcoQubit] : toInsert) { - auto& info = qubitInfoMap[qcQubit]; - auto index = info.index; - auto insert = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, - qtensor, index); - qtensor = insert.getResult(); - qubitMap.erase(qcQubit); + qtensor = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, + qtensor, index) + .getResult(); + qubitMap.erase(it); qubitInfoMap.erase(qcQubit); } - - rewriter.replaceOpWithNewOp(op, qtensor); - tensorMap.erase(memref); + rewriter.replaceOpWithNewOp(op, qtensor); return success(); } }; @@ -670,7 +646,7 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { auto qcQubit = op.getQubit(); auto qcoOp = rewriter.replaceOpWithNewOp(op, op.getIndex()); - assignMappedQubit(state, operation, qcQubit, qcoOp.getQubit()); + assignMappedQubit(state, qcoOp, qcQubit, qcoOp.getQubit()); return success(); } From f0849fe9ffc21d07daf89710d3503f8592af46d6 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Mon, 6 Apr 2026 20:22:32 +0200 Subject: [PATCH 49/71] =?UTF-8?q?=F0=9F=94=A5=20Remove=20`qtensor.insert?= =?UTF-8?q?=5Fslice`=20and=20`qtensor.extract=5Fslice`=20operations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Neither of these was actually produced by the conversion from QC or Jeff, so there is no real reason to keep them around given the complexity that they cause. Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 57 --- .../mlir/Dialect/QTensor/IR/QTensorOps.td | 79 ---- .../mlir/Dialect/QTensor/IR/QTensorUtils.h | 86 +---- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 37 -- .../QTensor/IR/Operations/ExtractOp.cpp | 15 - .../QTensor/IR/Operations/ExtractSliceOp.cpp | 190 ---------- .../QTensor/IR/Operations/InsertOp.cpp | 24 +- .../QTensor/IR/Operations/InsertSliceOp.cpp | 186 ---------- .../QTensor/Transforms/ShrinkRegisters.cpp | 136 ------- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 346 ------------------ mlir/unittests/programs/qco_programs.cpp | 66 ---- mlir/unittests/programs/qco_programs.h | 27 -- 12 files changed, 5 insertions(+), 1244 deletions(-) delete mode 100644 mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp delete mode 100644 mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 9579f9e98c..6e86dd513f 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -274,34 +274,6 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ std::pair qtensorExtract(Value tensor, const int64_t index); - /** - * @brief Extract a qubit slice from a tensor - * - * @details - * Extracts a slice from a one-dimensional tensor of qubits at the given - * offset and size and returns the updated input tensor and the extracted - * tensor. The extracted tensor is added to the qubit tensor tracking and the - * tracking for the input tensor is updated. - * - * @param tensor Source tensor (must be valid/unconsumed) - * @param offset The offset from where the slice is extracted - * @param size The size of the extracted slice - * @return Pair of (outTensor, extractedSlice) - * - * @par Example: - * ```c++ - * auto [outTensor, extractedSlice] = builder.qtensorExtractSlice(tensor, 0, - * 2); - * ``` - * ```mlir - * %outTensor, %extractedSlice = qtensor.extract_slice %tensor[%c0][%c2] - * : tensor<3x!qco.qubit> to tensor<2x!qco.qubit> - * ``` - */ - std::pair - qtensorExtractSlice(Value tensor, const std::variant& offset, - const std::variant& size); - /** * @brief Insert a qubit into a tensor * @@ -327,35 +299,6 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { Value qtensorInsert(Value scalar, Value tensor, const std::variant& index); - /** - * @brief Insert a qubit slice into a tensor - * - * @details - * Inserts a one-dimensional tensor of qubits into another one-dimensional - * tensor of qubits at the given offset and size. The inserted tensor slice is - * consumed and removed from the tracking, while the tracking for the - * destination tensor is updated. - * - * @param sourceTensor The slice that is inserted (must be valid/unconsumed) - * @param destTensor The tensor where the slice is inserted (must be - * valid/unconsumed) - * @param offset The offset into where the slice is inserted - * @param size The size of the inserted slice - * @return The output tensor - * - * @par Example: - * ```c++ - * auto outTensor = builder.qtensorInsertSlice(slicedTensor, tensor, 0, 2); - * ``` - * ```mlir - * %outTensor = qtensor.insert_slice %slicedTensor into %tensor[%c0][%c2] - * : tensor<2x!qco.qubit> into tensor<3x!qco.qubit> - * ``` - */ - Value qtensorInsertSlice(Value sourceTensor, Value destTensor, - const std::variant& offset, - const std::variant& size); - /** * @brief Explicitly deallocate a tensor * diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index d7233f8460..1a0b98d6ea 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -139,48 +139,6 @@ def ExtractOp let hasVerifier = 1; } -def ExtractSliceOp - : QTensorOp<"extract_slice", - [Pure, - TypesMatchWith<"returned tensor type matches input tensor", - "tensor", "out_tensor", "$_self">]> { - let summary = "Extract slice from tensor"; - let description = [{ - The `qtensor.extract_slice` operation is the modified version of the standard `tensor.extract_slice` - operation of the tensor dialect. It reads a one-dimensional qubit tensor and returns the extracted tensor of qubits specified by the - offset and size argument. In addition, it also returns the updated input tensor as result. - - The extract_slice operation supports the following arguments: - - - tensor: the "base" tensor from which to extract a slice. - - offset: the starting position in the base tensor from which the slice is - extracted. - - size: the length of the slice to extract from the base tensor. - - Example: - ```mlir - %outTensor, %extractedSlice = qtensor.extract_slice %tensor[%c0][%c2] : tensor<3x!qco.qubit> to tensor<2x!qco.qubit> - ``` - }]; - - let arguments = (ins 1DTensorOf<[QubitType]>:$tensor, Index:$offset, - Index:$size); - let results = (outs 1DTensorOf<[QubitType]>:$out_tensor, - 1DTensorOf<[QubitType]>:$result); - - let assemblyFormat = [{ - $tensor `[`$offset `]` `[`$size`]` - attr-dict `:` type($tensor) `to` type($result) - }]; - - let builders = [OpBuilder<(ins "Value":$tensor, "Value":$offset, - "Value":$size, CArg<"ArrayRef", "{}">:$attrs)>]; - - let hasCanonicalizer = 1; - let hasFolder = 1; - let hasVerifier = 1; -} - def InsertOp : QTensorOp< "insert", @@ -214,41 +172,4 @@ def InsertOp let hasVerifier = 1; } -def InsertSliceOp - : QTensorOp<"insert_slice", - [Pure, TypesMatchWith<"expected result type to match dest type", - "dest", "result", "$_self">]> { - let summary = "Insert slice into tensor"; - let description = [{ - The `qtensor.insert_slice` operation is a modified version of the `tensor.insert_slice` operation - of the tensor dialect. The operation inserts a tensor `source` into another tensor `dest` as specified by the - operation's offset and size arguments. This insertion consumes the `source` tensor of qubits. - - The insert_slice operation supports the following arguments: - - - source: the tensor that is inserted. The source is consumed after the insertion. - - dest: the tensor into which the source tensor is inserted. - - offset: the starting index in the `dest` tensor where the slice is inserted. - - size: the number of elements in the slice, which must match the size of the - source tensor type. - - Example: - ```mlir - %outTensor = qtensor.insert_slice %slicedTensor into %tensor[%c0][%c2] : tensor<2x!qco.qubit> into tensor<3x!qco.qubit> - ``` - }]; - - let arguments = (ins 1DTensorOf<[QubitType]>:$source, - 1DTensorOf<[QubitType]>:$dest, Index:$offset, Index:$size); - let results = (outs 1DTensorOf<[QubitType]>:$result); - let assemblyFormat = [{ - $source `into` $dest `[`$offset `]` `[`$size`]` - attr-dict `:` type($source) `into` type($dest) - }]; - - let hasCanonicalizer = 1; - let hasFolder = 1; - let hasVerifier = 1; -} - #endif // MLIR_DIALECT_QTENSOR_IR_QTENSOROPS_TD diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h index 145e5e017b..b46f853cea 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h @@ -16,15 +16,8 @@ #include #include -#include - namespace mlir::qtensor { -/** - * @brief Relation of two tensor accesses. - */ -enum class AccessRelation : std::uint8_t { Disjoint, Overlap, Equal, Unknown }; - /** * @brief Checks whether two index values are equivalent for matching. */ @@ -33,69 +26,10 @@ inline bool areEquivalentIndices(Value lhs, Value rhs) { } /** - * @brief Checks whether two slice ranges are equivalent for matching. - */ -inline bool areEquivalentRanges(Value lhsOffset, Value lhsSize, Value rhsOffset, - Value rhsSize) { - return areEquivalentIndices(lhsOffset, rhsOffset) && - areEquivalentIndices(lhsSize, rhsSize); -} - -/** - * @brief Classify the relation between a scalar index and a slice range. - */ -inline AccessRelation classifyIndexAndRange(Value index, Value offset, - Value size) { - if (areEquivalentIndices(index, offset)) { - return AccessRelation::Overlap; - } - - const auto indexValue = getConstantIntValue(index); - const auto offsetValue = getConstantIntValue(offset); - const auto sizeValue = getConstantIntValue(size); - if (!indexValue || !offsetValue || !sizeValue) { - return AccessRelation::Unknown; - } - - if (*indexValue < *offsetValue || *indexValue >= *offsetValue + *sizeValue) { - return AccessRelation::Disjoint; - } - return AccessRelation::Overlap; -} - -/** - * @brief Classify the relation between two slice ranges. - */ -inline AccessRelation classifyRanges(Value lhsOffset, Value lhsSize, - Value rhsOffset, Value rhsSize) { - if (areEquivalentRanges(lhsOffset, lhsSize, rhsOffset, rhsSize)) { - return AccessRelation::Equal; - } - - const auto lhsOffsetValue = getConstantIntValue(lhsOffset); - const auto lhsSizeValue = getConstantIntValue(lhsSize); - const auto rhsOffsetValue = getConstantIntValue(rhsOffset); - const auto rhsSizeValue = getConstantIntValue(rhsSize); - if (!lhsOffsetValue || !lhsSizeValue || !rhsOffsetValue || !rhsSizeValue) { - if (areEquivalentIndices(lhsOffset, rhsOffset)) { - return AccessRelation::Overlap; - } - return AccessRelation::Unknown; - } - - const auto lhsEnd = *lhsOffsetValue + *lhsSizeValue; - const auto rhsEnd = *rhsOffsetValue + *rhsSizeValue; - if (lhsEnd <= *rhsOffsetValue || rhsEnd <= *lhsOffsetValue) { - return AccessRelation::Disjoint; - } - return AccessRelation::Overlap; -} - -/** - * @brief Tensor-transforming ops in a chain that can commute by index/range. + * @brief Tensor-transforming ops in a scalar extract/insert chain. */ inline bool isTensorChainOp(Operation* op) { - return llvm::isa(op); + return llvm::isa(op); } /** @@ -108,12 +42,6 @@ inline Value getTensorChainInput(Operation* op) { if (auto extractOp = llvm::dyn_cast(op)) { return extractOp.getTensor(); } - if (auto insertSliceOp = llvm::dyn_cast(op)) { - return insertSliceOp.getDest(); - } - if (auto extractSliceOp = llvm::dyn_cast(op)) { - return extractSliceOp.getTensor(); - } return nullptr; } @@ -127,12 +55,6 @@ inline Value getTensorChainOutput(Operation* op) { if (auto extractOp = llvm::dyn_cast(op)) { return extractOp.getOutTensor(); } - if (auto insertSliceOp = llvm::dyn_cast(op)) { - return insertSliceOp.getResult(); - } - if (auto extractSliceOp = llvm::dyn_cast(op)) { - return extractSliceOp.getOutTensor(); - } return nullptr; } @@ -140,11 +62,11 @@ inline Value getTensorChainOutput(Operation* op) { * @brief Rewire the tensor input of a tensor-transforming op. */ inline void setTensorChainInput(Operation* op, Value tensor) { - if (llvm::isa(op)) { + if (llvm::isa(op)) { op->setOperand(1, tensor); return; } - if (llvm::isa(op)) { + if (llvm::isa(op)) { op->setOperand(0, tensor); } } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index d6d40ad94b..4317a73d9a 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -246,24 +246,6 @@ std::pair QCOProgramBuilder::qtensorExtract(Value tensor, return {outTensor, qubit}; } -std::pair QCOProgramBuilder::qtensorExtractSlice( - Value tensor, const std::variant& offset, - const std::variant& size) { - checkFinalized(); - - auto offsetValue = variantToValue(*this, getLoc(), offset); - auto sizesValue = variantToValue(*this, getLoc(), size); - auto extractSliceOp = - qtensor::ExtractSliceOp::create(*this, tensor, offsetValue, sizesValue); - auto slicedTensor = extractSliceOp.getResult(); - auto outTensor = extractSliceOp.getOutTensor(); - - validTensors.try_emplace(slicedTensor, TensorInfo{tensorCounter++}); - updateTensorTracking(tensor, outTensor); - - return {outTensor, slicedTensor}; -} - Value QCOProgramBuilder::qtensorInsert( Value scalar, Value tensor, const std::variant& index) { checkFinalized(); @@ -280,25 +262,6 @@ Value QCOProgramBuilder::qtensorInsert( return outTensor; } -Value QCOProgramBuilder::qtensorInsertSlice( - Value source, Value dest, const std::variant& offset, - const std::variant& size) { - checkFinalized(); - - auto offsetValue = variantToValue(*this, getLoc(), offset); - auto sizeValue = variantToValue(*this, getLoc(), size); - auto insertSliceOp = qtensor::InsertSliceOp::create(*this, source, dest, - offsetValue, sizeValue); - - auto outTensor = insertSliceOp.getResult(); - - validateTensorValue(source); - validTensors.erase(source); - updateTensorTracking(dest, outTensor); - - return outTensor; -} - QCOProgramBuilder& QCOProgramBuilder::qtensorDealloc(Value tensor) { checkFinalized(); diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 2b4874fb28..6313ec142b 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -97,22 +97,7 @@ struct RemoveInsertExtractPair final : OpRewritePattern { // Do not reorder reads from the same index. return failure(); } - } else if (auto insertSliceOp = - llvm::dyn_cast(definingOp)) { - if (classifyIndexAndRange( - extractOp.getIndex(), insertSliceOp.getOffset(), - insertSliceOp.getSize()) != AccessRelation::Disjoint) { - return failure(); - } - } else if (auto extractSliceOp = - llvm::dyn_cast(definingOp)) { - if (classifyIndexAndRange( - extractOp.getIndex(), extractSliceOp.getOffset(), - extractSliceOp.getSize()) != AccessRelation::Disjoint) { - return failure(); - } } - traversedOps.push_back(definingOp); currentTensor = getTensorChainInput(definingOp); } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp deleted file mode 100644 index c3fdb0be8e..0000000000 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractSliceOp.cpp +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM - * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH - * All rights reserved. - * - * SPDX-License-Identifier: MIT - * - * Licensed under the MIT License - */ - -#include "mlir/Dialect/QTensor/IR/QTensorOps.h" -#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace mlir; -using namespace mlir::qtensor; - -void ExtractSliceOp::build(OpBuilder& b, OperationState& result, Value tensor, - Value offset, Value size, - ArrayRef attrs) { - auto tensorType = cast(tensor.getType()); - auto sizeValue = getConstantIntValue(size); - auto resultType = RankedTensorType::get( - {sizeValue ? *sizeValue : ShapedType::kDynamic}, - tensorType.getElementType(), tensorType.getEncoding()); - - result.addAttributes(attrs); - build(b, result, {tensor.getType(), resultType}, tensor, offset, size); -} - -LogicalResult ExtractSliceOp::verify() { - auto tensorDim = getTensor().getType().getDimSize(0); - auto resultDim = getResult().getType().getDimSize(0); - auto constOffset = getConstantIntValue(getOffset()); - auto constSize = getConstantIntValue(getSize()); - - if (constOffset && *constOffset < 0) { - return emitOpError("Offset must be non-negative"); - } - - if (constSize && *constSize <= 0) { - return emitOpError("Size must be positive"); - } - - if (constOffset && constSize && !ShapedType::isDynamic(tensorDim)) { - if (*constOffset + *constSize > tensorDim) { - return emitOpError("Offset + Size exceeds source dimension"); - } - } - - if (constSize && !ShapedType::isDynamic(resultDim)) { - if (resultDim != *constSize) { - return emitOpError("Result tensor dimension must match size operand"); - } - } - - return success(); -} - -/** - * @brief If an ExtractSliceOp consumes an InsertSliceOp with the same offset - * and size, return the sourceTensor and the destTensor from the InsertSliceOp - * directly. - */ -static InsertSliceOp -foldExtractAfterInsertSlice(ExtractSliceOp extractSliceOp) { - auto insertSliceOp = - extractSliceOp.getTensor().getDefiningOp(); - if (!insertSliceOp) { - return nullptr; - } - - if (!areEquivalentRanges(insertSliceOp.getOffset(), insertSliceOp.getSize(), - extractSliceOp.getOffset(), - extractSliceOp.getSize())) { - return nullptr; - } - - return insertSliceOp; -} - -LogicalResult ExtractSliceOp::fold(FoldAdaptor /*adaptor*/, - SmallVectorImpl& results) { - if (auto insertOp = foldExtractAfterInsertSlice(*this)) { - results.emplace_back(insertOp.getDest()); - results.emplace_back(insertOp.getSource()); - return success(); - } - - return failure(); -} - -namespace { - -/** - * @brief Remove matching insert_slice-extract_slice pairs through commuting - * disjoint tensor-chain operations. - */ -struct RemoveInsertSliceExtractSlicePair final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp, - PatternRewriter& rewriter) const override { - llvm::SmallVector traversedOps; - Value currentTensor = extractSliceOp.getTensor(); - InsertSliceOp matchedInsertSliceOp = nullptr; - - while (auto* definingOp = currentTensor.getDefiningOp()) { - if (!isTensorChainOp(definingOp)) { - break; - } - - if (auto insertSliceOp = llvm::dyn_cast(definingOp)) { - const auto relation = classifyRanges( - insertSliceOp.getOffset(), insertSliceOp.getSize(), - extractSliceOp.getOffset(), extractSliceOp.getSize()); - if (relation == AccessRelation::Equal) { - matchedInsertSliceOp = insertSliceOp; - break; - } - if (relation != AccessRelation::Disjoint) { - return failure(); - } - } else if (auto insertOp = llvm::dyn_cast(definingOp)) { - if (classifyIndexAndRange( - insertOp.getIndex(), extractSliceOp.getOffset(), - extractSliceOp.getSize()) != AccessRelation::Disjoint) { - return failure(); - } - } else if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { - if (classifyIndexAndRange( - nestedExtractOp.getIndex(), extractSliceOp.getOffset(), - extractSliceOp.getSize()) != AccessRelation::Disjoint) { - return failure(); - } - } else if (auto nestedExtractSliceOp = - llvm::dyn_cast(definingOp)) { - if (classifyRanges( - nestedExtractSliceOp.getOffset(), - nestedExtractSliceOp.getSize(), extractSliceOp.getOffset(), - extractSliceOp.getSize()) != AccessRelation::Disjoint) { - return failure(); - } - } - - traversedOps.push_back(definingOp); - currentTensor = getTensorChainInput(definingOp); - } - - if (!matchedInsertSliceOp) { - return failure(); - } - - Value outTensor = matchedInsertSliceOp.getDest(); - if (!traversedOps.empty()) { - Operation* oldestCommutedOp = traversedOps.back(); - rewriter.modifyOpInPlace(oldestCommutedOp, [&]() { - setTensorChainInput(oldestCommutedOp, matchedInsertSliceOp.getDest()); - }); - outTensor = getTensorChainOutput(traversedOps.front()); - if (!outTensor) { - return failure(); - } - } - - rewriter.replaceOp(extractSliceOp, - {outTensor, matchedInsertSliceOp.getSource()}); - return success(); - } -}; - -} // namespace - -void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, - MLIRContext* context) { - results.add(context); -} diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index c8a8839492..88f066f750 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -58,13 +58,12 @@ OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { if (auto result = foldInsertAfterExtract(*this)) { return result; } - return {}; } /** * @brief Find a matching `qtensor.extract` for an insert index in a tensor - * chain by traversing nested scalar and slice tensor ops. + * chain by traversing nested scalar tensor ops. */ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { Value current = tensor; @@ -77,16 +76,6 @@ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { current = nestedInsertOp.getDest(); continue; } - if (auto nestedInsertSliceOp = llvm::dyn_cast(definingOp)) { - if (classifyIndexAndRange(index, nestedInsertSliceOp.getOffset(), - nestedInsertSliceOp.getSize()) != - AccessRelation::Disjoint) { - return nullptr; - } - current = nestedInsertSliceOp.getDest(); - continue; - } - if (auto extractOp = llvm::dyn_cast(definingOp)) { if (areEquivalentIndices(extractOp.getIndex(), index)) { return extractOp; @@ -94,19 +83,8 @@ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { current = extractOp.getTensor(); continue; } - if (auto extractSliceOp = llvm::dyn_cast(definingOp)) { - if (classifyIndexAndRange(index, extractSliceOp.getOffset(), - extractSliceOp.getSize()) != - AccessRelation::Disjoint) { - return nullptr; - } - current = extractSliceOp.getTensor(); - continue; - } - break; } - return nullptr; } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp deleted file mode 100644 index 6d2f6628a5..0000000000 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertSliceOp.cpp +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM - * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH - * All rights reserved. - * - * SPDX-License-Identifier: MIT - * - * Licensed under the MIT License - */ - -#include "mlir/Dialect/QTensor/IR/QTensorOps.h" -#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace mlir; -using namespace mlir::qtensor; - -/** - * @brief Checks whether removing an extract_slice-insert_slice pair is - * linearity-safe. - */ -static bool -isRemovableExtractSliceInsertSlicePair(InsertSliceOp insertSliceOp, - ExtractSliceOp extractSliceOp) { - return insertSliceOp.getSource() == extractSliceOp.getResult() && - areEquivalentRanges(insertSliceOp.getOffset(), insertSliceOp.getSize(), - extractSliceOp.getOffset(), - extractSliceOp.getSize()); -} - -/** - * @brief Find a matching `qtensor.extract_slice` for an insert_slice range in - * a tensor chain by traversing scalar and slice tensor operations. - */ -static ExtractSliceOp -findMatchingExtractSliceInTensorChain(Value tensor, Value offset, Value size) { - Value current = tensor; - while (Operation* definingOp = current.getDefiningOp()) { - if (auto nestedInsertOp = llvm::dyn_cast(definingOp)) { - if (classifyIndexAndRange(nestedInsertOp.getIndex(), offset, size) != - AccessRelation::Disjoint) { - return nullptr; - } - current = nestedInsertOp.getDest(); - continue; - } - if (auto nestedInsertSliceOp = llvm::dyn_cast(definingOp)) { - if (classifyRanges(nestedInsertSliceOp.getOffset(), - nestedInsertSliceOp.getSize(), offset, - size) != AccessRelation::Disjoint) { - return nullptr; - } - current = nestedInsertSliceOp.getDest(); - continue; - } - if (auto extractOp = llvm::dyn_cast(definingOp)) { - if (classifyIndexAndRange(extractOp.getIndex(), offset, size) != - AccessRelation::Disjoint) { - return nullptr; - } - current = extractOp.getTensor(); - continue; - } - if (auto extractSliceOp = llvm::dyn_cast(definingOp)) { - const auto relation = classifyRanges( - extractSliceOp.getOffset(), extractSliceOp.getSize(), offset, size); - if (relation == AccessRelation::Equal) { - return extractSliceOp; - } - if (relation != AccessRelation::Disjoint) { - return nullptr; - } - current = extractSliceOp.getTensor(); - continue; - } - - break; - } - - return nullptr; -} - -LogicalResult InsertSliceOp::verify() { - auto srcDim = getSource().getType().getDimSize(0); - auto dstDim = getDest().getType().getDimSize(0); - auto constOffset = getConstantIntValue(getOffset()); - auto constSize = getConstantIntValue(getSize()); - - if (constOffset && *constOffset < 0) { - return emitOpError("Offset must be non-negative"); - } - - if (constSize && *constSize <= 0) { - return emitOpError("Size must be positive"); - } - - if (constSize && !ShapedType::isDynamic(srcDim)) { - if (*constSize != srcDim) { - return emitOpError("Size must match source dimension"); - } - } - - if (constOffset && constSize && !ShapedType::isDynamic(dstDim)) { - if (*constSize > dstDim || *constOffset > dstDim - *constSize) { - return emitOpError("Offset + Size exceeds destination dimension"); - } - } - - return success(); -} - -/** - * @brief If an InsertSliceOp consumes an ExtractSliceOp with the same offset - * and size, return the sourceTensor from the extractSliceOp directly. - */ -static Value foldInsertAfterExtractSlice(InsertSliceOp insertSliceOp) { - auto extractSliceOp = - insertSliceOp.getSource().getDefiningOp(); - if (!extractSliceOp) { - return nullptr; - } - - if (extractSliceOp.getOutTensor() != insertSliceOp.getDest()) { - return nullptr; - } - - if (!areEquivalentRanges(insertSliceOp.getOffset(), insertSliceOp.getSize(), - extractSliceOp.getOffset(), - extractSliceOp.getSize())) { - return nullptr; - } - - return extractSliceOp.getTensor(); -} - -OpFoldResult InsertSliceOp::fold(FoldAdaptor /*adaptor*/) { - if (auto result = foldInsertAfterExtractSlice(*this)) { - return result; - } - - return {}; -} - -namespace { - -/** - * @brief Remove matching `qtensor.insert_slice` and `qtensor.extract_slice` - * pairs through commuting disjoint tensor-chain operations. - */ -struct RemoveExtractSliceInsertSlicePair final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InsertSliceOp op, - PatternRewriter& rewriter) const override { - auto extractSliceOp = findMatchingExtractSliceInTensorChain( - op.getDest(), op.getOffset(), op.getSize()); - if (!extractSliceOp) { - return failure(); - } - - if (!isRemovableExtractSliceInsertSlicePair(op, extractSliceOp)) { - return failure(); - } - - rewriter.replaceOp(op, op.getDest()); - rewriter.replaceOp(extractSliceOp, {extractSliceOp.getTensor(), nullptr}); - return success(); - } -}; - -} // namespace - -void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, - MLIRContext* context) { - results.add(context); -} diff --git a/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp index 36030c7cee..499c57abc4 100644 --- a/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp +++ b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp @@ -48,22 +48,6 @@ namespace mlir::qtensor { return success(); } -/** - * @brief Mark a contiguous live range. - */ -[[nodiscard]] static LogicalResult markLiveRange(const int64_t offset, - const int64_t size, - llvm::BitVector& liveIndices) { - if (offset < 0 || size <= 0 || - offset + size > static_cast(liveIndices.size())) { - return failure(); - } - for (int64_t index = offset; index < offset + size; ++index) { - liveIndices.set(static_cast(index)); - } - return success(); -} - /** * @brief Redirect the tensor operand from @p from to @p to. */ @@ -83,20 +67,6 @@ namespace mlir::qtensor { insertOp->setOperand(1, to); return success(); } - if (auto extractSliceOp = llvm::dyn_cast(op)) { - if (extractSliceOp.getTensor() != from) { - return failure(); - } - extractSliceOp->setOperand(0, to); - return success(); - } - if (auto insertSliceOp = llvm::dyn_cast(op)) { - if (insertSliceOp.getDest() != from) { - return failure(); - } - insertSliceOp->setOperand(1, to); - return success(); - } if (auto deallocOp = llvm::dyn_cast(op)) { if (deallocOp.getTensor() != from) { return failure(); @@ -152,32 +122,6 @@ namespace mlir::qtensor { continue; } - if (auto extractSliceOp = llvm::dyn_cast(user)) { - if (extractSliceOp.getTensor() != tensor) { - return failure(); - } - auto offset = getConstantIntValue(extractSliceOp.getOffset()); - auto size = getConstantIntValue(extractSliceOp.getSize()); - if (!offset || !size || failed(markLiveRange(*offset, *size, live))) { - return failure(); - } - tensor = extractSliceOp.getOutTensor(); - continue; - } - - if (auto insertSliceOp = llvm::dyn_cast(user)) { - if (insertSliceOp.getDest() != tensor) { - return failure(); - } - auto offset = getConstantIntValue(insertSliceOp.getOffset()); - auto size = getConstantIntValue(insertSliceOp.getSize()); - if (!offset || !size || failed(markLiveRange(*offset, *size, live))) { - return failure(); - } - tensor = insertSliceOp.getResult(); - continue; - } - return failure(); } } @@ -314,86 +258,6 @@ struct ShrinkStaticQTensor final : OpRewritePattern { continue; } - if (auto extractSliceOp = llvm::dyn_cast(currentOp)) { - if (extractSliceOp.getTensor() != oldTensor) { - return failure(); - } - const auto oldOffset = *getConstantIntValue(extractSliceOp.getOffset()); - const auto oldSliceSize = - *getConstantIntValue(extractSliceOp.getSize()); - if (oldOffset < 0 || oldSliceSize <= 0 || - oldOffset + oldSliceSize > - static_cast(newIndexByOldIndex.size())) { - return failure(); - } - const auto mappedOffset = - newIndexByOldIndex[static_cast(oldOffset)]; - if (mappedOffset < 0) { - return failure(); - } - Value oldOutTensor = extractSliceOp.getOutTensor(); - Operation* nextOp = getLinearTensorUser(oldOutTensor); - if (!nextOp) { - return failure(); - } - rewriter.setInsertionPoint(extractSliceOp); - auto newOffset = arith::ConstantIndexOp::create( - rewriter, extractSliceOp.getLoc(), mappedOffset); - auto newSliceSize = arith::ConstantIndexOp::create( - rewriter, extractSliceOp.getLoc(), oldSliceSize); - auto newExtractSlice = ExtractSliceOp::create( - rewriter, extractSliceOp.getLoc(), currentTensor, - newOffset.getResult(), newSliceSize.getResult()); - rewriter.replaceAllUsesWith(extractSliceOp.getResult(), - newExtractSlice.getResult()); - - currentTensor = newExtractSlice.getOutTensor(); - if (failed(remapTensorOperand(nextOp, oldOutTensor, oldTensor))) { - return failure(); - } - rewriter.eraseOp(extractSliceOp); - continue; - } - - if (auto insertSliceOp = llvm::dyn_cast(currentOp)) { - if (insertSliceOp.getDest() != oldTensor) { - return failure(); - } - const auto oldOffset = *getConstantIntValue(insertSliceOp.getOffset()); - const auto oldSliceSize = *getConstantIntValue(insertSliceOp.getSize()); - if (oldOffset < 0 || oldSliceSize <= 0 || - oldOffset + oldSliceSize > - static_cast(newIndexByOldIndex.size())) { - return failure(); - } - const auto mappedOffset = - newIndexByOldIndex[static_cast(oldOffset)]; - if (mappedOffset < 0) { - return failure(); - } - Value oldResultTensor = insertSliceOp.getResult(); - Operation* nextOp = getLinearTensorUser(oldResultTensor); - if (!nextOp) { - return failure(); - } - - rewriter.setInsertionPoint(insertSliceOp); - auto newOffset = arith::ConstantIndexOp::create( - rewriter, insertSliceOp.getLoc(), mappedOffset); - auto newSliceSize = arith::ConstantIndexOp::create( - rewriter, insertSliceOp.getLoc(), oldSliceSize); - auto newInsertSlice = InsertSliceOp::create( - rewriter, insertSliceOp.getLoc(), insertSliceOp.getSource(), - currentTensor, newOffset.getResult(), newSliceSize.getResult()); - - currentTensor = newInsertSlice.getResult(); - if (failed(remapTensorOperand(nextOp, oldResultTensor, oldTensor))) { - return failure(); - } - rewriter.eraseOp(insertSliceOp); - continue; - } - return failure(); } diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 3b7b1c9d59..f7b3235451 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -70,159 +70,6 @@ class QCOTest : public testing::TestWithParam { }; } // namespace -static OwningOpRef -buildTwoQubitInsertChainProgram(MLIRContext* context, - const bool reverseInsertOrder, - const bool swapInsertTargets) { - qco::QCOProgramBuilder builder(context); - builder.initialize(); - - auto tensor = builder.qtensorAlloc(2); - auto [tensorAfterFirstExtract, qubit0] = builder.qtensorExtract(tensor, 0); - auto [baseTensor, qubit1] = - builder.qtensorExtract(tensorAfterFirstExtract, 1); - - const int64_t qubit0Target = swapInsertTargets ? 1 : 0; - const int64_t qubit1Target = swapInsertTargets ? 0 : 1; - - Value currentTensor = baseTensor; - if (reverseInsertOrder) { - currentTensor = builder.qtensorInsert(qubit1, currentTensor, qubit1Target); - currentTensor = builder.qtensorInsert(qubit0, currentTensor, qubit0Target); - } else { - currentTensor = builder.qtensorInsert(qubit0, currentTensor, qubit0Target); - currentTensor = builder.qtensorInsert(qubit1, currentTensor, qubit1Target); - } - - builder.qtensorDealloc(currentTensor); - return builder.finalize(); -} - -static OwningOpRef -buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, - const bool swapInsertTargets) { - qco::QCOProgramBuilder builder(context); - builder.initialize(); - - auto tensor = builder.qtensorAlloc(3); - Value tensorAfterReads = tensor; - Value qubit0 = nullptr; - Value qubit1 = nullptr; - - if (reverseOrder) { - std::tie(tensorAfterReads, qubit1) = - builder.qtensorExtract(tensorAfterReads, 1); - std::tie(tensorAfterReads, qubit0) = - builder.qtensorExtract(tensorAfterReads, 0); - } else { - std::tie(tensorAfterReads, qubit0) = - builder.qtensorExtract(tensorAfterReads, 0); - std::tie(tensorAfterReads, qubit1) = - builder.qtensorExtract(tensorAfterReads, 1); - } - - const int64_t q0Target = 0; - const int64_t q1Target = swapInsertTargets ? 2 : 1; - - Value tensorAfterWrites = tensorAfterReads; - if (reverseOrder) { - tensorAfterWrites = - builder.qtensorInsert(qubit0, tensorAfterWrites, q0Target); - tensorAfterWrites = - builder.qtensorInsert(qubit1, tensorAfterWrites, q1Target); - } else { - tensorAfterWrites = - builder.qtensorInsert(qubit1, tensorAfterWrites, q1Target); - tensorAfterWrites = - builder.qtensorInsert(qubit0, tensorAfterWrites, q0Target); - } - - builder.qtensorDealloc(tensorAfterWrites); - return builder.finalize(); -} - -static OwningOpRef -buildMixedScalarSliceInsertProgram(MLIRContext* context, - const bool reverseOrder, const bool overlap, - const bool mutateScalar) { - qco::QCOProgramBuilder builder(context); - builder.initialize(); - - auto tensor = builder.qtensorAlloc(6); - auto [tensorAfterSliceExtract, slice] = - builder.qtensorExtractSlice(tensor, 1, 2); - const int64_t scalarIndex = overlap ? 1 : 5; - auto [tensorAfterScalarExtract, scalar] = - builder.qtensorExtract(tensorAfterSliceExtract, scalarIndex); - if (mutateScalar) { - scalar = builder.h(scalar); - } - - Value tensorAfterWrites = tensorAfterScalarExtract; - if (reverseOrder) { - tensorAfterWrites = - builder.qtensorInsertSlice(slice, tensorAfterWrites, 1, 2); - tensorAfterWrites = - builder.qtensorInsert(scalar, tensorAfterWrites, scalarIndex); - } else { - tensorAfterWrites = - builder.qtensorInsert(scalar, tensorAfterWrites, scalarIndex); - tensorAfterWrites = - builder.qtensorInsertSlice(slice, tensorAfterWrites, 1, 2); - } - - builder.qtensorDealloc(tensorAfterWrites); - return builder.finalize(); -} - -static OwningOpRef -buildResetWithCommutingInsertProgram(MLIRContext* context, - const bool withReset) { - qco::QCOProgramBuilder builder(context); - builder.initialize(); - - auto tensor = builder.qtensorAlloc(2); - auto [tensorAfterExtract0, qubit0] = builder.qtensorExtract(tensor, 0); - auto tensorAfterInsert0 = - builder.qtensorInsert(qubit0, tensorAfterExtract0, 0); - auto [tensorAfterExtract1, qubit1] = - builder.qtensorExtract(tensorAfterInsert0, 1); - if (withReset) { - qubit1 = builder.reset(qubit1); - } - auto tensorFinal = builder.qtensorInsert(qubit1, tensorAfterExtract1, 1); - builder.qtensorDealloc(tensorFinal); - - return builder.finalize(); -} - -static OwningOpRef -buildResetWithSameIndexInsertProgram(MLIRContext* context, - const bool withReset) { - qco::QCOProgramBuilder builder(context); - builder.initialize(); - - auto tensor = builder.qtensorAlloc(2); - auto [tensorAfterExtract0, qubit0] = builder.qtensorExtract(tensor, 0); - auto [tensorAfterExtract1, qubit1] = - builder.qtensorExtract(tensorAfterExtract0, 1); - qubit1 = builder.h(qubit1); - auto tensorAfterInsert1 = - builder.qtensorInsert(qubit1, tensorAfterExtract1, 1); - auto [tensorAfterReadBack1, qubit1ReadBack] = - builder.qtensorExtract(tensorAfterInsert1, 1); - if (withReset) { - qubit1ReadBack = builder.reset(qubit1ReadBack); - } - auto tensorAfterInsert1ReadBack = - builder.qtensorInsert(qubit1ReadBack, tensorAfterReadBack1, 1); - auto tensorFinal = - builder.qtensorInsert(qubit0, tensorAfterInsert1ReadBack, 0); - builder.qtensorDealloc(tensorFinal); - - return builder.finalize(); -} - TEST_P(QCOTest, ProgramEquivalence) { const auto& [_, programBuilder, referenceBuilder] = GetParam(); const auto name = " (" + GetParam().name + ")"; @@ -250,146 +97,6 @@ TEST_P(QCOTest, ProgramEquivalence) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } -TEST_F(QCOTest, InsertChainPermutationEquivalence) { - auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); - ASSERT_TRUE(program); - EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); - EXPECT_TRUE(verify(*program).succeeded()); - - auto reference = buildTwoQubitInsertChainProgram(context.get(), true, false); - ASSERT_TRUE(reference); - EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_TRUE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); -} - -TEST_F(QCOTest, InsertChainDifferentAssignmentsNotEquivalent) { - auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); - ASSERT_TRUE(program); - EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); - EXPECT_TRUE(verify(*program).succeeded()); - - auto reference = buildTwoQubitInsertChainProgram(context.get(), true, true); - ASSERT_TRUE(reference); - EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_FALSE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); -} - -TEST_F(QCOTest, MixedExtractInsertPermutationEquivalence) { - auto program = buildMixedExtractInsertProgram(context.get(), false, false); - ASSERT_TRUE(program); - EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); - EXPECT_TRUE(verify(*program).succeeded()); - - auto reference = buildMixedExtractInsertProgram(context.get(), true, false); - ASSERT_TRUE(reference); - EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_TRUE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); -} - -TEST_F(QCOTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { - auto program = buildMixedExtractInsertProgram(context.get(), false, false); - ASSERT_TRUE(program); - EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); - EXPECT_TRUE(verify(*program).succeeded()); - - auto reference = buildMixedExtractInsertProgram(context.get(), true, true); - ASSERT_TRUE(reference); - EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_FALSE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); -} - -TEST_F(QCOTest, MixedScalarSliceInsertPermutationEquivalence) { - auto program = - buildMixedScalarSliceInsertProgram(context.get(), false, false, false); - ASSERT_TRUE(program); - EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); - EXPECT_TRUE(verify(*program).succeeded()); - - auto reference = - buildMixedScalarSliceInsertProgram(context.get(), true, false, false); - ASSERT_TRUE(reference); - EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_TRUE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); -} - -TEST_F(QCOTest, MixedScalarSliceInsertOverlapNotEquivalent) { - auto program = - buildMixedScalarSliceInsertProgram(context.get(), false, true, true); - ASSERT_TRUE(program); - EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); - EXPECT_TRUE(verify(*program).succeeded()); - - auto reference = - buildMixedScalarSliceInsertProgram(context.get(), true, true, true); - ASSERT_TRUE(reference); - EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_FALSE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); -} - -TEST_F(QCOTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { - auto program = buildResetWithCommutingInsertProgram(context.get(), true); - ASSERT_TRUE(program); - EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); - EXPECT_TRUE(verify(*program).succeeded()); - - auto reference = buildResetWithCommutingInsertProgram(context.get(), false); - ASSERT_TRUE(reference); - EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_TRUE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); -} - -TEST_F(QCOTest, ResetAfterExtractThroughSameIndexInsertIsNotEliminated) { - auto program = buildResetWithSameIndexInsertProgram(context.get(), true); - ASSERT_TRUE(program); - EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); - EXPECT_TRUE(verify(*program).succeeded()); - - auto reference = buildResetWithSameIndexInsertProgram(context.get(), false); - ASSERT_TRUE(reference); - EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_FALSE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); -} - TEST_F(QCOTest, DirectIfBuilder) { // Test If construction directly QCOProgramBuilder builder(context.get()); @@ -1378,56 +1085,3 @@ INSTANTIATE_TEST_SUITE_P( QCOTestCase{"AllocSinkPair", MQT_NAMED_BUILDER(allocSinkPair), MQT_NAMED_BUILDER(emptyQCO)})); /// @} - -/// \name QTensor/QTensor.cpp -/// @{ -INSTANTIATE_TEST_SUITE_P( - QTensorTest, QCOTest, - testing::Values( - QCOTestCase{"QTensorAlloc", MQT_NAMED_BUILDER(qtensorAlloc), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QCOTestCase{"QTensorAllocDealloc", MQT_NAMED_BUILDER(qtensorDealloc), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QCOTestCase{"QTensorFromElements", - MQT_NAMED_BUILDER(qtensorFromElements), - MQT_NAMED_BUILDER(qtensorFromElements)}, - QCOTestCase{"QTensorExtract", MQT_NAMED_BUILDER(qtensorExtract), - MQT_NAMED_BUILDER(qtensorExtract)}, - QCOTestCase{"QTensorInsert", MQT_NAMED_BUILDER(qtensorInsert), - MQT_NAMED_BUILDER(qtensorInsert)}, - QCOTestCase{"QTensorExtractSlice", - MQT_NAMED_BUILDER(qtensorExtractSlice), - MQT_NAMED_BUILDER(qtensorExtractSlice)}, - QCOTestCase{"QTensorInsertSlice", MQT_NAMED_BUILDER(qtensorInsertSlice), - MQT_NAMED_BUILDER(qtensorInsertSlice)}, - QCOTestCase{"QTensorExtractInsertSameIndex", - MQT_NAMED_BUILDER(qtensorExtractInsertSameIndex), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QCOTestCase{"QTensorExtractInsertIndexMismatch", - MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch), - MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch)}, - QCOTestCase{"QTensorInsertExtractSameIndex", - MQT_NAMED_BUILDER(qtensorInsertExtractSameIndex), - MQT_NAMED_BUILDER(qtensorInsert)}, - QCOTestCase{"QTensorInsertExtractIndexMismatch", - MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch), - MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch)}, - QCOTestCase{"QTensorExtractSliceInsertSliceSameOffset", - MQT_NAMED_BUILDER(qtensorExtractSliceInsertSliceSameOffset), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QCOTestCase{ - "QTensorExtractSliceInsertSliceOffsetMismatch", - MQT_NAMED_BUILDER(qtensorExtractSliceInsertSliceOffsetMismatch), - MQT_NAMED_BUILDER(qtensorExtractSliceInsertSliceOffsetMismatch)}, - QCOTestCase{"QTensorInsertSliceExtractSliceSameOffset", - MQT_NAMED_BUILDER(qtensorInsertSliceExtractSliceSameOffset), - MQT_NAMED_BUILDER(qtensorInsertSlice)}, - QCOTestCase{ - "QTensorInsertSliceExtractSliceOffsetMismatch", - MQT_NAMED_BUILDER(qtensorInsertSliceExtractSliceOffsetMismatch), - MQT_NAMED_BUILDER(qtensorInsertSliceExtractSliceOffsetMismatch)}, - QCOTestCase{ - "QTensorExtractSliceExtractInsertInsertSlice", - MQT_NAMED_BUILDER(qtensorExtractSliceExtractInsertInsertSlice), - MQT_NAMED_BUILDER(qtensorAlloc)})); -/// @} diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 22da726ef0..0f86e41595 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -2193,21 +2193,6 @@ void qtensorInsert(QCOProgramBuilder& b) { b.qtensorInsert(q1, extractOutTensor, 0); } -void qtensorExtractSlice(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - b.qtensorExtractSlice(qtensor, 0, 2); -} - -void qtensorInsertSlice(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - auto [extractOutTensor, q0] = b.qtensorExtract(slicedTensor, 0); - auto q1 = b.h(q0); - auto insertOutTensor = b.qtensorInsert(q1, extractOutTensor, 0); - b.qtensorInsertSlice(insertOutTensor, extractSliceOutTensor, 0, 2); -} - void qtensorExtractInsertIndexMismatch(QCOProgramBuilder& b) { auto qtensor = b.qtensorAlloc(3); auto [extractOutTensor, q0] = b.qtensorExtract(qtensor, 0); @@ -2220,20 +2205,6 @@ void qtensorExtractInsertSameIndex(QCOProgramBuilder& b) { b.qtensorInsert(q0, extractOutTensor, 0); } -void qtensorExtractSliceInsertSliceOffsetMismatch(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - b.qtensorInsertSlice(slicedTensor, extractSliceOutTensor, 1, 2); -} - -void qtensorExtractSliceInsertSliceSameOffset(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - b.qtensorInsertSlice(slicedTensor, extractSliceOutTensor, 0, 2); -} - void qtensorInsertExtractIndexMismatch(QCOProgramBuilder& b) { auto qtensor = b.qtensorAlloc(3); auto [extractOutTensor, q0] = b.qtensorExtract(qtensor, 0); @@ -2252,41 +2223,4 @@ void qtensorInsertExtractSameIndex(QCOProgramBuilder& b) { b.qtensorInsert(q2, extractOutTensor1, 0); } -void qtensorInsertSliceExtractSliceOffsetMismatch(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - auto [extractOutTensor, q0] = b.qtensorExtract(slicedTensor, 0); - auto q1 = b.h(q0); - auto insertOutTensor = b.qtensorInsert(q1, extractOutTensor, 0); - auto insertSliceOutTensor = - b.qtensorInsertSlice(insertOutTensor, extractSliceOutTensor, 0, 2); - auto [extractSliceOutTensor1, slicedTensor1] = - b.qtensorExtractSlice(insertSliceOutTensor, 1, 2); - b.qtensorInsertSlice(slicedTensor1, extractSliceOutTensor1, 0, 2); -} - -void qtensorInsertSliceExtractSliceSameOffset(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - auto [extractOutTensor, q0] = b.qtensorExtract(slicedTensor, 0); - auto q1 = b.h(q0); - auto insertOutTensor = b.qtensorInsert(q1, extractOutTensor, 0); - auto insertSliceOutTensor = - b.qtensorInsertSlice(insertOutTensor, extractSliceOutTensor, 0, 2); - auto [extractSliceOutTensor1, slicedTensor1] = - b.qtensorExtractSlice(insertSliceOutTensor, 0, 2); - b.qtensorInsertSlice(slicedTensor1, extractSliceOutTensor1, 0, 2); -} - -void qtensorExtractSliceExtractInsertInsertSlice(QCOProgramBuilder& b) { - auto qtensor = b.qtensorAlloc(3); - auto [extractSliceOutTensor, slicedTensor] = - b.qtensorExtractSlice(qtensor, 0, 2); - auto [extractOutTensor, q0] = b.qtensorExtract(slicedTensor, 0); - auto insertOutTensor = b.qtensorInsert(q0, extractOutTensor, 0); - b.qtensorInsertSlice(insertOutTensor, extractSliceOutTensor, 0, 2); -} - } // namespace mlir::qco diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index a14100102c..3d503bc472 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -1007,12 +1007,6 @@ void qtensorExtract(QCOProgramBuilder& b); /// Inserts a qubit into a tensor. void qtensorInsert(QCOProgramBuilder& b); -/// Extracts a slice from a tensor. -void qtensorExtractSlice(QCOProgramBuilder& b); - -/// Inserts a slice into a tensor. -void qtensorInsertSlice(QCOProgramBuilder& b); - /// Extracts a qubit from a tensor and inserts it immediately at a different /// index. void qtensorExtractInsertIndexMismatch(QCOProgramBuilder& b); @@ -1027,25 +1021,4 @@ void qtensorInsertExtractIndexMismatch(QCOProgramBuilder& b); /// Inserts a qubit into a tensor and extracts it immediately at the same index. void qtensorInsertExtractSameIndex(QCOProgramBuilder& b); -/// Extracts a slice of qubits from a tensor and inserts it immediately at a -/// different offset. -void qtensorExtractSliceInsertSliceOffsetMismatch(QCOProgramBuilder& b); - -/// Extracts a slice of qubits from a tensor and inserts it immediately at the -/// same offset. -void qtensorExtractSliceInsertSliceSameOffset(QCOProgramBuilder& b); - -/// Inserts a slice of qubits into a tensor and extracts it immediately at a -/// different offset. -void qtensorInsertSliceExtractSliceOffsetMismatch(QCOProgramBuilder& b); - -/// Inserts a slice of qubits into a tensor and extracts it immediately at the -/// same offset. -void qtensorInsertSliceExtractSliceSameOffset(QCOProgramBuilder& b); - -/// Extracts a slice of qubits, extracts a qubit from the slice, inserts the -/// qubit back into the slice, and inserts the slice back into the tensor -/// immediately at the same index and offset. -void qtensorExtractSliceExtractInsertInsertSlice(QCOProgramBuilder& b); - } // namespace mlir::qco From 7be6f12c5619d329fcb64b65e2508755362ab171 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Mon, 6 Apr 2026 20:23:01 +0200 Subject: [PATCH 50/71] =?UTF-8?q?=E2=9C=85=20Add=20dedicated=20test=20suit?= =?UTF-8?q?e=20for=20the=20QTensor=20dialect?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assisted-by: gpt-5.3-codex (high) via Codex Signed-off-by: burgholzer --- mlir/unittests/Dialect/CMakeLists.txt | 1 + mlir/unittests/Dialect/QTensor/CMakeLists.txt | 9 + .../Dialect/QTensor/IR/CMakeLists.txt | 15 + .../Dialect/QTensor/IR/test_qtensor_ir.cpp | 883 ++++++++++++++++++ 4 files changed, 908 insertions(+) create mode 100644 mlir/unittests/Dialect/QTensor/CMakeLists.txt create mode 100644 mlir/unittests/Dialect/QTensor/IR/CMakeLists.txt create mode 100644 mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 41a55f38fe..2c62a1e81e 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -9,4 +9,5 @@ add_subdirectory(QC) add_subdirectory(QCO) add_subdirectory(QIR) +add_subdirectory(QTensor) add_subdirectory(Utils) diff --git a/mlir/unittests/Dialect/QTensor/CMakeLists.txt b/mlir/unittests/Dialect/QTensor/CMakeLists.txt new file mode 100644 index 0000000000..b181a84fed --- /dev/null +++ b/mlir/unittests/Dialect/QTensor/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_subdirectory(IR) diff --git a/mlir/unittests/Dialect/QTensor/IR/CMakeLists.txt b/mlir/unittests/Dialect/QTensor/IR/CMakeLists.txt new file mode 100644 index 0000000000..b7e2227682 --- /dev/null +++ b/mlir/unittests/Dialect/QTensor/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(qtensor_ir_target mqt-core-mlir-unittest-qtensor-ir) +add_executable(${qtensor_ir_target} test_qtensor_ir.cpp) +target_link_libraries(${qtensor_ir_target} PRIVATE GTest::gtest_main MLIRParser MLIRSupportMQT + MLIRQCOProgramBuilder MLIRQCOPrograms) +mqt_mlir_configure_unittest_target(${qtensor_ir_target}) + +gtest_discover_tests(${qtensor_ir_target} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp new file mode 100644 index 0000000000..3708d6517b --- /dev/null +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -0,0 +1,883 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +/** + * @file test_qtensor_ir.cpp + * @brief Dedicated unit-test suite for the QTensor MLIR dialect. + */ + +#include "TestCaseUtils.h" +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorDialect.h" +#include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" +#include "mlir/Support/IRVerification.h" +#include "mlir/Support/Passes.h" +#include "qco_programs.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::qtensor; +using namespace mlir::qco; + +namespace { +// ============================================================================ +// Shared fixture — sets up an MLIR context with QTensor/QCO/Arith dialects +// and provides a QCOProgramBuilder for creating test programs. +// ============================================================================ + +class QTensorTest : public ::testing::Test { +protected: + std::unique_ptr context; + + void SetUp() override { + DialectRegistry registry; + registry.insert(); + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } + + /// Build a module using the QCOProgramBuilder and run a lightweight cleanup + /// pipeline (canonicalizer + CSE + symbol DCE + canonicalizer). + [[nodiscard]] OwningOpRef + buildAndCanonicalize(void (*buildFn)(QCOProgramBuilder&)) const { + auto module = QCOProgramBuilder::build(context.get(), buildFn); + if (!module) { + return {}; + } + + PassManager pm(context.get()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createSymbolDCEPass()); + pm.addPass(createCanonicalizerPass()); + if (pm.run(*module).failed()) { + return {}; + } + return module; + } + + /// Count occurrences of a specific op kind inside a module. + template + [[nodiscard]] static std::size_t countOps(ModuleOp module) { + std::size_t count = 0; + module.walk([&](OpT) { ++count; }); + return count; + } +}; + +// ============================================================================ +// 1. QTensorUtils — direct tests of scalar chain helpers +// ============================================================================ + +TEST_F(QTensorTest, AreEquivalentIndices_SameValueIsEquivalent) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto c2 = arith::ConstantIndexOp::create(builder, 2); + EXPECT_TRUE(areEquivalentIndices(c2.getResult(), c2.getResult())); +} + +TEST_F(QTensorTest, AreEquivalentIndices_DifferentConstantsAreNotEquivalent) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto c0 = arith::ConstantIndexOp::create(builder, 0); + auto c1 = arith::ConstantIndexOp::create(builder, 1); + EXPECT_FALSE(areEquivalentIndices(c0.getResult(), c1.getResult())); +} + +TEST_F(QTensorTest, TensorChainHelpers_InsertAndExtractAreRecognized) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto tensor = builder.qtensorAlloc(3); + auto [outTensor, q0] = builder.qtensorExtract(tensor, 0); + auto insert = builder.qtensorInsert(q0, outTensor, 0).getDefiningOp(); + auto extract = outTensor.getDefiningOp(); + + ASSERT_NE(insert, nullptr); + ASSERT_NE(extract, nullptr); + EXPECT_TRUE(isTensorChainOp(insert)); + EXPECT_TRUE(isTensorChainOp(extract)); + EXPECT_EQ(getTensorChainOutput(insert), insert->getResult(0)); + EXPECT_EQ(getTensorChainInput(extract), tensor); +} + +TEST_F(QTensorTest, TensorChainHelpers_SetTensorChainInputRewiresOperand) { + auto module = + QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { + auto t1 = b.qtensorAlloc(3); + auto [out1, q1] = b.qtensorExtract(t1, 1); + auto t0 = b.qtensorAlloc(3); + auto [out0, q0] = b.qtensorExtract(t0, 0); + auto insert = InsertOp::create( + b, q0, out0, arith::ConstantIndexOp::create(b, 0).getResult()); + setTensorChainInput(insert.getOperation(), out1); + (void)InsertOp::create( + b, q0, out1, arith::ConstantIndexOp::create(b, 1).getResult()); + (void)q1; + (void)out0; + }); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); +} + +// ============================================================================ +// 2. AllocOp — verify() tests +// ============================================================================ + +/// A valid static alloc should pass verification. +TEST_F(QTensorTest, AllocOp_ValidStaticAllocVerifies) { + auto module = QCOProgramBuilder::build( + context.get(), [](QCOProgramBuilder& b) { b.qtensorAlloc(3); }); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); +} + +/// AllocOp with a constant size ≤ 0 must fail verification. +/// Note: The builder asserts on zero/negative, so we verify the verifier +/// by constructing the op manually bypassing the builder assertion. +TEST_F(QTensorTest, AllocOp_ZeroSizeFailsVerification) { + // Build a module manually to bypass builder-level assertion. + auto loc = UnknownLoc::get(context.get()); + auto module = ModuleOp::create(loc); + ImplicitLocOpBuilder b(loc, context.get()); + b.setInsertionPointToStart(module.getBody()); + + // Create a constant 0 for the size operand. + auto c0 = arith::ConstantIndexOp::create(b, 0); + // Construct the result type that would match a size-0 tensor (which is + // invalid per the verifier). We use kDynamic so the type-level constraint + // won't block construction, but the constant operand (0) triggers the + // verifier. + auto qubitType = qco::QubitType::get(context.get()); + auto dynType = RankedTensorType::get({ShapedType::kDynamic}, qubitType); + AllocOp::create(b, dynType, c0.getResult()); + + // The verifier should catch `sizeValue <= 0`. + EXPECT_TRUE(verify(module).failed()); +} + +/// AllocOp where static result type dim ≠ constant size must fail. +TEST_F(QTensorTest, AllocOp_StaticTypeMismatchFailsVerification) { + auto loc = UnknownLoc::get(context.get()); + auto module = ModuleOp::create(loc); + ImplicitLocOpBuilder b(loc, context.get()); + b.setInsertionPointToStart(module.getBody()); + + auto c2 = arith::ConstantIndexOp::create(b, 2); // size operand = 2 + auto qubitType = qco::QubitType::get(context.get()); + // result type says dimension = 3, but size operand = 2 → mismatch + auto staticType = RankedTensorType::get({3}, qubitType); + AllocOp::create(b, staticType, c2.getResult()); + + EXPECT_TRUE(verify(module).failed()); +} + +/// AllocOp with a dynamic result type but a constant size operand is valid. +TEST_F(QTensorTest, AllocOp_DynamicTypeWithConstantSizeVerifies) { + auto loc = UnknownLoc::get(context.get()); + auto module = ModuleOp::create(loc); + ImplicitLocOpBuilder b(loc, context.get()); + b.setInsertionPointToStart(module.getBody()); + + auto c3 = arith::ConstantIndexOp::create(b, 3); + auto qubitType = qco::QubitType::get(context.get()); + auto dynType = RankedTensorType::get({ShapedType::kDynamic}, qubitType); + AllocOp::create(b, dynType, c3.getResult()); + + // Dynamic result dim with constant positive size → valid. + EXPECT_TRUE(verify(module).succeeded()); +} + +/// AllocOp with a static result type but a non-constant (dynamic) size +/// operand must fail verification. +TEST_F(QTensorTest, AllocOp_StaticTypeWithDynamicSizeOperandFailsVerification) { + auto loc = UnknownLoc::get(context.get()); + auto module = ModuleOp::create(loc); + ImplicitLocOpBuilder b(loc, context.get()); + // We need a block argument to act as a non-constant size. + auto* block = module.getBody(); + block->addArgument(IndexType::get(context.get()), loc); + Value dynSizeVal = block->getArgument(0); + + b.setInsertionPointToEnd(block); + auto qubitType = qco::QubitType::get(context.get()); + // Static result type dim = 3, but size operand is dynamic → error + auto staticType = RankedTensorType::get({3}, qubitType); + AllocOp::create(b, staticType, dynSizeVal); + + EXPECT_TRUE(verify(module).failed()); +} + +// ============================================================================ +// 3. DeallocOp — canonicalization (RemoveAllocDeallocPair) +// ============================================================================ + +/// An alloc immediately followed by dealloc should be eliminated entirely. +TEST_F(QTensorTest, DeallocOp_AllocDeallocPairIsRemoved) { + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto t = b.qtensorAlloc(3); + b.qtensorDealloc(t); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); + // Both AllocOp and DeallocOp should have been erased. + EXPECT_EQ(countOps(*canonicalized), 0U); + EXPECT_EQ(countOps(*canonicalized), 0U); +} + +/// A dealloc whose operand is not directly an AllocOp should not be removed. +TEST_F(QTensorTest, DeallocOp_DeallocOfNonAllocIsNotRemoved) { + // Extract then insert to create a different tensor SSA value before dealloc. + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + auto [outTensor, q0] = b.qtensorExtract(tensor, 0); + auto q1 = b.h(q0); + auto afterInsert = b.qtensorInsert(q1, outTensor, 0); + b.qtensorDealloc(afterInsert); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); + // After canonicalization the extract/insert pair simplifies, but there + // should still be either an alloc+dealloc pair or both get eliminated + // through further folding — just check the module is valid. + // The important invariant: DeallocOp count is not negative, i.e., the + // transform did not crash. +} + +// ============================================================================ +// 4. ExtractOp — verify(), fold, and canonicalization +// ============================================================================ + +/// A valid extract at index 0 from a size-1 tensor must pass verification. +TEST_F(QTensorTest, ExtractOp_ValidIndexVerifies) { + auto module = + QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { + auto t = b.qtensorAlloc(1); + b.qtensorExtract(t, 0); + }); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); +} + +/// An extract at a negative constant index must fail verification. +TEST_F(QTensorTest, ExtractOp_NegativeIndexFailsVerification) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + Value tensor = builder.qtensorAlloc(3); + auto negIdx = arith::ConstantIndexOp::create(builder, -1); + ExtractOp::create(builder, tensor, negIdx.getResult()); + auto module = builder.finalize(); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).failed()); +} + +/// An extract at an index equal to the tensor dimension must fail (out of +/// bounds). +TEST_F(QTensorTest, ExtractOp_IndexAtDimFailsVerification) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + Value tensor = builder.qtensorAlloc(3); + // index = 3, tensor has dim 3 → out of bounds + auto idx3 = arith::ConstantIndexOp::create(builder, 3); + ExtractOp::create(builder, tensor, idx3.getResult()); + auto module = builder.finalize(); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).failed()); +} + +/// An extract at an index one less than the dimension must pass. +TEST_F(QTensorTest, ExtractOp_IndexAtDimMinusOneVerifies) { + // Build inside a proper func.func body via QCOProgramBuilder. + QCOProgramBuilder builder(context.get()); + builder.initialize(); + // qtensorAlloc(3) creates tensor<3x!qco.qubit> and tracks it. + Value tensor = builder.qtensorAlloc(3); + auto idx2 = arith::ConstantIndexOp::create(builder, 2); + // Create extract at index 2 — last valid index for dim 3. + // (Use the raw op creator to bypass builder tracking.) + ExtractOp::create(builder, tensor, idx2.getResult()); + // finalize() will dealloc the still-tracked tensor (two uses of %tensor is + // valid in MLIR SSA). Dead extract results (%tOut, %q) are fine. + auto module = builder.finalize(); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); +} + +/// foldExtractAfterInsert: extract(insert(t, q, i), i) → (t, q) +/// The fold must eliminate the round-trip at the same index. +TEST_F(QTensorTest, ExtractOp_FoldExtractAfterInsertSameIndex) { + // Use the full QCO pipeline so that both the fold and subsequent DCE of the + // dead InsertOp run to convergence (single canonicalizer pass may leave + // unreachable Pure ops if DCE and folding don't interleave). + auto module = + QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + auto [outTensor, q0] = b.qtensorExtract(tensor, 0); + auto q1 = b.h(q0); + auto afterInsert = b.qtensorInsert(q1, outTensor, 0); + // Immediately extract the same qubit back — should fold away. + b.qtensorExtract(afterInsert, 0); + }); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); + runQCOCleanupPipeline(module.get()); + EXPECT_TRUE(verify(*module).succeeded()); + // The extra extract at the same index should fold away. + EXPECT_EQ(countOps(*module), 1U); // original extract +} + +/// foldExtractAfterInsert: extract at a different index must NOT fold. +TEST_F(QTensorTest, ExtractOp_NoFoldExtractAfterInsertDifferentIndex) { + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + auto [outTensor, q0] = b.qtensorExtract(tensor, 0); + auto q1 = b.h(q0); + auto afterInsert = b.qtensorInsert(q1, outTensor, 0); + // Extract at index 1 — different from the insert's index 0 + b.qtensorExtract(afterInsert, 1); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); + // The insert should still be present (not folded). + EXPECT_GE(countOps(*canonicalized), 1U); +} + +/// RemoveInsertExtractPair: extract through a disjoint InsertOp should find +/// the original extract and eliminate both. +TEST_F(QTensorTest, ExtractOp_RemoveInsertExtractPairThroughDisjointInsert) { + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + // Extract qubit 0. + auto [t1, q0] = b.qtensorExtract(tensor, 0); + // Extract qubit 1 (disjoint from qubit 0). + auto [t2, q1] = b.qtensorExtract(t1, 1); + // Insert qubit 1 back at index 1, then extract it again — same index. + // The canonicalizer should eliminate both the insert and the re-extract. + auto afterInsert = b.qtensorInsert(q1, t2, 1); + b.qtensorExtract(afterInsert, 1); + // Use q0 so it isn't dead. + b.h(q0); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); +} + +/// RemoveInsertExtractPair: a nested ExtractOp at the same index must block +/// re-ordering (linearity guard). +TEST_F(QTensorTest, + ExtractOp_RemoveInsertExtractPairBlockedByNestedExtractAtSameIndex) { + // Pattern: insert q0 at 0, then extract-at-0 twice (would violate linearity + // if the first extraction were skipped). + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + auto [t1, q0] = b.qtensorExtract(tensor, 0); + // Re-insert q0 at index 0, producing a new tensor. + auto q0h = b.h(q0); + auto afterInsert = b.qtensorInsert(q0h, t1, 0); + // Attempt to extract index 0 again — the chain already has an + // extract-at-0 in it, blocking the RemoveInsertExtractPair pattern. + auto [t3, q0again] = b.qtensorExtract(afterInsert, 0); + b.h(q0again); + (void)t3; + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); +} + +// ============================================================================ +// 5. InsertOp — verify(), fold, and canonicalization +// ============================================================================ + +/// A valid insert at index 0 into a size-3 tensor must pass verification. +TEST_F(QTensorTest, InsertOp_ValidIndexVerifies) { + auto module = + QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { + auto t = b.qtensorAlloc(3); + auto [out, q] = b.qtensorExtract(t, 0); + b.qtensorInsert(q, out, 0); + }); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); +} + +/// An insert at a negative constant index must fail verification. +TEST_F(QTensorTest, InsertOp_NegativeIndexFailsVerification) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + // Extract qubit 0 to get both a tracked tensor and a qubit. + auto tensor = builder.qtensorAlloc(3); + auto [outTensor, q0] = builder.qtensorExtract(tensor, 0); + // Insert at index -1 — raw op creation bypasses builder tracking. + auto negIdx = arith::ConstantIndexOp::create(builder, -1); + InsertOp::create(builder, q0, outTensor, negIdx.getResult()); + // finalize() will dealloc outTensor and sink q0 (both still tracked, both + // reused — valid in SSA). + auto module = builder.finalize(); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).failed()); +} + +/// An insert at an index equal to the destination dimension must fail. +TEST_F(QTensorTest, InsertOp_IndexAtDimFailsVerification) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto tensor = builder.qtensorAlloc(3); + auto [outTensor, q0] = builder.qtensorExtract(tensor, 0); + auto idx3 = arith::ConstantIndexOp::create(builder, 3); // == dim + InsertOp::create(builder, q0, outTensor, idx3.getResult()); + auto module = builder.finalize(); + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).failed()); +} + +/// foldInsertAfterExtract: insert(extract(t, i).qubit, extract(t, i).out, i) +/// should fold to `t`. +TEST_F(QTensorTest, InsertOp_FoldInsertAfterExtractSameIndex) { + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + auto [outTensor, q0] = b.qtensorExtract(tensor, 0); + // Insert the extracted qubit back at the same index without modification. + b.qtensorInsert(q0, outTensor, 0); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); + // The extract-insert pair should have been eliminated entirely. + EXPECT_EQ(countOps(*canonicalized), 0U); + EXPECT_EQ(countOps(*canonicalized), 0U); +} + +/// foldInsertAfterExtract: inserting the qubit at a different index must NOT +/// fold. +TEST_F(QTensorTest, InsertOp_NoFoldInsertAfterExtractDifferentIndex) { + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + auto [outTensor, q0] = b.qtensorExtract(tensor, 0); + // Insert at index 1 instead of 0 + b.qtensorInsert(q0, outTensor, 1); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); + EXPECT_GE(countOps(*canonicalized), 1U); +} + +/// foldInsertAfterExtract: inserting into a different tensor (not the extract's +/// out_tensor) must NOT fold. +TEST_F(QTensorTest, InsertOp_NoFoldInsertAfterExtractDifferentDest) { + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto t1 = b.qtensorAlloc(3); + auto t2 = b.qtensorAlloc(3); + auto [outTensor, q0] = b.qtensorExtract(t1, 0); + // q0 came from t1, but we insert into t2's out-tensor + auto [t2out, q1] = b.qtensorExtract(t2, 1); + b.qtensorInsert(q0, t2out, 0); + b.h(q1); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); + EXPECT_GE(countOps(*canonicalized), 1U); +} + +/// RemoveExtractInsertPair: an insert-after-extract that has been modified +/// (qubit passed through an H gate) must NOT be eliminated. +TEST_F(QTensorTest, InsertOp_NoRemoveExtractInsertPairAfterMutation) { + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + auto [outTensor, q0] = b.qtensorExtract(tensor, 0); + auto q1 = b.h(q0); // mutation — scalar ≠ extract.getResult() + b.qtensorInsert(q1, outTensor, 0); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); + // The HOp mutates the qubit, so the pair cannot be collapsed. + EXPECT_GE(countOps(*canonicalized), 1U); +} + +/// RemoveExtractInsertPair: insert shadowed by an earlier InsertOp at the same +/// index must not be eliminated. +TEST_F(QTensorTest, InsertOp_RemoveExtractInsertPairBlockedByShadowingInsert) { + // Pattern: + // t1, q0 = extract(alloc, 0) + // t2 = insert(q0, t1, 0) ← overwrites index 0 + // t3 = insert(q0_h, t2, 0) ← another write to index 0 + // Trying to find the matching extract for the second insert should be + // blocked by the first insert at the same index. + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + auto [t1, q0] = b.qtensorExtract(tensor, 0); + // First insert q0 at 0. + auto t2 = b.qtensorInsert(q0, t1, 0); + // Second insert at 0 (different qubit — from another extract). + auto [t2out, q1] = b.qtensorExtract(t2, 0); + auto q1h = b.h(q1); + b.qtensorInsert(q1h, t2out, 0); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); +} + +/// RemoveExtractInsertPair: insert blocked by a disjoint insert (different +/// index) should still succeed in finding the original extract. +TEST_F(QTensorTest, InsertOp_RemoveExtractInsertPairThroughDisjointInsert) { + // Pattern: + // t1, q0 = extract(alloc, 0) + // t2, q1 = extract(t1, 1) ← disjoint from index 0 + // t3 = insert(q1, t2, 1) ← insert at index 1 (disjoint) + // t4 = insert(q0, t3, 0) ← insert matches the extract at 0 + // Both the q0 extract-insert and q1 extract-insert should collapse. + auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { + auto tensor = b.qtensorAlloc(3); + auto [t1, q0] = b.qtensorExtract(tensor, 0); + auto [t2, q1] = b.qtensorExtract(t1, 1); + auto t3 = b.qtensorInsert(q1, t2, 1); + b.qtensorInsert(q0, t3, 0); + }); + ASSERT_TRUE(canonicalized); + EXPECT_TRUE(verify(*canonicalized).succeeded()); + // Both pairs should collapse. + EXPECT_EQ(countOps(*canonicalized), 0U); + EXPECT_EQ(countOps(*canonicalized), 0U); +} + +// ============================================================================ +// 6. Integration +// +// These tests use the full QCO cleanup pipeline and compare canonicalized +// modules for structural equivalence with permutations. +// ============================================================================ + +struct QTensorIntegrationTestCase { + std::string name; + mqt::test::NamedBuilder programBuilder; + mqt::test::NamedBuilder referenceBuilder; + + friend std::ostream& operator<<(std::ostream& os, + const QTensorIntegrationTestCase& info) { + return os << "QTensor{" << info.name << "}"; + } +}; + +class QTensorIntegrationTest + : public testing::TestWithParam { +protected: + std::unique_ptr context; + + void SetUp() override { + DialectRegistry registry; + registry.insert(); + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } +}; + +TEST_P(QTensorIntegrationTest, ProgramEquivalence) { + const auto& [_, programBuilder, referenceBuilder] = GetParam(); + const auto name = " (" + GetParam().name + ")"; + mqt::test::DeferredPrinter printer; + + auto program = QCOProgramBuilder::build(context.get(), programBuilder.fn); + ASSERT_TRUE(program); + printer.record(program.get(), "Original QTensor IR" + name); + EXPECT_TRUE(verify(*program).succeeded()); + + runQCOCleanupPipeline(program.get()); + printer.record(program.get(), "Canonicalized QTensor IR" + name); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = QCOProgramBuilder::build(context.get(), referenceBuilder.fn); + ASSERT_TRUE(reference); + printer.record(reference.get(), "Reference QTensor IR" + name); + EXPECT_TRUE(verify(*reference).succeeded()); + + runQCOCleanupPipeline(reference.get()); + printer.record(reference.get(), "Canonicalized Reference QTensor IR" + name); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +/// @name QTensor/QTensor.cpp (relocated from QCO test suite) +/// @{ +INSTANTIATE_TEST_SUITE_P( + QTensorOpsTest, QTensorIntegrationTest, + testing::Values( + QTensorIntegrationTestCase{"QTensorAlloc", + MQT_NAMED_BUILDER(qtensorAlloc), + MQT_NAMED_BUILDER(qtensorAlloc)}, + QTensorIntegrationTestCase{"QTensorAllocDealloc", + MQT_NAMED_BUILDER(qtensorDealloc), + MQT_NAMED_BUILDER(qtensorAlloc)}, + QTensorIntegrationTestCase{"QTensorFromElements", + MQT_NAMED_BUILDER(qtensorFromElements), + MQT_NAMED_BUILDER(qtensorFromElements)}, + QTensorIntegrationTestCase{"QTensorExtract", + MQT_NAMED_BUILDER(qtensorExtract), + MQT_NAMED_BUILDER(qtensorExtract)}, + QTensorIntegrationTestCase{"QTensorInsert", + MQT_NAMED_BUILDER(qtensorInsert), + MQT_NAMED_BUILDER(qtensorInsert)}, + QTensorIntegrationTestCase{ + "QTensorExtractInsertSameIndex", + MQT_NAMED_BUILDER(qtensorExtractInsertSameIndex), + MQT_NAMED_BUILDER(qtensorAlloc)}, + QTensorIntegrationTestCase{ + "QTensorExtractInsertIndexMismatch", + MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch), + MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch)}, + QTensorIntegrationTestCase{ + "QTensorInsertExtractSameIndex", + MQT_NAMED_BUILDER(qtensorInsertExtractSameIndex), + MQT_NAMED_BUILDER(qtensorInsert)}, + QTensorIntegrationTestCase{ + "QTensorInsertExtractIndexMismatch", + MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch), + MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch)})); +/// @} +} // namespace + +// ============================================================================ +// 7. Integration — multi-qubit permutation equivalence tests +// ============================================================================ + +static OwningOpRef +buildTwoQubitInsertChainProgram(MLIRContext* context, + const bool reverseInsertOrder, + const bool swapInsertTargets) { + QCOProgramBuilder builder(context); + builder.initialize(); + + auto tensor = builder.qtensorAlloc(2); + auto [tensorAfterFirstExtract, qubit0] = builder.qtensorExtract(tensor, 0); + auto [baseTensor, qubit1] = + builder.qtensorExtract(tensorAfterFirstExtract, 1); + + const int64_t qubit0Target = swapInsertTargets ? 1 : 0; + const int64_t qubit1Target = swapInsertTargets ? 0 : 1; + + Value currentTensor = baseTensor; + if (reverseInsertOrder) { + currentTensor = builder.qtensorInsert(qubit1, currentTensor, qubit1Target); + currentTensor = builder.qtensorInsert(qubit0, currentTensor, qubit0Target); + } else { + currentTensor = builder.qtensorInsert(qubit0, currentTensor, qubit0Target); + currentTensor = builder.qtensorInsert(qubit1, currentTensor, qubit1Target); + } + + builder.qtensorDealloc(currentTensor); + return builder.finalize(); +} + +static OwningOpRef +buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, + const bool swapInsertTargets) { + QCOProgramBuilder builder(context); + builder.initialize(); + + auto tensor = builder.qtensorAlloc(3); + Value tensorAfterReads = tensor; + Value qubit0 = nullptr; + Value qubit1 = nullptr; + + if (reverseOrder) { + std::tie(tensorAfterReads, qubit1) = + builder.qtensorExtract(tensorAfterReads, 1); + std::tie(tensorAfterReads, qubit0) = + builder.qtensorExtract(tensorAfterReads, 0); + } else { + std::tie(tensorAfterReads, qubit0) = + builder.qtensorExtract(tensorAfterReads, 0); + std::tie(tensorAfterReads, qubit1) = + builder.qtensorExtract(tensorAfterReads, 1); + } + + const int64_t q0Target = 0; + const int64_t q1Target = swapInsertTargets ? 2 : 1; + + Value tensorAfterWrites = tensorAfterReads; + if (reverseOrder) { + tensorAfterWrites = + builder.qtensorInsert(qubit0, tensorAfterWrites, q0Target); + tensorAfterWrites = + builder.qtensorInsert(qubit1, tensorAfterWrites, q1Target); + } else { + tensorAfterWrites = + builder.qtensorInsert(qubit1, tensorAfterWrites, q1Target); + tensorAfterWrites = + builder.qtensorInsert(qubit0, tensorAfterWrites, q0Target); + } + + builder.qtensorDealloc(tensorAfterWrites); + return builder.finalize(); +} + +static OwningOpRef +buildResetWithCommutingInsertProgram(MLIRContext* context, + const bool withReset) { + QCOProgramBuilder builder(context); + builder.initialize(); + + auto tensor = builder.qtensorAlloc(2); + auto [tensorAfterExtract0, qubit0] = builder.qtensorExtract(tensor, 0); + auto tensorAfterInsert0 = + builder.qtensorInsert(qubit0, tensorAfterExtract0, 0); + auto [tensorAfterExtract1, qubit1] = + builder.qtensorExtract(tensorAfterInsert0, 1); + if (withReset) { + qubit1 = builder.reset(qubit1); + } + auto tensorFinal = builder.qtensorInsert(qubit1, tensorAfterExtract1, 1); + builder.qtensorDealloc(tensorFinal); + + return builder.finalize(); +} + +static OwningOpRef +buildResetWithSameIndexInsertProgram(MLIRContext* context, + const bool withReset) { + QCOProgramBuilder builder(context); + builder.initialize(); + + auto tensor = builder.qtensorAlloc(2); + auto [tensorAfterExtract0, qubit0] = builder.qtensorExtract(tensor, 0); + auto [tensorAfterExtract1, qubit1] = + builder.qtensorExtract(tensorAfterExtract0, 1); + qubit1 = builder.h(qubit1); + auto tensorAfterInsert1 = + builder.qtensorInsert(qubit1, tensorAfterExtract1, 1); + auto [tensorAfterReadBack1, qubit1ReadBack] = + builder.qtensorExtract(tensorAfterInsert1, 1); + if (withReset) { + qubit1ReadBack = builder.reset(qubit1ReadBack); + } + auto tensorAfterInsert1ReadBack = + builder.qtensorInsert(qubit1ReadBack, tensorAfterReadBack1, 1); + auto tensorFinal = + builder.qtensorInsert(qubit0, tensorAfterInsert1ReadBack, 0); + builder.qtensorDealloc(tensorFinal); + + return builder.finalize(); +} + +namespace { +TEST_F(QTensorTest, InsertChainPermutationEquivalence) { + auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); + ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); + runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildTwoQubitInsertChainProgram(context.get(), true, false); + ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); + runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QTensorTest, InsertChainDifferentAssignmentsNotEquivalent) { + auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); + ASSERT_TRUE(program); + runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildTwoQubitInsertChainProgram(context.get(), true, true); + ASSERT_TRUE(reference); + runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_FALSE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QTensorTest, MixedExtractInsertPermutationEquivalence) { + auto program = buildMixedExtractInsertProgram(context.get(), false, false); + ASSERT_TRUE(program); + runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildMixedExtractInsertProgram(context.get(), true, false); + ASSERT_TRUE(reference); + runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QTensorTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { + auto program = buildMixedExtractInsertProgram(context.get(), false, false); + ASSERT_TRUE(program); + runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildMixedExtractInsertProgram(context.get(), true, true); + ASSERT_TRUE(reference); + runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_FALSE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QTensorTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { + auto program = buildResetWithCommutingInsertProgram(context.get(), true); + ASSERT_TRUE(program); + runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildResetWithCommutingInsertProgram(context.get(), false); + ASSERT_TRUE(reference); + runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +TEST_F(QTensorTest, ResetAfterExtractThroughSameIndexInsertIsNotEliminated) { + auto program = buildResetWithSameIndexInsertProgram(context.get(), true); + ASSERT_TRUE(program); + runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(verify(*program).succeeded()); + + auto reference = buildResetWithSameIndexInsertProgram(context.get(), false); + ASSERT_TRUE(reference); + runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(verify(*reference).succeeded()); + + EXPECT_FALSE( + areModulesEquivalentWithPermutations(program.get(), reference.get())); +} + +} // namespace From b5cfbbcd776010b00b3f5d56d22d04f01c9bc38f Mon Sep 17 00:00:00 2001 From: burgholzer Date: Mon, 6 Apr 2026 20:58:01 +0200 Subject: [PATCH 51/71] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Do=20not=20default?= =?UTF-8?q?=20to=20dynamic=20shape=20for=20QTensor=20conversion=20from=20Q?= =?UTF-8?q?CO=20to=20QC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 421d7942b2..a9f83c3e88 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -100,6 +100,10 @@ namespace { * The primary conversion is from !qco.qubit to !qc.qubit, which * represents the semantic shift from value types to reference types. * + * Qubit tensor types preserve their shape during conversion: a statically + * shaped `tensor` becomes `memref`, while a + * dynamically shaped `tensor` becomes `memref`. + * * Other types (integers, booleans, etc.) pass through unchanged via * the identity conversion. */ @@ -116,7 +120,7 @@ class QCOToQCTypeConverter final : public TypeConverter { addConversion([ctx](RankedTensorType type) -> Type { if (llvm::isa(type.getElementType())) { - return MemRefType::get({ShapedType::kDynamic}, qc::QubitType::get(ctx)); + return MemRefType::get(type.getShape(), qc::QubitType::get(ctx)); } return type; }); @@ -142,8 +146,17 @@ struct ConvertQTensorAllocOp final : OpConversionPattern { matchAndRewrite(qtensor::AllocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto qubitType = qc::QubitType::get(op.getContext()); - auto memrefType = mlir::MemRefType::get({ShapedType::kDynamic}, qubitType); - rewriter.replaceOpWithNewOp(op, memrefType, op.getSize()); + auto tensorType = llvm::cast(op.getResult().getType()); + auto memrefType = MemRefType::get(tensorType.getShape(), qubitType); + + if (tensorType.hasStaticShape()) { + // Static size: no dynamic size operand needed + rewriter.replaceOpWithNewOp(op, memrefType); + } else { + // Dynamic size: forward the runtime size operand + rewriter.replaceOpWithNewOp(op, memrefType, + op.getSize()); + } return success(); } }; From dcea94423f6b36fbd91cc3cd0ac89fd347a35ff7 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Mon, 6 Apr 2026 21:07:51 +0200 Subject: [PATCH 52/71] =?UTF-8?q?=F0=9F=8E=A8=20Newline=20cosmetic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp b/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp index 2298f6ab47..1c958b346b 100644 --- a/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp +++ b/mlir/lib/Conversion/JeffToQCO/JeffToQCO.cpp @@ -458,10 +458,8 @@ struct ConvertJeffQuregInsertIndexOpToQCO final ConversionPatternRewriter& rewriter) const override { auto index = arith::IndexCastOp::create( rewriter, op.getLoc(), rewriter.getIndexType(), adaptor.getIndex()); - rewriter.replaceOpWithNewOp(op, adaptor.getInQubit(), - adaptor.getInQreg(), - - index.getResult()); + rewriter.replaceOpWithNewOp( + op, adaptor.getInQubit(), adaptor.getInQreg(), index.getResult()); return success(); } }; From 7409023a8f0d521e456ca40654019644ba8f5f4a Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 6 Apr 2026 23:13:04 +0200 Subject: [PATCH 53/71] Fix linter errors --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 1 - .../Dialect/QC/Builder/QCProgramBuilder.cpp | 1 - .../QC/Transforms/ShrinkQubitRegisters.cpp | 8 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 4 +- .../lib/Dialect/QCO/IR/Operations/ResetOp.cpp | 1 + .../QCO/IR/Operations/StandardGates/RXXOp.cpp | 1 + .../QCO/IR/Operations/StandardGates/RYYOp.cpp | 1 + .../QCO/IR/Operations/StandardGates/RZOp.cpp | 1 + .../QCO/IR/Operations/StandardGates/RZXOp.cpp | 1 + .../QCO/IR/Operations/StandardGates/RZZOp.cpp | 2 + .../Operations/StandardGates/XXMinusYYOp.cpp | 2 +- .../Operations/StandardGates/XXPlusYYOp.cpp | 2 +- .../lib/Dialect/QIR/Transforms/QIRCleanup.cpp | 7 ++ .../QTensor/IR/Operations/ExtractOp.cpp | 1 + .../QTensor/Transforms/ShrinkRegisters.cpp | 25 ++++-- mlir/lib/Support/IRVerification.cpp | 4 +- mlir/lib/Support/Passes.cpp | 1 + mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 2 - .../Dialect/QTensor/IR/test_qtensor_ir.cpp | 76 ++++++++++--------- 19 files changed, 87 insertions(+), 54 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index a7470901d2..b18306cd9d 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -642,7 +642,6 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { matchAndRewrite(qc::StaticOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - auto* operation = op.getOperation(); auto qcQubit = op.getQubit(); auto qcoOp = rewriter.replaceOpWithNewOp(op, op.getIndex()); diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index b6c503f7eb..914b4ee29f 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" -#include #include #include #include diff --git a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp index fc8e9868c8..963bd82f82 100644 --- a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp +++ b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp @@ -8,19 +8,23 @@ * Licensed under the MIT License */ -#include "mlir/Dialect/QC/IR/QCOps.h" +#include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/Transforms/Passes.h" #include #include #include +#include #include #include #include -#include +#include #include +#include #include +#include +#include #include #include diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 4317a73d9a..c17da088d7 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -154,7 +154,7 @@ void QCOProgramBuilder::updateQubitTracking(Value inputQubit, validQubits.erase(it); // Add the output (new) value to tracking - validQubits.try_emplace(outputQubit, std::move(info)); + validQubits.try_emplace(outputQubit, info); } void QCOProgramBuilder::validateTensorValue(Value tensor) const { @@ -187,7 +187,7 @@ void QCOProgramBuilder::updateTensorTracking(Value inputTensor, validTensors.erase(it); // Add the output (new) value to tracking - validTensors.try_emplace(outputTensor, std::move(info)); + validTensors.try_emplace(outputTensor, info); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index 3e29c68472..e8f57cd859 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp index 58b2cf1c7e..ca26cf0c78 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp index 18d069263a..6b9d540062 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp index bd789e770f..cad399846c 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp index 70bbb999fb..be40c3de9a 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp index 2132bfdd1f..5fb05c3379 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp @@ -17,8 +17,10 @@ #include #include #include +#include #include +#include #include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp index 6e40c801c3..79dc169ae1 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp @@ -9,7 +9,6 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" -#include "mlir/Dialect/QCO/QCOUtils.h" #include "mlir/Dialect/Utils/Utils.h" #include @@ -20,6 +19,7 @@ #include #include #include +#include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp index 010913f98a..03d8853371 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp @@ -9,7 +9,6 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" -#include "mlir/Dialect/QCO/QCOUtils.h" #include "mlir/Dialect/Utils/Utils.h" #include @@ -19,6 +18,7 @@ #include #include #include +#include #include #include diff --git a/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp b/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp index 7daabd2cf4..c55efb3414 100644 --- a/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp +++ b/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp @@ -11,15 +11,22 @@ #include "mlir/Dialect/QIR/Transforms/Passes.h" #include "mlir/Dialect/QIR/Utils/QIRUtils.h" +#include #include +#include #include +#include #include #include #include #include #include +#include +#include #include +#include + namespace mlir::qir { #define GEN_PASS_DEF_QIRCLEANUPPASS diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 6313ec142b..3f0fa9880e 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include diff --git a/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp index 499c57abc4..d4bfcf3e5a 100644 --- a/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp +++ b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp @@ -14,13 +14,18 @@ #include #include #include +#include #include #include -#include +#include #include +#include +#include #include #include +#include +#include #include namespace mlir::qtensor { @@ -41,7 +46,7 @@ namespace mlir::qtensor { */ [[nodiscard]] static LogicalResult markLiveIndex(const int64_t index, llvm::BitVector& liveIndices) { - if (index < 0 || index >= static_cast(liveIndices.size())) { + if (index < 0 || std::cmp_greater_equal(index, liveIndices.size())) { return failure(); } liveIndices.set(static_cast(index)); @@ -86,7 +91,7 @@ namespace mlir::qtensor { Value tensor = allocOp.getResult(); while (true) { auto* user = getLinearTensorUser(tensor); - if (!user) { + if (user == nullptr) { return failure(); } @@ -126,6 +131,8 @@ namespace mlir::qtensor { } } +namespace { + /** * @brief Shrink static qtensors by removing never-accessed indices. * @details QTensor is linear, so this rewrite follows a single use-def chain. @@ -173,7 +180,7 @@ struct ShrinkStaticQTensor final : OpRewritePattern { Value currentTensor = newAlloc.getResult(); while (true) { Operation* currentOp = getLinearTensorUser(oldTensor); - if (!currentOp) { + if (currentOp == nullptr) { return failure(); } @@ -193,7 +200,7 @@ struct ShrinkStaticQTensor final : OpRewritePattern { } const auto oldIndex = *getConstantIntValue(extractOp.getIndex()); if (oldIndex < 0 || - oldIndex >= static_cast(newIndexByOldIndex.size())) { + std::cmp_greater_equal(oldIndex, newIndexByOldIndex.size())) { return failure(); } const auto mappedIndex = @@ -203,7 +210,7 @@ struct ShrinkStaticQTensor final : OpRewritePattern { } Value oldOutTensor = extractOp.getOutTensor(); Operation* nextOp = getLinearTensorUser(oldOutTensor); - if (!nextOp) { + if (nextOp == nullptr) { return failure(); } @@ -229,7 +236,7 @@ struct ShrinkStaticQTensor final : OpRewritePattern { } const auto oldIndex = *getConstantIntValue(insertOp.getIndex()); if (oldIndex < 0 || - oldIndex >= static_cast(newIndexByOldIndex.size())) { + std::cmp_greater_equal(oldIndex, newIndexByOldIndex.size())) { return failure(); } const auto mappedIndex = @@ -239,7 +246,7 @@ struct ShrinkStaticQTensor final : OpRewritePattern { } Value oldResultTensor = insertOp.getResult(); Operation* nextOp = getLinearTensorUser(oldResultTensor); - if (!nextOp) { + if (nextOp == nullptr) { return failure(); } @@ -279,4 +286,6 @@ struct ShrinkQTensorToFitPass final } }; +} // namespace + } // namespace mlir::qtensor diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 3ade4cd5ce..c2d0f08f02 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -215,8 +215,8 @@ summarizeInsertGroup(llvm::ArrayRef ops, } auto& chain = chains[chainIdx]; - chain.writes.push_back( - InsertWrite{insertOp.getScalar(), insertOp.getIndex()}); + chain.writes.push_back(InsertWrite{.scalar = insertOp.getScalar(), + .index = insertOp.getIndex()}); if (!consumedInsertResults.contains(insertOp.getResult())) { if (chain.finalTensor) { diff --git a/mlir/lib/Support/Passes.cpp b/mlir/lib/Support/Passes.cpp index 254159d76b..dd29c96c1b 100644 --- a/mlir/lib/Support/Passes.cpp +++ b/mlir/lib/Support/Passes.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/QTensor/Transforms/Passes.h" #include +#include #include #include #include diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index f7b3235451..4b62f96449 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -26,12 +26,10 @@ #include #include -#include #include #include #include #include -#include using namespace mlir; using namespace mlir::qco; diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 3708d6517b..53f8d76cd9 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -24,17 +24,20 @@ #include "qco_programs.h" #include -#include #include #include +#include #include #include #include +#include #include +#include #include #include #include +#include #include #include #include @@ -97,14 +100,14 @@ class QTensorTest : public ::testing::Test { // 1. QTensorUtils — direct tests of scalar chain helpers // ============================================================================ -TEST_F(QTensorTest, AreEquivalentIndices_SameValueIsEquivalent) { +TEST_F(QTensorTest, AreEquivalentIndicesSameValueIsEquivalent) { QCOProgramBuilder builder(context.get()); builder.initialize(); auto c2 = arith::ConstantIndexOp::create(builder, 2); EXPECT_TRUE(areEquivalentIndices(c2.getResult(), c2.getResult())); } -TEST_F(QTensorTest, AreEquivalentIndices_DifferentConstantsAreNotEquivalent) { +TEST_F(QTensorTest, AreEquivalentIndicesDifferentConstantsAreNotEquivalent) { QCOProgramBuilder builder(context.get()); builder.initialize(); auto c0 = arith::ConstantIndexOp::create(builder, 0); @@ -112,13 +115,13 @@ TEST_F(QTensorTest, AreEquivalentIndices_DifferentConstantsAreNotEquivalent) { EXPECT_FALSE(areEquivalentIndices(c0.getResult(), c1.getResult())); } -TEST_F(QTensorTest, TensorChainHelpers_InsertAndExtractAreRecognized) { +TEST_F(QTensorTest, TensorChainHelpersInsertAndExtractAreRecognized) { QCOProgramBuilder builder(context.get()); builder.initialize(); auto tensor = builder.qtensorAlloc(3); auto [outTensor, q0] = builder.qtensorExtract(tensor, 0); - auto insert = builder.qtensorInsert(q0, outTensor, 0).getDefiningOp(); - auto extract = outTensor.getDefiningOp(); + auto* insert = builder.qtensorInsert(q0, outTensor, 0).getDefiningOp(); + auto* extract = outTensor.getDefiningOp(); ASSERT_NE(insert, nullptr); ASSERT_NE(extract, nullptr); @@ -128,7 +131,7 @@ TEST_F(QTensorTest, TensorChainHelpers_InsertAndExtractAreRecognized) { EXPECT_EQ(getTensorChainInput(extract), tensor); } -TEST_F(QTensorTest, TensorChainHelpers_SetTensorChainInputRewiresOperand) { +TEST_F(QTensorTest, TensorChainHelpersSetTensorChainInputRewiresOperand) { auto module = QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { auto t1 = b.qtensorAlloc(3); @@ -152,7 +155,7 @@ TEST_F(QTensorTest, TensorChainHelpers_SetTensorChainInputRewiresOperand) { // ============================================================================ /// A valid static alloc should pass verification. -TEST_F(QTensorTest, AllocOp_ValidStaticAllocVerifies) { +TEST_F(QTensorTest, AllocOpValidStaticAllocVerifies) { auto module = QCOProgramBuilder::build( context.get(), [](QCOProgramBuilder& b) { b.qtensorAlloc(3); }); ASSERT_TRUE(module); @@ -162,7 +165,7 @@ TEST_F(QTensorTest, AllocOp_ValidStaticAllocVerifies) { /// AllocOp with a constant size ≤ 0 must fail verification. /// Note: The builder asserts on zero/negative, so we verify the verifier /// by constructing the op manually bypassing the builder assertion. -TEST_F(QTensorTest, AllocOp_ZeroSizeFailsVerification) { +TEST_F(QTensorTest, AllocOpZeroSizeFailsVerification) { // Build a module manually to bypass builder-level assertion. auto loc = UnknownLoc::get(context.get()); auto module = ModuleOp::create(loc); @@ -184,7 +187,7 @@ TEST_F(QTensorTest, AllocOp_ZeroSizeFailsVerification) { } /// AllocOp where static result type dim ≠ constant size must fail. -TEST_F(QTensorTest, AllocOp_StaticTypeMismatchFailsVerification) { +TEST_F(QTensorTest, AllocOpStaticTypeMismatchFailsVerification) { auto loc = UnknownLoc::get(context.get()); auto module = ModuleOp::create(loc); ImplicitLocOpBuilder b(loc, context.get()); @@ -200,7 +203,7 @@ TEST_F(QTensorTest, AllocOp_StaticTypeMismatchFailsVerification) { } /// AllocOp with a dynamic result type but a constant size operand is valid. -TEST_F(QTensorTest, AllocOp_DynamicTypeWithConstantSizeVerifies) { +TEST_F(QTensorTest, AllocOpDynamicTypeWithConstantSizeVerifies) { auto loc = UnknownLoc::get(context.get()); auto module = ModuleOp::create(loc); ImplicitLocOpBuilder b(loc, context.get()); @@ -217,7 +220,7 @@ TEST_F(QTensorTest, AllocOp_DynamicTypeWithConstantSizeVerifies) { /// AllocOp with a static result type but a non-constant (dynamic) size /// operand must fail verification. -TEST_F(QTensorTest, AllocOp_StaticTypeWithDynamicSizeOperandFailsVerification) { +TEST_F(QTensorTest, AllocOpStaticTypeWithDynamicSizeOperandFailsVerification) { auto loc = UnknownLoc::get(context.get()); auto module = ModuleOp::create(loc); ImplicitLocOpBuilder b(loc, context.get()); @@ -240,7 +243,7 @@ TEST_F(QTensorTest, AllocOp_StaticTypeWithDynamicSizeOperandFailsVerification) { // ============================================================================ /// An alloc immediately followed by dealloc should be eliminated entirely. -TEST_F(QTensorTest, DeallocOp_AllocDeallocPairIsRemoved) { +TEST_F(QTensorTest, DeallocOpAllocDeallocPairIsRemoved) { auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { auto t = b.qtensorAlloc(3); b.qtensorDealloc(t); @@ -253,7 +256,7 @@ TEST_F(QTensorTest, DeallocOp_AllocDeallocPairIsRemoved) { } /// A dealloc whose operand is not directly an AllocOp should not be removed. -TEST_F(QTensorTest, DeallocOp_DeallocOfNonAllocIsNotRemoved) { +TEST_F(QTensorTest, DeallocOpDeallocOfNonAllocIsNotRemoved) { // Extract then insert to create a different tensor SSA value before dealloc. auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { auto tensor = b.qtensorAlloc(3); @@ -276,7 +279,7 @@ TEST_F(QTensorTest, DeallocOp_DeallocOfNonAllocIsNotRemoved) { // ============================================================================ /// A valid extract at index 0 from a size-1 tensor must pass verification. -TEST_F(QTensorTest, ExtractOp_ValidIndexVerifies) { +TEST_F(QTensorTest, ExtractOpValidIndexVerifies) { auto module = QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { auto t = b.qtensorAlloc(1); @@ -287,7 +290,7 @@ TEST_F(QTensorTest, ExtractOp_ValidIndexVerifies) { } /// An extract at a negative constant index must fail verification. -TEST_F(QTensorTest, ExtractOp_NegativeIndexFailsVerification) { +TEST_F(QTensorTest, ExtractOpNegativeIndexFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); Value tensor = builder.qtensorAlloc(3); @@ -300,7 +303,7 @@ TEST_F(QTensorTest, ExtractOp_NegativeIndexFailsVerification) { /// An extract at an index equal to the tensor dimension must fail (out of /// bounds). -TEST_F(QTensorTest, ExtractOp_IndexAtDimFailsVerification) { +TEST_F(QTensorTest, ExtractOpIndexAtDimFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); Value tensor = builder.qtensorAlloc(3); @@ -313,7 +316,7 @@ TEST_F(QTensorTest, ExtractOp_IndexAtDimFailsVerification) { } /// An extract at an index one less than the dimension must pass. -TEST_F(QTensorTest, ExtractOp_IndexAtDimMinusOneVerifies) { +TEST_F(QTensorTest, ExtractOpIndexAtDimMinusOneVerifies) { // Build inside a proper func.func body via QCOProgramBuilder. QCOProgramBuilder builder(context.get()); builder.initialize(); @@ -332,7 +335,7 @@ TEST_F(QTensorTest, ExtractOp_IndexAtDimMinusOneVerifies) { /// foldExtractAfterInsert: extract(insert(t, q, i), i) → (t, q) /// The fold must eliminate the round-trip at the same index. -TEST_F(QTensorTest, ExtractOp_FoldExtractAfterInsertSameIndex) { +TEST_F(QTensorTest, ExtractOpFoldExtractAfterInsertSameIndex) { // Use the full QCO pipeline so that both the fold and subsequent DCE of the // dead InsertOp run to convergence (single canonicalizer pass may leave // unreachable Pure ops if DCE and folding don't interleave). @@ -354,7 +357,7 @@ TEST_F(QTensorTest, ExtractOp_FoldExtractAfterInsertSameIndex) { } /// foldExtractAfterInsert: extract at a different index must NOT fold. -TEST_F(QTensorTest, ExtractOp_NoFoldExtractAfterInsertDifferentIndex) { +TEST_F(QTensorTest, ExtractOpNoFoldExtractAfterInsertDifferentIndex) { auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { auto tensor = b.qtensorAlloc(3); auto [outTensor, q0] = b.qtensorExtract(tensor, 0); @@ -371,7 +374,7 @@ TEST_F(QTensorTest, ExtractOp_NoFoldExtractAfterInsertDifferentIndex) { /// RemoveInsertExtractPair: extract through a disjoint InsertOp should find /// the original extract and eliminate both. -TEST_F(QTensorTest, ExtractOp_RemoveInsertExtractPairThroughDisjointInsert) { +TEST_F(QTensorTest, ExtractOpRemoveInsertExtractPairThroughDisjointInsert) { auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { auto tensor = b.qtensorAlloc(3); // Extract qubit 0. @@ -392,7 +395,7 @@ TEST_F(QTensorTest, ExtractOp_RemoveInsertExtractPairThroughDisjointInsert) { /// RemoveInsertExtractPair: a nested ExtractOp at the same index must block /// re-ordering (linearity guard). TEST_F(QTensorTest, - ExtractOp_RemoveInsertExtractPairBlockedByNestedExtractAtSameIndex) { + ExtractOpRemoveInsertExtractPairBlockedByNestedExtractAtSameIndex) { // Pattern: insert q0 at 0, then extract-at-0 twice (would violate linearity // if the first extraction were skipped). auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { @@ -416,7 +419,7 @@ TEST_F(QTensorTest, // ============================================================================ /// A valid insert at index 0 into a size-3 tensor must pass verification. -TEST_F(QTensorTest, InsertOp_ValidIndexVerifies) { +TEST_F(QTensorTest, InsertOpValidIndexVerifies) { auto module = QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { auto t = b.qtensorAlloc(3); @@ -428,7 +431,7 @@ TEST_F(QTensorTest, InsertOp_ValidIndexVerifies) { } /// An insert at a negative constant index must fail verification. -TEST_F(QTensorTest, InsertOp_NegativeIndexFailsVerification) { +TEST_F(QTensorTest, InsertOpNegativeIndexFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); // Extract qubit 0 to get both a tracked tensor and a qubit. @@ -445,7 +448,7 @@ TEST_F(QTensorTest, InsertOp_NegativeIndexFailsVerification) { } /// An insert at an index equal to the destination dimension must fail. -TEST_F(QTensorTest, InsertOp_IndexAtDimFailsVerification) { +TEST_F(QTensorTest, InsertOpIndexAtDimFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); auto tensor = builder.qtensorAlloc(3); @@ -459,7 +462,7 @@ TEST_F(QTensorTest, InsertOp_IndexAtDimFailsVerification) { /// foldInsertAfterExtract: insert(extract(t, i).qubit, extract(t, i).out, i) /// should fold to `t`. -TEST_F(QTensorTest, InsertOp_FoldInsertAfterExtractSameIndex) { +TEST_F(QTensorTest, InsertOpFoldInsertAfterExtractSameIndex) { auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { auto tensor = b.qtensorAlloc(3); auto [outTensor, q0] = b.qtensorExtract(tensor, 0); @@ -475,7 +478,7 @@ TEST_F(QTensorTest, InsertOp_FoldInsertAfterExtractSameIndex) { /// foldInsertAfterExtract: inserting the qubit at a different index must NOT /// fold. -TEST_F(QTensorTest, InsertOp_NoFoldInsertAfterExtractDifferentIndex) { +TEST_F(QTensorTest, InsertOpNoFoldInsertAfterExtractDifferentIndex) { auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { auto tensor = b.qtensorAlloc(3); auto [outTensor, q0] = b.qtensorExtract(tensor, 0); @@ -489,7 +492,7 @@ TEST_F(QTensorTest, InsertOp_NoFoldInsertAfterExtractDifferentIndex) { /// foldInsertAfterExtract: inserting into a different tensor (not the extract's /// out_tensor) must NOT fold. -TEST_F(QTensorTest, InsertOp_NoFoldInsertAfterExtractDifferentDest) { +TEST_F(QTensorTest, InsertOpNoFoldInsertAfterExtractDifferentDest) { auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { auto t1 = b.qtensorAlloc(3); auto t2 = b.qtensorAlloc(3); @@ -506,7 +509,7 @@ TEST_F(QTensorTest, InsertOp_NoFoldInsertAfterExtractDifferentDest) { /// RemoveExtractInsertPair: an insert-after-extract that has been modified /// (qubit passed through an H gate) must NOT be eliminated. -TEST_F(QTensorTest, InsertOp_NoRemoveExtractInsertPairAfterMutation) { +TEST_F(QTensorTest, InsertOpNoRemoveExtractInsertPairAfterMutation) { auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { auto tensor = b.qtensorAlloc(3); auto [outTensor, q0] = b.qtensorExtract(tensor, 0); @@ -521,7 +524,7 @@ TEST_F(QTensorTest, InsertOp_NoRemoveExtractInsertPairAfterMutation) { /// RemoveExtractInsertPair: insert shadowed by an earlier InsertOp at the same /// index must not be eliminated. -TEST_F(QTensorTest, InsertOp_RemoveExtractInsertPairBlockedByShadowingInsert) { +TEST_F(QTensorTest, InsertOpRemoveExtractInsertPairBlockedByShadowingInsert) { // Pattern: // t1, q0 = extract(alloc, 0) // t2 = insert(q0, t1, 0) ← overwrites index 0 @@ -544,7 +547,7 @@ TEST_F(QTensorTest, InsertOp_RemoveExtractInsertPairBlockedByShadowingInsert) { /// RemoveExtractInsertPair: insert blocked by a disjoint insert (different /// index) should still succeed in finding the original extract. -TEST_F(QTensorTest, InsertOp_RemoveExtractInsertPairThroughDisjointInsert) { +TEST_F(QTensorTest, InsertOpRemoveExtractInsertPairThroughDisjointInsert) { // Pattern: // t1, q0 = extract(alloc, 0) // t2, q1 = extract(t1, 1) ← disjoint from index 0 @@ -577,12 +580,17 @@ struct QTensorIntegrationTestCase { mqt::test::NamedBuilder programBuilder; mqt::test::NamedBuilder referenceBuilder; + // NOLINTNEXTLINE(llvm-prefer-static-over-anonymous-namespace) friend std::ostream& operator<<(std::ostream& os, - const QTensorIntegrationTestCase& info) { - return os << "QTensor{" << info.name << "}"; - } + const QTensorIntegrationTestCase& info); }; +// NOLINTNEXTLINE(llvm-prefer-static-over-anonymous-namespace) +std::ostream& operator<<(std::ostream& os, + const QTensorIntegrationTestCase& info) { + return os << "QTensor{" << info.name << "}"; +} + class QTensorIntegrationTest : public testing::TestWithParam { protected: From c808bf1d67c89152a6ea99f615bdee08912c04df Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 00:01:18 +0200 Subject: [PATCH 54/71] Address some of the Rabbit's comments --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 22 ++++++++++++++----- .../QC/Transforms/ShrinkQubitRegisters.cpp | 2 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 12 +++++----- .../lib/Dialect/QCO/IR/Operations/ResetOp.cpp | 22 +++++++++---------- .../lib/Dialect/QIR/Transforms/QIRCleanup.cpp | 13 ++++++++--- .../QTensor/IR/Operations/DeallocOp.cpp | 2 +- .../QTensor/IR/Operations/ExtractOp.cpp | 7 ++++-- .../QTensor/IR/Operations/InsertOp.cpp | 6 ++--- .../QTensor/Transforms/ShrinkRegisters.cpp | 14 ++++++------ mlir/lib/Support/IRVerification.cpp | 2 +- .../JeffRoundTrip/test_jeff_round_trip.cpp | 11 +++++++++- .../Dialect/QTensor/IR/test_qtensor_ir.cpp | 18 +++++++-------- 12 files changed, 79 insertions(+), 52 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index b18306cd9d..6323da023c 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -408,7 +408,6 @@ struct ConvertMemRefAllocOp final return failure(); } - auto memref = op.getResult(); Value qtensor; if (shape[0] == ShapedType::kDynamic) { qtensor = rewriter.replaceOpWithNewOp( @@ -419,7 +418,9 @@ struct ConvertMemRefAllocOp final qtensor = rewriter.replaceOpWithNewOp(op, size.getResult()); } + auto& state = getState(); + auto memref = op.getResult(); assignMappedTensor(state, qtensor.getDefiningOp(), memref, qtensor); return success(); @@ -520,17 +521,26 @@ struct ConvertMemRefDeallocOp final auto qtensor = lookupMappedTensor(state, op.getOperation(), memref); // Filter out qubits belonging to this tensor - for (auto it = qubitMap.begin(); it != qubitMap.end(); ++it) { - auto& [qcQubit, qcoQubit] = *it; - auto& [reg, index] = qubitInfoMap[qcQubit]; + for (auto it = qubitMap.begin(); it != qubitMap.end();) { + auto current = it++; + auto qcQubit = current->first; + auto qcoQubit = current->second; + + auto infoIt = qubitInfoMap.find(qcQubit); + if (infoIt == qubitInfoMap.end()) { + continue; + } + + auto& [reg, index] = infoIt->second; if (reg != memref) { continue; } + qtensor = qtensor::InsertOp::create(rewriter, op.getLoc(), qcoQubit, qtensor, index) .getResult(); - qubitMap.erase(it); - qubitInfoMap.erase(qcQubit); + qubitMap.erase(current); + qubitInfoMap.erase(infoIt); } tensorMap.erase(memref); diff --git a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp index 963bd82f82..9445559a41 100644 --- a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp +++ b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp @@ -80,7 +80,7 @@ struct ShrinkQubitRegister final : OpRewritePattern { return failure(); } auto index = getLoadIndex(loadOp); - if (!index) { + if (!index || *index < 0 || *index >= memRefType.getDimSize(0)) { return failure(); } loadOps.push_back(loadOp); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index c17da088d7..d0d2abb9f3 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -894,24 +894,22 @@ OwningOpRef QCOProgramBuilder::finalize() { validTensorIds.insert(info.regId); } - llvm::DenseMap registerQubits; + llvm::DenseMap>> + qubitsByRegister; for (auto [qubit, info] : validQubits) { if (info.regId == -1 || !validTensorIds.contains(info.regId)) { // Automatically deallocate all still-allocated qubits SinkOp::create(*this, qubit); } else { - registerQubits.try_emplace(qubit, info); + qubitsByRegister[info.regId].emplace_back(qubit, info); } } // Automatically deallocate all still-allocated tensors for (auto& [tensor, tensorInfo] : validTensors) { - Value currentTensor = tensor; + auto currentTensor = tensor; // Filter out qubits belonging to this tensor - for (auto& [qubit, qubitInfo] : registerQubits) { - if (qubitInfo.regId != tensorInfo.regId) { - continue; - } + for (auto& [qubit, qubitInfo] : qubitsByRegister[tensorInfo.regId]) { auto indexValue = constantFromScalar(*this, getLoc(), qubitInfo.regIndex); currentTensor = qtensor::InsertOp::create(*this, qubit, currentTensor, indexValue) diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index e8f57cd859..73927eb887 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -10,9 +10,9 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" #include -#include #include #include #include @@ -23,16 +23,15 @@ using namespace mlir; using namespace mlir::qco; /** - * @brief Check if a `qtensor.extract` operation is guaranteed to read from a - * `qtensor.alloc` chain. + * @brief Check if a `qtensor.extract` operation reads from a `qtensor.alloc` + * chain. * - * In QTensor's linear tensor model, reads/writes on different indices commute. - * We can therefore skip over `qtensor.insert` on other indices while tracing - * provenance. A write to the same index invalidates the proof. + * @details In QTensor's linear tensor model, reads/writes on different indices + * commute. We can therefore skip over `qtensor.insert` on other indices while + * tracing provenance. A write to the same index invalidates the proof. */ -static bool originatesFromAlloc(qtensor::ExtractOp extractOp) { - Value currentTensor = extractOp.getTensor(); - const auto extractIndex = getAsOpFoldResult(extractOp.getIndex()); +static bool originatesFromQTensorAlloc(qtensor::ExtractOp extractOp) { + auto currentTensor = extractOp.getTensor(); while (auto* definingOp = currentTensor.getDefiningOp()) { if (llvm::isa(definingOp)) { @@ -45,7 +44,8 @@ static bool originatesFromAlloc(qtensor::ExtractOp extractOp) { } if (auto insertOp = llvm::dyn_cast(definingOp)) { - if (getAsOpFoldResult(insertOp.getIndex()) == extractIndex) { + if (qtensor::areEquivalentIndices(extractOp.getIndex(), + insertOp.getIndex())) { return false; } currentTensor = insertOp.getDest(); @@ -76,7 +76,7 @@ struct RemoveResetAfterExtract final : OpRewritePattern { } // Check if the tensor originates from an AllocOp - if (!originatesFromAlloc(extractOp)) { + if (!originatesFromQTensorAlloc(extractOp)) { return failure(); } diff --git a/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp b/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp index c55efb3414..f1188ce19a 100644 --- a/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp +++ b/mlir/lib/Dialect/QIR/Transforms/QIRCleanup.cpp @@ -142,10 +142,12 @@ static void normalizeQIRMetadata(ModuleOp module) { } namespace { + /** - * @brief Remove dead QIR qubit-array allocation/release pairs. - * @details Matches an unused `__quantum__rt__qubit_array_allocate` / - * `__quantum__rt__qubit_array_release` pair on the same stack slot. + * @brief Remove matching allocation-release pairs of qubit arrays. + * @details Matches an unused + * `__quantum__rt__qubit_array_allocate`-`__quantum__rt__qubit_array_release` + * pair on the same stack slot. */ struct RemoveDeadQubitArrayPair final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -197,6 +199,11 @@ struct RemoveDeadQubitArrayPair final : OpRewritePattern { } }; +/** + * @brief Clean up QIR. + * @details Removes dead allocation-release pairs of qubit arrays, drops unused + * external declarations, and normalizes QIR metadata. + */ struct QIRCleanupPass final : impl::QIRCleanupPassBase { protected: void runOnOperation() override { diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp index bf62cd1df4..4abddab7cc 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/DeallocOp.cpp @@ -20,7 +20,7 @@ using namespace mlir::qtensor; namespace { /** - * @brief Remove matching allocation and deallocation pairs without operations + * @brief Remove matching allocation-deallocation pairs without operations * between them. */ struct RemoveAllocDeallocPair final : OpRewritePattern { diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 3f0fa9880e..0d30f26781 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -93,12 +93,15 @@ struct RemoveInsertExtractPair final : OpRewritePattern { break; } } else if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { - if (areEquivalentIndices(nestedExtractOp.getIndex(), - extractOp.getIndex())) { + if (areEquivalentIndices(extractOp.getIndex(), + nestedExtractOp.getIndex())) { // Do not reorder reads from the same index. return failure(); } + } else { + return failure(); } + traversedOps.push_back(definingOp); currentTensor = getTensorChainInput(definingOp); } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 88f066f750..12c9b85581 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -66,8 +66,8 @@ OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { * chain by traversing nested scalar tensor ops. */ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { - Value current = tensor; - while (Operation* definingOp = current.getDefiningOp()) { + auto current = tensor; + while (auto* definingOp = current.getDefiningOp()) { if (auto nestedInsertOp = llvm::dyn_cast(definingOp)) { // A more recent write to the same index shadows all older extracts. if (areEquivalentIndices(nestedInsertOp.getIndex(), index)) { @@ -91,7 +91,7 @@ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { namespace { /** - * @brief Remove matching `qtensor.insert` and `qtensor.extract` pairs. + * @brief Remove matching `qtensor.insert`-`qtensor.extract` pairs. */ struct RemoveExtractInsertPair final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp index d4bfcf3e5a..feb64d6e15 100644 --- a/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp +++ b/mlir/lib/Dialect/QTensor/Transforms/ShrinkRegisters.cpp @@ -88,7 +88,7 @@ namespace mlir::qtensor { [[nodiscard]] static LogicalResult collectLiveIndices(AllocOp allocOp, llvm::BitVector& live, DeallocOp& deallocOp) { - Value tensor = allocOp.getResult(); + auto tensor = allocOp.getResult(); while (true) { auto* user = getLinearTensorUser(tensor); if (user == nullptr) { @@ -176,8 +176,8 @@ struct ShrinkStaticQTensor final : OpRewritePattern { auto newAlloc = AllocOp::create(rewriter, allocOp.getLoc(), size.getResult()); - Value oldTensor = allocOp.getResult(); - Value currentTensor = newAlloc.getResult(); + auto oldTensor = allocOp.getResult(); + auto currentTensor = newAlloc.getResult(); while (true) { Operation* currentOp = getLinearTensorUser(oldTensor); if (currentOp == nullptr) { @@ -208,8 +208,8 @@ struct ShrinkStaticQTensor final : OpRewritePattern { if (mappedIndex < 0) { return failure(); } - Value oldOutTensor = extractOp.getOutTensor(); - Operation* nextOp = getLinearTensorUser(oldOutTensor); + auto oldOutTensor = extractOp.getOutTensor(); + auto* nextOp = getLinearTensorUser(oldOutTensor); if (nextOp == nullptr) { return failure(); } @@ -244,8 +244,8 @@ struct ShrinkStaticQTensor final : OpRewritePattern { if (mappedIndex < 0) { return failure(); } - Value oldResultTensor = insertOp.getResult(); - Operation* nextOp = getLinearTensorUser(oldResultTensor); + auto oldResultTensor = insertOp.getResult(); + auto* nextOp = getLinearTensorUser(oldResultTensor); if (nextOp == nullptr) { return failure(); } diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index c2d0f08f02..de8e1134a3 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -167,7 +167,7 @@ static bool isCommutableQTensorInsertDependency(Operation* dependent, } static Value getInsertChainBaseTensor(Value tensor, const OperationSet& group) { - Value current = tensor; + auto current = tensor; while (auto insertOp = current.getDefiningOp()) { if (!group.contains(insertOp.getOperation())) { break; diff --git a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp index 5b046e5e98..c30dbaa90a 100644 --- a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp +++ b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -27,6 +28,7 @@ #include #include #include +#include #include #include @@ -101,7 +103,14 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(program.get(), "Converted Jeff IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); + PassManager pm(context.get()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createRemoveDeadValuesPass()); + if (pm.run(program.get()).failed()) { + llvm::errs() << "Failed to run cleanup passes." << "\n"; + } + printer.record(program.get(), "Canonicalized Converted Jeff IR" + name); EXPECT_TRUE(verify(*program).succeeded()); diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 53f8d76cd9..793e8b2d42 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -227,7 +227,7 @@ TEST_F(QTensorTest, AllocOpStaticTypeWithDynamicSizeOperandFailsVerification) { // We need a block argument to act as a non-constant size. auto* block = module.getBody(); block->addArgument(IndexType::get(context.get()), loc); - Value dynSizeVal = block->getArgument(0); + auto dynSizeVal = block->getArgument(0); b.setInsertionPointToEnd(block); auto qubitType = qco::QubitType::get(context.get()); @@ -267,8 +267,8 @@ TEST_F(QTensorTest, DeallocOpDeallocOfNonAllocIsNotRemoved) { }); ASSERT_TRUE(canonicalized); EXPECT_TRUE(verify(*canonicalized).succeeded()); - // After canonicalization the extract/insert pair simplifies, but there - // should still be either an alloc+dealloc pair or both get eliminated + // After canonicalization the extract-insert pair simplifies, but there + // should still be either an alloc-dealloc pair or both get eliminated // through further folding — just check the module is valid. // The important invariant: DeallocOp count is not negative, i.e., the // transform did not crash. @@ -293,7 +293,7 @@ TEST_F(QTensorTest, ExtractOpValidIndexVerifies) { TEST_F(QTensorTest, ExtractOpNegativeIndexFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); - Value tensor = builder.qtensorAlloc(3); + auto tensor = builder.qtensorAlloc(3); auto negIdx = arith::ConstantIndexOp::create(builder, -1); ExtractOp::create(builder, tensor, negIdx.getResult()); auto module = builder.finalize(); @@ -306,7 +306,7 @@ TEST_F(QTensorTest, ExtractOpNegativeIndexFailsVerification) { TEST_F(QTensorTest, ExtractOpIndexAtDimFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); - Value tensor = builder.qtensorAlloc(3); + auto tensor = builder.qtensorAlloc(3); // index = 3, tensor has dim 3 → out of bounds auto idx3 = arith::ConstantIndexOp::create(builder, 3); ExtractOp::create(builder, tensor, idx3.getResult()); @@ -321,7 +321,7 @@ TEST_F(QTensorTest, ExtractOpIndexAtDimMinusOneVerifies) { QCOProgramBuilder builder(context.get()); builder.initialize(); // qtensorAlloc(3) creates tensor<3x!qco.qubit> and tracks it. - Value tensor = builder.qtensorAlloc(3); + auto tensor = builder.qtensorAlloc(3); auto idx2 = arith::ConstantIndexOp::create(builder, 2); // Create extract at index 2 — last valid index for dim 3. // (Use the raw op creator to bypass builder tracking.) @@ -691,7 +691,7 @@ buildTwoQubitInsertChainProgram(MLIRContext* context, const int64_t qubit0Target = swapInsertTargets ? 1 : 0; const int64_t qubit1Target = swapInsertTargets ? 0 : 1; - Value currentTensor = baseTensor; + auto currentTensor = baseTensor; if (reverseInsertOrder) { currentTensor = builder.qtensorInsert(qubit1, currentTensor, qubit1Target); currentTensor = builder.qtensorInsert(qubit0, currentTensor, qubit0Target); @@ -711,7 +711,7 @@ buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, builder.initialize(); auto tensor = builder.qtensorAlloc(3); - Value tensorAfterReads = tensor; + auto tensorAfterReads = tensor; Value qubit0 = nullptr; Value qubit1 = nullptr; @@ -730,7 +730,7 @@ buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, const int64_t q0Target = 0; const int64_t q1Target = swapInsertTargets ? 2 : 1; - Value tensorAfterWrites = tensorAfterReads; + auto tensorAfterWrites = tensorAfterReads; if (reverseOrder) { tensorAfterWrites = builder.qtensorInsert(qubit0, tensorAfterWrites, q0Target); From b5cebd3ef784505079f1275616929aeb57055a27 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:14:28 +0200 Subject: [PATCH 55/71] Address the Rabbit's comments --- mlir/include/mlir/Support/Passes.h | 6 ++-- .../QC/Transforms/ShrinkQubitRegisters.cpp | 6 ++++ mlir/lib/Support/Passes.cpp | 30 +++++++++---------- .../JeffRoundTrip/test_jeff_round_trip.cpp | 4 +-- .../Dialect/QTensor/IR/test_qtensor_ir.cpp | 22 ++++++++++---- 5 files changed, 41 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Support/Passes.h b/mlir/include/mlir/Support/Passes.h index 4b8057ced2..de9d076296 100644 --- a/mlir/include/mlir/Support/Passes.h +++ b/mlir/include/mlir/Support/Passes.h @@ -19,19 +19,19 @@ class PassManager; * @brief Populate a QC-oriented cleanup pipeline on the given pass manager. * @details Adds generic cleanup and QC qubit-register shrinking. */ -void populateQCCleanupPipeline(mlir::PassManager& passManager); +void populateQCCleanupPipeline(mlir::PassManager& pm); /** * @brief Populate a QCO-oriented cleanup pipeline on the given pass manager. * @details Adds generic cleanup and qtensor shrink-to-fit. */ -void populateQCOCleanupPipeline(mlir::PassManager& passManager); +void populateQCOCleanupPipeline(mlir::PassManager& pm); /** * @brief Populate a QIR-oriented cleanup pipeline on the given pass manager. * @details Adds generic cleanup and QIR-specific simplifications. */ -void populateQIRCleanupPipeline(mlir::PassManager& passManager); +void populateQIRCleanupPipeline(mlir::PassManager& pm); /** * @brief Run the QC-oriented cleanup pipeline on a module. diff --git a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp index 9445559a41..accd4330cd 100644 --- a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp +++ b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp @@ -66,6 +66,12 @@ struct ShrinkQubitRegister final : OpRewritePattern { if (!llvm::isa(memRefType.getElementType())) { return failure(); } + if (!memRefType.getLayout().isIdentity()) { + return failure(); + } + if (memRefType.getMemorySpace() != 0) { + return failure(); + } llvm::SmallVector loadOps; llvm::SmallVector liveIndices; diff --git a/mlir/lib/Support/Passes.cpp b/mlir/lib/Support/Passes.cpp index dd29c96c1b..5f1f2f73f3 100644 --- a/mlir/lib/Support/Passes.cpp +++ b/mlir/lib/Support/Passes.cpp @@ -23,9 +23,9 @@ using namespace mlir; -static void addSimplificationPasses(PassManager& passManager) { - passManager.addPass(createCanonicalizerPass()); - passManager.addPass(createCSEPass()); +static void addSimplificationPasses(PassManager& pm) { + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); } static void @@ -39,22 +39,22 @@ runWithPassManager(ModuleOp module, } } -void populateQCCleanupPipeline(PassManager& passManager) { - addSimplificationPasses(passManager); - passManager.addPass(qc::createShrinkQubitRegistersPass()); - passManager.addPass(createRemoveDeadValuesPass()); +void populateQCCleanupPipeline(PassManager& pm) { + addSimplificationPasses(pm); + pm.addPass(qc::createShrinkQubitRegistersPass()); + pm.addPass(createRemoveDeadValuesPass()); } -void populateQCOCleanupPipeline(PassManager& passManager) { - addSimplificationPasses(passManager); - passManager.addPass(qtensor::createShrinkQTensorToFitPass()); - passManager.addPass(createRemoveDeadValuesPass()); +void populateQCOCleanupPipeline(PassManager& pm) { + addSimplificationPasses(pm); + pm.addPass(qtensor::createShrinkQTensorToFitPass()); + pm.addPass(createRemoveDeadValuesPass()); } -void populateQIRCleanupPipeline(PassManager& passManager) { - addSimplificationPasses(passManager); - passManager.addPass(qir::createQIRCleanupPass()); - passManager.addPass(createRemoveDeadValuesPass()); +void populateQIRCleanupPipeline(PassManager& pm) { + addSimplificationPasses(pm); + pm.addPass(qir::createQIRCleanupPass()); + pm.addPass(createRemoveDeadValuesPass()); } void runQCCleanupPipeline(ModuleOp module) { diff --git a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp index c30dbaa90a..7b1b08d2a1 100644 --- a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp +++ b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp @@ -107,9 +107,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); pm.addPass(createRemoveDeadValuesPass()); - if (pm.run(program.get()).failed()) { - llvm::errs() << "Failed to run cleanup passes." << "\n"; - } + EXPECT_TRUE(pm.run(program.get()).succeeded()); printer.record(program.get(), "Canonicalized Converted Jeff IR" + name); EXPECT_TRUE(verify(*program).succeeded()); diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 793e8b2d42..194b85b9e5 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -141,6 +141,7 @@ TEST_F(QTensorTest, TensorChainHelpersSetTensorChainInputRewiresOperand) { auto insert = InsertOp::create( b, q0, out0, arith::ConstantIndexOp::create(b, 0).getResult()); setTensorChainInput(insert.getOperation(), out1); + EXPECT_EQ(getTensorChainInput(insert.getOperation()), out1); (void)InsertOp::create( b, q0, out1, arith::ConstantIndexOp::create(b, 1).getResult()); (void)q1; @@ -224,15 +225,24 @@ TEST_F(QTensorTest, AllocOpStaticTypeWithDynamicSizeOperandFailsVerification) { auto loc = UnknownLoc::get(context.get()); auto module = ModuleOp::create(loc); ImplicitLocOpBuilder b(loc, context.get()); - // We need a block argument to act as a non-constant size. - auto* block = module.getBody(); - block->addArgument(IndexType::get(context.get()), loc); - auto dynSizeVal = block->getArgument(0); + b.setInsertionPointToStart(module.getBody()); + + // We need a block argument to act as a non-constant size + // Create a func.func to hold the block argument + auto funcType = + FunctionType::get(context.get(), {IndexType::get(context.get())}, {}); + auto func = func::FuncOp::create(b, "test", funcType); + + // Add a block with an index argument + auto& block = func.getBody().emplaceBlock(); + block.addArgument(IndexType::get(context.get()), loc); + + b.setInsertionPointToStart(&block); - b.setInsertionPointToEnd(block); - auto qubitType = qco::QubitType::get(context.get()); // Static result type dim = 3, but size operand is dynamic → error + auto qubitType = qco::QubitType::get(context.get()); auto staticType = RankedTensorType::get({3}, qubitType); + auto dynSizeVal = block.getArgument(0); AllocOp::create(b, staticType, dynSizeVal); EXPECT_TRUE(verify(module).failed()); From 2e50202fb21cc8f60426b59b3095223807b24f8f Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:44:49 +0200 Subject: [PATCH 56/71] Fix linter errors --- mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp | 2 +- .../unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp index accd4330cd..48ce84a42c 100644 --- a/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp +++ b/mlir/lib/Dialect/QC/Transforms/ShrinkQubitRegisters.cpp @@ -69,7 +69,7 @@ struct ShrinkQubitRegister final : OpRewritePattern { if (!memRefType.getLayout().isIdentity()) { return failure(); } - if (memRefType.getMemorySpace() != 0) { + if (memRefType.getMemorySpace() != nullptr) { return failure(); } diff --git a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp index 7b1b08d2a1..23081ab805 100644 --- a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp +++ b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp @@ -19,7 +19,6 @@ #include #include -#include #include #include #include From fdced7d308fed1ef5024db9be8fd4e04e737900d Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:26:39 +0200 Subject: [PATCH 57/71] Raise error if qc.dealloc is called manually --- mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 914b4ee29f..33e7459aab 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -446,12 +446,14 @@ QCProgramBuilder::inv(const llvm::function_ref& body) { QCProgramBuilder& QCProgramBuilder::dealloc(Value qubit) { checkFinalized(); + if (llvm::isa_and_nonnull(qubit.getDefiningOp())) { + llvm::reportFatalUsageError( + "Register-backed qubits cannot be deallocated manually"); + } + // Check if the qubit is in the tracking set if (!allocatedQubits.erase(qubit)) { - // Qubit was not found in the set - either never allocated or already - // deallocated - llvm::reportFatalUsageError( - "Double deallocation or invalid qubit deallocation"); + llvm::reportFatalUsageError("Invalid qubit deallocation"); } // Create the DeallocOp From 82510ba80a2c65d8d9d859225c6d3fc3544d78a8 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:34:21 +0200 Subject: [PATCH 58/71] Make areEquivalentIndices() more strict --- .../mlir/Dialect/QTensor/IR/QTensorUtils.h | 7 ++++- mlir/lib/Support/IRVerification.cpp | 26 ++++++++++++------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h index b46f853cea..24061e3b5d 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h @@ -22,7 +22,12 @@ namespace mlir::qtensor { * @brief Checks whether two index values are equivalent for matching. */ inline bool areEquivalentIndices(Value lhs, Value rhs) { - return getAsOpFoldResult(lhs) == getAsOpFoldResult(rhs); + auto lhsValue = getConstantIntValue(lhs); + auto rhsValue = getConstantIntValue(rhs); + if (!lhsValue || !rhsValue) { + return lhs == rhs; + } + return *lhsValue == *rhsValue; } /** diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index de8e1134a3..611f8c2399 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -42,6 +42,7 @@ using namespace mlir; namespace { + /// Compute a structural hash for an operation (excluding SSA value identities). /// This hash is based on operation name, types, and attributes only. struct OperationStructuralHash { @@ -130,26 +131,33 @@ struct InsertChainSummary { Value finalTensor; llvm::SmallVector writes; }; + } // namespace -static bool areValuesEquivalent(Value lhsValue, Value rhsValue, + +static bool areValuesEquivalent(Value lhs, Value rhs, ValueEquivalenceMap& valueMap) { - if (auto it = valueMap.find(lhsValue); it != valueMap.end()) { - return it->second == rhsValue; + if (auto it = valueMap.find(lhs); it != valueMap.end()) { + return it->second == rhs; } - valueMap[lhsValue] = rhsValue; + valueMap[lhs] = rhs; return true; } -static bool areEquivalentIndices(Value lhsValue, Value rhsValue) { - return getAsOpFoldResult(lhsValue) == getAsOpFoldResult(rhsValue); +static bool areEquivalentIndices(Value lhs, Value rhs) { + auto lhsValue = getConstantIntValue(lhs); + auto rhsValue = getConstantIntValue(rhs); + if (!lhsValue || !rhsValue) { + return lhs == rhs; + } + return *lhsValue == *rhsValue; } -static bool areIndexValuesEquivalent(Value lhsValue, Value rhsValue, +static bool areIndexValuesEquivalent(Value lhs, Value rhs, ValueEquivalenceMap& valueMap) { - if (areEquivalentIndices(lhsValue, rhsValue)) { + if (areEquivalentIndices(lhs, rhs)) { return true; } - return areValuesEquivalent(lhsValue, rhsValue, valueMap); + return areValuesEquivalent(lhs, rhs, valueMap); } static bool isQTensorInsertOp(Operation* op) { From 18c7747501e96d4c36506242d3072f209f32d816 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:43:41 +0200 Subject: [PATCH 59/71] Return result of cleanup pipelines --- mlir/include/mlir/Support/Passes.h | 8 +++-- mlir/lib/Support/Passes.cpp | 23 +++++++------- .../Compiler/test_compiler_pipeline.cpp | 4 +-- .../JeffRoundTrip/test_jeff_round_trip.cpp | 6 ++-- .../Conversion/QCOToQC/test_qco_to_qc.cpp | 6 ++-- .../Conversion/QCToQCO/test_qc_to_qco.cpp | 6 ++-- .../Conversion/QCToQIR/test_qc_to_qir.cpp | 6 ++-- mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp | 4 +-- .../test_quantum_computation_translation.cpp | 4 +-- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 8 ++--- mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp | 4 +-- .../Dialect/QTensor/IR/test_qtensor_ir.cpp | 30 +++++++++---------- 12 files changed, 57 insertions(+), 52 deletions(-) diff --git a/mlir/include/mlir/Support/Passes.h b/mlir/include/mlir/Support/Passes.h index de9d076296..f780be87f9 100644 --- a/mlir/include/mlir/Support/Passes.h +++ b/mlir/include/mlir/Support/Passes.h @@ -10,6 +10,8 @@ #pragma once +#include "mlir/Support/LogicalResult.h" + namespace mlir { class ModuleOp; class PassManager; @@ -36,14 +38,14 @@ void populateQIRCleanupPipeline(mlir::PassManager& pm); /** * @brief Run the QC-oriented cleanup pipeline on a module. */ -void runQCCleanupPipeline(mlir::ModuleOp module); +[[nodiscard]] mlir::LogicalResult runQCCleanupPipeline(mlir::ModuleOp module); /** * @brief Run the QCO-oriented cleanup pipeline on a module. */ -void runQCOCleanupPipeline(mlir::ModuleOp module); +[[nodiscard]] mlir::LogicalResult runQCOCleanupPipeline(mlir::ModuleOp module); /** * @brief Run the QIR-oriented cleanup pipeline on a module. */ -void runQIRCleanupPipeline(mlir::ModuleOp module); +[[nodiscard]] mlir::LogicalResult runQIRCleanupPipeline(mlir::ModuleOp module); diff --git a/mlir/lib/Support/Passes.cpp b/mlir/lib/Support/Passes.cpp index 5f1f2f73f3..0369056533 100644 --- a/mlir/lib/Support/Passes.cpp +++ b/mlir/lib/Support/Passes.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include using namespace mlir; @@ -28,7 +29,7 @@ static void addSimplificationPasses(PassManager& pm) { pm.addPass(createCSEPass()); } -static void +static LogicalResult runWithPassManager(ModuleOp module, const llvm::function_ref populatePasses, const llvm::StringRef errorMessage) { @@ -36,7 +37,9 @@ runWithPassManager(ModuleOp module, populatePasses(pm); if (pm.run(module).failed()) { llvm::errs() << errorMessage << "\n"; + return failure(); } + return success(); } void populateQCCleanupPipeline(PassManager& pm) { @@ -57,17 +60,17 @@ void populateQIRCleanupPipeline(PassManager& pm) { pm.addPass(createRemoveDeadValuesPass()); } -void runQCCleanupPipeline(ModuleOp module) { - runWithPassManager(module, populateQCCleanupPipeline, - "Failed to run QC cleanup pipeline."); +[[nodiscard]] LogicalResult runQCCleanupPipeline(ModuleOp module) { + return runWithPassManager(module, populateQCCleanupPipeline, + "Failed to run QC cleanup pipeline."); } -void runQCOCleanupPipeline(ModuleOp module) { - runWithPassManager(module, populateQCOCleanupPipeline, - "Failed to run QCO cleanup pipeline."); +[[nodiscard]] LogicalResult runQCOCleanupPipeline(ModuleOp module) { + return runWithPassManager(module, populateQCOCleanupPipeline, + "Failed to run QCO cleanup pipeline."); } -void runQIRCleanupPipeline(ModuleOp module) { - runWithPassManager(module, populateQIRCleanupPipeline, - "Failed to run QIR cleanup pipeline."); +[[nodiscard]] LogicalResult runQIRCleanupPipeline(ModuleOp module) { + return runWithPassManager(module, populateQIRCleanupPipeline, + "Failed to run QIR cleanup pipeline."); } diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index 169ec794a1..ba6f4a18e1 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -99,7 +99,7 @@ class CompilerPipelineTest [[nodiscard]] mlir::OwningOpRef buildQCReference(const QCProgramBuilderFn builder) const { auto module = mlir::qc::QCProgramBuilder::build(context.get(), builder.fn); - runQCCleanupPipeline(module.get()); + EXPECT_TRUE(runQCCleanupPipeline(module.get()).succeeded()); return module; } @@ -107,7 +107,7 @@ class CompilerPipelineTest buildQIRReference(const QIRProgramBuilderFn builder) const { auto module = mlir::qir::QIRProgramBuilder::build(context.get(), builder.fn); - runQIRCleanupPipeline(module.get()); + EXPECT_TRUE(runQIRCleanupPipeline(module.get()).succeeded()); return module; } diff --git a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp index 23081ab805..1afb831cdc 100644 --- a/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp +++ b/mlir/unittests/Conversion/JeffRoundTrip/test_jeff_round_trip.cpp @@ -94,7 +94,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(program.get(), "Original QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -115,7 +115,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(program.get(), "Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -125,7 +125,7 @@ TEST_P(JeffRoundTripTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index 4d3afbe03f..c1e5d81d43 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -90,7 +90,7 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { printer.record(program.get(), "Original QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -98,7 +98,7 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { printer.record(program.get(), "Converted QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCCleanupPipeline(program.get()); + EXPECT_TRUE(runQCCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized Converted QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -108,7 +108,7 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runQCCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index 6c595e9b5b..9454f8309a 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -89,7 +89,7 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { printer.record(program.get(), "Original QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCCleanupPipeline(program.get()); + EXPECT_TRUE(runQCCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -97,7 +97,7 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { printer.record(program.get(), "Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized Converted QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -107,7 +107,7 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index 994637e348..65982d3673 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -87,7 +87,7 @@ TEST_P(QCToQIRTest, ProgramEquivalence) { printer.record(program.get(), "Original QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCCleanupPipeline(program.get()); + EXPECT_TRUE(runQCCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -95,7 +95,7 @@ TEST_P(QCToQIRTest, ProgramEquivalence) { printer.record(program.get(), "Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQIRCleanupPipeline(program.get()); + EXPECT_TRUE(runQIRCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -105,7 +105,7 @@ TEST_P(QCToQIRTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runQIRCleanupPipeline(reference.get()); + EXPECT_TRUE(runQIRCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index cb10225730..9c3afe4ecf 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -78,7 +78,7 @@ TEST_P(QCTest, ProgramEquivalence) { printer.record(program.get(), "Original QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCCleanupPipeline(program.get()); + EXPECT_TRUE(runQCCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QC IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -87,7 +87,7 @@ TEST_P(QCTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runQCCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QC IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp index d2a544b1cc..79bcd296c0 100644 --- a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp +++ b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp @@ -81,7 +81,7 @@ TEST_P(QuantumComputationTranslationTest, ProgramEquivalence) { printer.record(translated.get(), "Translated QC IR" + name); EXPECT_TRUE(mlir::verify(*translated).succeeded()); - runQCCleanupPipeline(translated.get()); + EXPECT_TRUE(runQCCleanupPipeline(translated.get()).succeeded()); printer.record(translated.get(), "Canonicalized Translated QC IR" + name); EXPECT_TRUE(mlir::verify(*translated).succeeded()); @@ -91,7 +91,7 @@ TEST_P(QuantumComputationTranslationTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QC IR" + name); EXPECT_TRUE(mlir::verify(*reference).succeeded()); - runQCCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QC IR" + name); EXPECT_TRUE(mlir::verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 4b62f96449..5a7c1d8add 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -78,7 +78,7 @@ TEST_P(QCOTest, ProgramEquivalence) { printer.record(program.get(), "Original QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QCO IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -87,7 +87,7 @@ TEST_P(QCOTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QCO IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); @@ -118,14 +118,14 @@ TEST_F(QCOTest, DirectIfBuilder) { auto directBuilder = builder.finalize(); ASSERT_TRUE(directBuilder); EXPECT_TRUE(verify(*directBuilder).succeeded()); - runQCOCleanupPipeline(directBuilder.get()); + EXPECT_TRUE(runQCOCleanupPipeline(directBuilder.get()).succeeded()); EXPECT_TRUE(verify(*directBuilder).succeeded()); auto refBuilder = QCOProgramBuilder::build(context.get(), MQT_NAMED_BUILDER(simpleIf).fn); ASSERT_TRUE(refBuilder); EXPECT_TRUE(verify(*refBuilder).succeeded()); - runQCOCleanupPipeline(refBuilder.get()); + EXPECT_TRUE(runQCOCleanupPipeline(refBuilder.get()).succeeded()); EXPECT_TRUE(verify(*refBuilder).succeeded()); EXPECT_TRUE(areModulesEquivalentWithPermutations(directBuilder.get(), diff --git a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp index 6d6c66d7f9..225a234b84 100644 --- a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp +++ b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp @@ -71,7 +71,7 @@ TEST_P(QIRTest, ProgramEquivalence) { printer.record(program.get(), "Original QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQIRCleanupPipeline(program.get()); + EXPECT_TRUE(runQIRCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -80,7 +80,7 @@ TEST_P(QIRTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runQIRCleanupPipeline(reference.get()); + EXPECT_TRUE(runQIRCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 194b85b9e5..8a80faf120 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -360,7 +360,7 @@ TEST_F(QTensorTest, ExtractOpFoldExtractAfterInsertSameIndex) { }); ASSERT_TRUE(module); EXPECT_TRUE(verify(*module).succeeded()); - runQCOCleanupPipeline(module.get()); + EXPECT_TRUE(runQCOCleanupPipeline(module.get()).succeeded()); EXPECT_TRUE(verify(*module).succeeded()); // The extra extract at the same index should fold away. EXPECT_EQ(countOps(*module), 1U); // original extract @@ -626,7 +626,7 @@ TEST_P(QTensorIntegrationTest, ProgramEquivalence) { printer.record(program.get(), "Original QTensor IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); printer.record(program.get(), "Canonicalized QTensor IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -635,7 +635,7 @@ TEST_P(QTensorIntegrationTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QTensor IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); printer.record(reference.get(), "Canonicalized Reference QTensor IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); @@ -810,13 +810,13 @@ TEST_F(QTensorTest, InsertChainPermutationEquivalence) { auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); ASSERT_TRUE(program); EXPECT_TRUE(verify(*program).succeeded()); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildTwoQubitInsertChainProgram(context.get(), true, false); ASSERT_TRUE(reference); EXPECT_TRUE(verify(*reference).succeeded()); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE( @@ -826,12 +826,12 @@ TEST_F(QTensorTest, InsertChainPermutationEquivalence) { TEST_F(QTensorTest, InsertChainDifferentAssignmentsNotEquivalent) { auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); ASSERT_TRUE(program); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildTwoQubitInsertChainProgram(context.get(), true, true); ASSERT_TRUE(reference); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_FALSE( @@ -841,12 +841,12 @@ TEST_F(QTensorTest, InsertChainDifferentAssignmentsNotEquivalent) { TEST_F(QTensorTest, MixedExtractInsertPermutationEquivalence) { auto program = buildMixedExtractInsertProgram(context.get(), false, false); ASSERT_TRUE(program); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildMixedExtractInsertProgram(context.get(), true, false); ASSERT_TRUE(reference); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE( @@ -856,12 +856,12 @@ TEST_F(QTensorTest, MixedExtractInsertPermutationEquivalence) { TEST_F(QTensorTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { auto program = buildMixedExtractInsertProgram(context.get(), false, false); ASSERT_TRUE(program); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildMixedExtractInsertProgram(context.get(), true, true); ASSERT_TRUE(reference); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_FALSE( @@ -871,12 +871,12 @@ TEST_F(QTensorTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { TEST_F(QTensorTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { auto program = buildResetWithCommutingInsertProgram(context.get(), true); ASSERT_TRUE(program); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildResetWithCommutingInsertProgram(context.get(), false); ASSERT_TRUE(reference); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE( @@ -886,12 +886,12 @@ TEST_F(QTensorTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { TEST_F(QTensorTest, ResetAfterExtractThroughSameIndexInsertIsNotEliminated) { auto program = buildResetWithSameIndexInsertProgram(context.get(), true); ASSERT_TRUE(program); - runQCOCleanupPipeline(program.get()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildResetWithSameIndexInsertProgram(context.get(), false); ASSERT_TRUE(reference); - runQCOCleanupPipeline(reference.get()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_FALSE( From 5f719f9943493d8e92b428adbe81e599ba39f051 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 16:15:01 +0200 Subject: [PATCH 60/71] Address the Rabbit's comments --- mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h | 2 +- mlir/lib/Support/IRVerification.cpp | 2 +- mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp | 11 ++++------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h index 24061e3b5d..6cfd1842da 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h @@ -25,7 +25,7 @@ inline bool areEquivalentIndices(Value lhs, Value rhs) { auto lhsValue = getConstantIntValue(lhs); auto rhsValue = getConstantIntValue(rhs); if (!lhsValue || !rhsValue) { - return lhs == rhs; + return false; } return *lhsValue == *rhsValue; } diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 611f8c2399..6d6e419ba9 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -147,7 +147,7 @@ static bool areEquivalentIndices(Value lhs, Value rhs) { auto lhsValue = getConstantIntValue(lhs); auto rhsValue = getConstantIntValue(rhs); if (!lhsValue || !rhsValue) { - return lhs == rhs; + return false; } return *lhsValue == *rhsValue; } diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 8a80faf120..d38dca5cec 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -232,18 +232,15 @@ TEST_F(QTensorTest, AllocOpStaticTypeWithDynamicSizeOperandFailsVerification) { auto funcType = FunctionType::get(context.get(), {IndexType::get(context.get())}, {}); auto func = func::FuncOp::create(b, "test", funcType); - - // Add a block with an index argument - auto& block = func.getBody().emplaceBlock(); - block.addArgument(IndexType::get(context.get()), loc); - - b.setInsertionPointToStart(&block); + auto* block = func.addEntryBlock(); + b.setInsertionPointToStart(block); // Static result type dim = 3, but size operand is dynamic → error auto qubitType = qco::QubitType::get(context.get()); auto staticType = RankedTensorType::get({3}, qubitType); - auto dynSizeVal = block.getArgument(0); + auto dynSizeVal = block->getArgument(0); AllocOp::create(b, staticType, dynSizeVal); + func::ReturnOp::create(b); EXPECT_TRUE(verify(module).failed()); } From b149a673f348717c6baf01aa2e3dd8250ec817f4 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 16:56:01 +0200 Subject: [PATCH 61/71] Make index checks even more strict --- .../mlir/Dialect/QTensor/IR/QTensorUtils.h | 8 ++++- .../lib/Dialect/QCO/IR/Operations/ResetOp.cpp | 21 +++++++++---- .../QTensor/IR/Operations/ExtractOp.cpp | 30 +++++++++++++------ .../QTensor/IR/Operations/InsertOp.cpp | 19 ++++++++++-- mlir/lib/Support/IRVerification.cpp | 9 ++++++ 5 files changed, 68 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h index 6cfd1842da..594be34481 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorUtils.h @@ -19,7 +19,13 @@ namespace mlir::qtensor { /** - * @brief Checks whether two index values are equivalent for matching. + * @brief Checks whether two index values are equivalent. + * + * @details This is a conservative check that returns true if both indices are + * constant integers with the same value. It returns false if either index is + * non-constant or if they have different constant values. Note that this means + * that some equivalent indices may be considered non-equivalent by this + * function, but no non-equivalent indices will be considered equivalent. */ inline bool areEquivalentIndices(Value lhs, Value rhs) { auto lhsValue = getConstantIntValue(lhs); diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index 73927eb887..c5f3521d9e 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/QTensor/IR/QTensorUtils.h" #include +#include #include #include #include @@ -31,24 +32,32 @@ using namespace mlir::qco; * tracing provenance. A write to the same index invalidates the proof. */ static bool originatesFromQTensorAlloc(qtensor::ExtractOp extractOp) { - auto currentTensor = extractOp.getTensor(); + auto current = extractOp.getTensor(); - while (auto* definingOp = currentTensor.getDefiningOp()) { + auto extractIndex = extractOp.getIndex(); + if (!getConstantIntValue(extractIndex)) { + return false; + } + + while (auto* definingOp = current.getDefiningOp()) { if (llvm::isa(definingOp)) { return true; } if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { - currentTensor = nestedExtractOp.getTensor(); + current = nestedExtractOp.getTensor(); continue; } if (auto insertOp = llvm::dyn_cast(definingOp)) { - if (qtensor::areEquivalentIndices(extractOp.getIndex(), - insertOp.getIndex())) { + auto insertIndex = insertOp.getIndex(); + if (!getConstantIntValue(insertIndex)) { + return false; + } + if (qtensor::areEquivalentIndices(extractIndex, insertIndex)) { return false; } - currentTensor = insertOp.getDest(); + current = insertOp.getDest(); continue; } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 0d30f26781..5e575537a2 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -40,8 +40,8 @@ LogicalResult ExtractOp::verify() { } /** - * @brief If an ExtractOp consumes an InsertOp with the same index, - * return the scalar and the destTensor from the InsertOp directly. + * @brief Check if a `qtensor.extract` operation reads from a `qtensor.insert` + * operation. */ static InsertOp foldExtractAfterInsert(ExtractOp extractOp) { auto insertOp = extractOp.getTensor().getDefiningOp(); @@ -79,23 +79,35 @@ struct RemoveInsertExtractPair final : OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter& rewriter) const override { llvm::SmallVector traversedOps; - Value currentTensor = extractOp.getTensor(); + Value current = extractOp.getTensor(); InsertOp matchedInsertOp = nullptr; - while (auto* definingOp = currentTensor.getDefiningOp()) { + auto extractIndex = extractOp.getIndex(); + if (!getConstantIntValue(extractIndex)) { + return failure(); + } + + while (auto* definingOp = current.getDefiningOp()) { if (!isTensorChainOp(definingOp)) { break; } if (auto insertOp = llvm::dyn_cast(definingOp)) { - if (areEquivalentIndices(insertOp.getIndex(), extractOp.getIndex())) { + auto insertIndex = insertOp.getIndex(); + if (!getConstantIntValue(insertIndex)) { + return failure(); + } + if (areEquivalentIndices(insertIndex, extractIndex)) { matchedInsertOp = insertOp; break; } } else if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { - if (areEquivalentIndices(extractOp.getIndex(), - nestedExtractOp.getIndex())) { - // Do not reorder reads from the same index. + auto nestedExtractIndex = nestedExtractOp.getIndex(); + if (!getConstantIntValue(nestedExtractIndex)) { + return failure(); + } + // Do not reorder reads from the same index + if (areEquivalentIndices(extractIndex, nestedExtractIndex)) { return failure(); } } else { @@ -103,7 +115,7 @@ struct RemoveInsertExtractPair final : OpRewritePattern { } traversedOps.push_back(definingOp); - currentTensor = getTensorChainInput(definingOp); + current = getTensorChainInput(definingOp); } if (!matchedInsertOp) { diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 12c9b85581..2e807bf487 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -67,17 +67,30 @@ OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { */ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { auto current = tensor; + + if (!getConstantIntValue(index)) { + return nullptr; + } + while (auto* definingOp = current.getDefiningOp()) { if (auto nestedInsertOp = llvm::dyn_cast(definingOp)) { - // A more recent write to the same index shadows all older extracts. - if (areEquivalentIndices(nestedInsertOp.getIndex(), index)) { + auto nestedInsertIndex = nestedInsertOp.getIndex(); + if (!getConstantIntValue(nestedInsertIndex)) { + return nullptr; + } + // A more recent write to the same index shadows all older extracts + if (areEquivalentIndices(nestedInsertIndex, index)) { return nullptr; } current = nestedInsertOp.getDest(); continue; } if (auto extractOp = llvm::dyn_cast(definingOp)) { - if (areEquivalentIndices(extractOp.getIndex(), index)) { + auto extractIndex = extractOp.getIndex(); + if (!getConstantIntValue(extractIndex)) { + return nullptr; + } + if (areEquivalentIndices(extractIndex, index)) { return extractOp; } current = extractOp.getTensor(); diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 6d6e419ba9..3f184d7396 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -143,6 +143,15 @@ static bool areValuesEquivalent(Value lhs, Value rhs, return true; } +/** + * @brief Checks whether two index values are equivalent. + * + * @details This is a conservative check that returns true if both indices are + * constant integers with the same value. It returns false if either index is + * non-constant or if they have different constant values. Note that this means + * that some equivalent indices may be considered non-equivalent by this + * function, but no non-equivalent indices will be considered equivalent. + */ static bool areEquivalentIndices(Value lhs, Value rhs) { auto lhsValue = getConstantIntValue(lhs); auto rhsValue = getConstantIntValue(rhs); From 8f18b6dcc25c1332cb30a3819ecc4ebd21ed0e77 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 16:58:20 +0200 Subject: [PATCH 62/71] Add test --- mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index d38dca5cec..69c68abe0c 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -115,6 +115,14 @@ TEST_F(QTensorTest, AreEquivalentIndicesDifferentConstantsAreNotEquivalent) { EXPECT_FALSE(areEquivalentIndices(c0.getResult(), c1.getResult())); } +TEST_F(QTensorTest, AreEquivalentIndicesSameConstantDifferentSSAAreEquivalent) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto lhs = arith::ConstantIndexOp::create(builder, 2); + auto rhs = arith::ConstantIndexOp::create(builder, 2); + EXPECT_TRUE(areEquivalentIndices(lhs.getResult(), rhs.getResult())); +} + TEST_F(QTensorTest, TensorChainHelpersInsertAndExtractAreRecognized) { QCOProgramBuilder builder(context.get()); builder.initialize(); From e982b150895b0051a03726749e1b54e10cb28799 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:22:18 +0200 Subject: [PATCH 63/71] Address the Rabbit's comments --- .../lib/Dialect/QCO/IR/Operations/ResetOp.cpp | 7 ++++++ mlir/lib/Support/IRVerification.cpp | 24 ++++--------------- .../Dialect/QTensor/IR/test_qtensor_ir.cpp | 13 ++-------- 3 files changed, 13 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index c5f3521d9e..3db78fbd8b 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -45,6 +45,13 @@ static bool originatesFromQTensorAlloc(qtensor::ExtractOp extractOp) { } if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { + auto nestedExtractIndex = nestedExtractOp.getIndex(); + if (!getConstantIntValue(nestedExtractIndex)) { + return false; + } + if (qtensor::areEquivalentIndices(extractIndex, nestedExtractIndex)) { + return false; + } current = nestedExtractOp.getTensor(); continue; } diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 3f184d7396..6c3318df8c 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -10,6 +10,8 @@ #include "mlir/Support/IRVerification.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" + #include #include #include @@ -143,27 +145,9 @@ static bool areValuesEquivalent(Value lhs, Value rhs, return true; } -/** - * @brief Checks whether two index values are equivalent. - * - * @details This is a conservative check that returns true if both indices are - * constant integers with the same value. It returns false if either index is - * non-constant or if they have different constant values. Note that this means - * that some equivalent indices may be considered non-equivalent by this - * function, but no non-equivalent indices will be considered equivalent. - */ -static bool areEquivalentIndices(Value lhs, Value rhs) { - auto lhsValue = getConstantIntValue(lhs); - auto rhsValue = getConstantIntValue(rhs); - if (!lhsValue || !rhsValue) { - return false; - } - return *lhsValue == *rhsValue; -} - static bool areIndexValuesEquivalent(Value lhs, Value rhs, ValueEquivalenceMap& valueMap) { - if (areEquivalentIndices(lhs, rhs)) { + if (qtensor::areEquivalentIndices(lhs, rhs)) { return true; } return areValuesEquivalent(lhs, rhs, valueMap); @@ -252,7 +236,7 @@ summarizeInsertGroup(llvm::ArrayRef ops, llvm::SmallVector seenIndices; for (const auto& write : chain.writes) { if (llvm::any_of(seenIndices, [&](Value seenIndex) { - return areEquivalentIndices(seenIndex, write.index); + return qtensor::areEquivalentIndices(seenIndex, write.index); })) { return false; } diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 69c68abe0c..7988d3c9ac 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -67,23 +67,14 @@ class QTensorTest : public ::testing::Test { context->loadAllAvailableDialects(); } - /// Build a module using the QCOProgramBuilder and run a lightweight cleanup - /// pipeline (canonicalizer + CSE + symbol DCE + canonicalizer). + /// Build a module using the QCOProgramBuilder and run the cleanup pipeline. [[nodiscard]] OwningOpRef buildAndCanonicalize(void (*buildFn)(QCOProgramBuilder&)) const { auto module = QCOProgramBuilder::build(context.get(), buildFn); if (!module) { return {}; } - - PassManager pm(context.get()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - pm.addPass(createSymbolDCEPass()); - pm.addPass(createCanonicalizerPass()); - if (pm.run(*module).failed()) { - return {}; - } + EXPECT_TRUE(runQCOCleanupPipeline(module.get()).succeeded()); return module; } From 88def0f65a30e18e1df1df4abdbed7064385c2a3 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:47:20 +0200 Subject: [PATCH 64/71] Fix linter errors --- mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 7988d3c9ac..5c6bf90996 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -34,8 +34,6 @@ #include #include #include -#include -#include #include #include From 51f719fa67399011078121525c03e1f33fc3f903 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:52:41 +0200 Subject: [PATCH 65/71] Address the Rabbit's final comment --- mlir/lib/Support/IRVerification.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 6c3318df8c..db2799a533 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -164,7 +164,16 @@ static bool isCommutableQTensorInsertDependency(Operation* dependent, if (!dependentInsert || !dependencyInsert) { return false; } - return dependentInsert.getDest() == dependencyInsert.getResult(); + if (dependentInsert.getDest() != dependencyInsert.getResult()) { + return false; + } + auto dependentIndex = dependentInsert.getIndex(); + auto dependencyIndex = dependencyInsert.getIndex(); + if (!getConstantIntValue(dependentIndex) || + !getConstantIntValue(dependencyIndex)) { + return false; + } + return !qtensor::areEquivalentIndices(dependentIndex, dependencyIndex); } static Value getInsertChainBaseTensor(Value tensor, const OperationSet& group) { From bf520ed675404ea38490cf88fb560b1061fff686 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:14:59 +0200 Subject: [PATCH 66/71] Address the Rabbit's final comment --- mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 5c6bf90996..e5d1a9ac8f 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -72,7 +72,9 @@ class QTensorTest : public ::testing::Test { if (!module) { return {}; } - EXPECT_TRUE(runQCOCleanupPipeline(module.get()).succeeded()); + if (runQCOCleanupPipeline(module.get()).failed()) { + return {}; + } return module; } From 543a27be82d459e8aa6eca942ea5a81a768d09eb Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 19:44:05 +0200 Subject: [PATCH 67/71] Streamline QTensor tests --- .../Dialect/QTensor/IR/test_qtensor_ir.cpp | 734 +++++------------- 1 file changed, 192 insertions(+), 542 deletions(-) diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index e5d1a9ac8f..562b4d322e 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -47,10 +47,6 @@ using namespace mlir::qtensor; using namespace mlir::qco; namespace { -// ============================================================================ -// Shared fixture — sets up an MLIR context with QTensor/QCO/Arith dialects -// and provides a QCOProgramBuilder for creating test programs. -// ============================================================================ class QTensorTest : public ::testing::Test { protected: @@ -88,7 +84,7 @@ class QTensorTest : public ::testing::Test { }; // ============================================================================ -// 1. QTensorUtils — direct tests of scalar chain helpers +// QTensorUtils // ============================================================================ TEST_F(QTensorTest, AreEquivalentIndicesSameValueIsEquivalent) { @@ -98,15 +94,7 @@ TEST_F(QTensorTest, AreEquivalentIndicesSameValueIsEquivalent) { EXPECT_TRUE(areEquivalentIndices(c2.getResult(), c2.getResult())); } -TEST_F(QTensorTest, AreEquivalentIndicesDifferentConstantsAreNotEquivalent) { - QCOProgramBuilder builder(context.get()); - builder.initialize(); - auto c0 = arith::ConstantIndexOp::create(builder, 0); - auto c1 = arith::ConstantIndexOp::create(builder, 1); - EXPECT_FALSE(areEquivalentIndices(c0.getResult(), c1.getResult())); -} - -TEST_F(QTensorTest, AreEquivalentIndicesSameConstantDifferentSSAAreEquivalent) { +TEST_F(QTensorTest, AreEquivalentIndicesSameConstantsAreEquivalent) { QCOProgramBuilder builder(context.get()); builder.initialize(); auto lhs = arith::ConstantIndexOp::create(builder, 2); @@ -114,75 +102,30 @@ TEST_F(QTensorTest, AreEquivalentIndicesSameConstantDifferentSSAAreEquivalent) { EXPECT_TRUE(areEquivalentIndices(lhs.getResult(), rhs.getResult())); } -TEST_F(QTensorTest, TensorChainHelpersInsertAndExtractAreRecognized) { +TEST_F(QTensorTest, AreEquivalentIndicesDifferentConstantsAreNotEquivalent) { QCOProgramBuilder builder(context.get()); builder.initialize(); - auto tensor = builder.qtensorAlloc(3); - auto [outTensor, q0] = builder.qtensorExtract(tensor, 0); - auto* insert = builder.qtensorInsert(q0, outTensor, 0).getDefiningOp(); - auto* extract = outTensor.getDefiningOp(); - - ASSERT_NE(insert, nullptr); - ASSERT_NE(extract, nullptr); - EXPECT_TRUE(isTensorChainOp(insert)); - EXPECT_TRUE(isTensorChainOp(extract)); - EXPECT_EQ(getTensorChainOutput(insert), insert->getResult(0)); - EXPECT_EQ(getTensorChainInput(extract), tensor); -} - -TEST_F(QTensorTest, TensorChainHelpersSetTensorChainInputRewiresOperand) { - auto module = - QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { - auto t1 = b.qtensorAlloc(3); - auto [out1, q1] = b.qtensorExtract(t1, 1); - auto t0 = b.qtensorAlloc(3); - auto [out0, q0] = b.qtensorExtract(t0, 0); - auto insert = InsertOp::create( - b, q0, out0, arith::ConstantIndexOp::create(b, 0).getResult()); - setTensorChainInput(insert.getOperation(), out1); - EXPECT_EQ(getTensorChainInput(insert.getOperation()), out1); - (void)InsertOp::create( - b, q0, out1, arith::ConstantIndexOp::create(b, 1).getResult()); - (void)q1; - (void)out0; - }); - ASSERT_TRUE(module); - EXPECT_TRUE(verify(*module).succeeded()); + auto c0 = arith::ConstantIndexOp::create(builder, 0); + auto c1 = arith::ConstantIndexOp::create(builder, 1); + EXPECT_FALSE(areEquivalentIndices(c0.getResult(), c1.getResult())); } // ============================================================================ -// 2. AllocOp — verify() tests +// AllocOp // ============================================================================ -/// A valid static alloc should pass verification. -TEST_F(QTensorTest, AllocOpValidStaticAllocVerifies) { - auto module = QCOProgramBuilder::build( - context.get(), [](QCOProgramBuilder& b) { b.qtensorAlloc(3); }); - ASSERT_TRUE(module); - EXPECT_TRUE(verify(*module).succeeded()); -} - /// AllocOp with a constant size ≤ 0 must fail verification. -/// Note: The builder asserts on zero/negative, so we verify the verifier -/// by constructing the op manually bypassing the builder assertion. TEST_F(QTensorTest, AllocOpZeroSizeFailsVerification) { - // Build a module manually to bypass builder-level assertion. auto loc = UnknownLoc::get(context.get()); auto module = ModuleOp::create(loc); ImplicitLocOpBuilder b(loc, context.get()); b.setInsertionPointToStart(module.getBody()); - // Create a constant 0 for the size operand. - auto c0 = arith::ConstantIndexOp::create(b, 0); - // Construct the result type that would match a size-0 tensor (which is - // invalid per the verifier). We use kDynamic so the type-level constraint - // won't block construction, but the constant operand (0) triggers the - // verifier. auto qubitType = qco::QubitType::get(context.get()); - auto dynType = RankedTensorType::get({ShapedType::kDynamic}, qubitType); - AllocOp::create(b, dynType, c0.getResult()); + auto tensorType = RankedTensorType::get({ShapedType::kDynamic}, qubitType); + auto c0 = arith::ConstantIndexOp::create(b, 0); + AllocOp::create(b, tensorType, c0.getResult()); - // The verifier should catch `sizeValue <= 0`. EXPECT_TRUE(verify(module).failed()); } @@ -193,11 +136,10 @@ TEST_F(QTensorTest, AllocOpStaticTypeMismatchFailsVerification) { ImplicitLocOpBuilder b(loc, context.get()); b.setInsertionPointToStart(module.getBody()); - auto c2 = arith::ConstantIndexOp::create(b, 2); // size operand = 2 auto qubitType = qco::QubitType::get(context.get()); - // result type says dimension = 3, but size operand = 2 → mismatch - auto staticType = RankedTensorType::get({3}, qubitType); - AllocOp::create(b, staticType, c2.getResult()); + auto tensorType = RankedTensorType::get({3}, qubitType); + auto c2 = arith::ConstantIndexOp::create(b, 2); + AllocOp::create(b, tensorType, c2.getResult()); EXPECT_TRUE(verify(module).failed()); } @@ -209,17 +151,15 @@ TEST_F(QTensorTest, AllocOpDynamicTypeWithConstantSizeVerifies) { ImplicitLocOpBuilder b(loc, context.get()); b.setInsertionPointToStart(module.getBody()); - auto c3 = arith::ConstantIndexOp::create(b, 3); auto qubitType = qco::QubitType::get(context.get()); - auto dynType = RankedTensorType::get({ShapedType::kDynamic}, qubitType); - AllocOp::create(b, dynType, c3.getResult()); + auto tensorType = RankedTensorType::get({ShapedType::kDynamic}, qubitType); + auto c3 = arith::ConstantIndexOp::create(b, 3); + AllocOp::create(b, tensorType, c3.getResult()); - // Dynamic result dim with constant positive size → valid. EXPECT_TRUE(verify(module).succeeded()); } -/// AllocOp with a static result type but a non-constant (dynamic) size -/// operand must fail verification. +/// AllocOp with a static result type but a dynamic size fails verification. TEST_F(QTensorTest, AllocOpStaticTypeWithDynamicSizeOperandFailsVerification) { auto loc = UnknownLoc::get(context.get()); auto module = ModuleOp::create(loc); @@ -234,18 +174,17 @@ TEST_F(QTensorTest, AllocOpStaticTypeWithDynamicSizeOperandFailsVerification) { auto* block = func.addEntryBlock(); b.setInsertionPointToStart(block); - // Static result type dim = 3, but size operand is dynamic → error auto qubitType = qco::QubitType::get(context.get()); - auto staticType = RankedTensorType::get({3}, qubitType); - auto dynSizeVal = block->getArgument(0); - AllocOp::create(b, staticType, dynSizeVal); + auto tensorType = RankedTensorType::get({3}, qubitType); + auto size = block->getArgument(0); + AllocOp::create(b, tensorType, size); func::ReturnOp::create(b); EXPECT_TRUE(verify(module).failed()); } // ============================================================================ -// 3. DeallocOp — canonicalization (RemoveAllocDeallocPair) +// DeallocOp // ============================================================================ /// An alloc immediately followed by dealloc should be eliminated entirely. @@ -261,495 +200,130 @@ TEST_F(QTensorTest, DeallocOpAllocDeallocPairIsRemoved) { EXPECT_EQ(countOps(*canonicalized), 0U); } -/// A dealloc whose operand is not directly an AllocOp should not be removed. -TEST_F(QTensorTest, DeallocOpDeallocOfNonAllocIsNotRemoved) { - // Extract then insert to create a different tensor SSA value before dealloc. - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - auto [outTensor, q0] = b.qtensorExtract(tensor, 0); - auto q1 = b.h(q0); - auto afterInsert = b.qtensorInsert(q1, outTensor, 0); - b.qtensorDealloc(afterInsert); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); - // After canonicalization the extract-insert pair simplifies, but there - // should still be either an alloc-dealloc pair or both get eliminated - // through further folding — just check the module is valid. - // The important invariant: DeallocOp count is not negative, i.e., the - // transform did not crash. -} - // ============================================================================ -// 4. ExtractOp — verify(), fold, and canonicalization +// ExtractOp // ============================================================================ -/// A valid extract at index 0 from a size-1 tensor must pass verification. -TEST_F(QTensorTest, ExtractOpValidIndexVerifies) { - auto module = - QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { - auto t = b.qtensorAlloc(1); - b.qtensorExtract(t, 0); - }); - ASSERT_TRUE(module); - EXPECT_TRUE(verify(*module).succeeded()); -} - -/// An extract at a negative constant index must fail verification. +/// An extract at a negative constant index fails verification. TEST_F(QTensorTest, ExtractOpNegativeIndexFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); auto tensor = builder.qtensorAlloc(3); - auto negIdx = arith::ConstantIndexOp::create(builder, -1); - ExtractOp::create(builder, tensor, negIdx.getResult()); + auto index = arith::ConstantIndexOp::create(builder, -1); + ExtractOp::create(builder, tensor, index.getResult()); auto module = builder.finalize(); + ASSERT_TRUE(module); EXPECT_TRUE(verify(*module).failed()); } -/// An extract at an index equal to the tensor dimension must fail (out of -/// bounds). +/// An extract at an index equal to the tensor dimension fails verification. TEST_F(QTensorTest, ExtractOpIndexAtDimFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); auto tensor = builder.qtensorAlloc(3); - // index = 3, tensor has dim 3 → out of bounds - auto idx3 = arith::ConstantIndexOp::create(builder, 3); - ExtractOp::create(builder, tensor, idx3.getResult()); + auto index = arith::ConstantIndexOp::create(builder, 3); + ExtractOp::create(builder, tensor, index.getResult()); auto module = builder.finalize(); - ASSERT_TRUE(module); - EXPECT_TRUE(verify(*module).failed()); -} -/// An extract at an index one less than the dimension must pass. -TEST_F(QTensorTest, ExtractOpIndexAtDimMinusOneVerifies) { - // Build inside a proper func.func body via QCOProgramBuilder. - QCOProgramBuilder builder(context.get()); - builder.initialize(); - // qtensorAlloc(3) creates tensor<3x!qco.qubit> and tracks it. - auto tensor = builder.qtensorAlloc(3); - auto idx2 = arith::ConstantIndexOp::create(builder, 2); - // Create extract at index 2 — last valid index for dim 3. - // (Use the raw op creator to bypass builder tracking.) - ExtractOp::create(builder, tensor, idx2.getResult()); - // finalize() will dealloc the still-tracked tensor (two uses of %tensor is - // valid in MLIR SSA). Dead extract results (%tOut, %q) are fine. - auto module = builder.finalize(); ASSERT_TRUE(module); - EXPECT_TRUE(verify(*module).succeeded()); + EXPECT_TRUE(verify(*module).failed()); } -/// foldExtractAfterInsert: extract(insert(t, q, i), i) → (t, q) -/// The fold must eliminate the round-trip at the same index. +// foldExtractAfterInsert: Fold if index is equivalent TEST_F(QTensorTest, ExtractOpFoldExtractAfterInsertSameIndex) { - // Use the full QCO pipeline so that both the fold and subsequent DCE of the - // dead InsertOp run to convergence (single canonicalizer pass may leave - // unreachable Pure ops if DCE and folding don't interleave). - auto module = - QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - auto [outTensor, q0] = b.qtensorExtract(tensor, 0); - auto q1 = b.h(q0); - auto afterInsert = b.qtensorInsert(q1, outTensor, 0); - // Immediately extract the same qubit back — should fold away. - b.qtensorExtract(afterInsert, 0); - }); - ASSERT_TRUE(module); - EXPECT_TRUE(verify(*module).succeeded()); - EXPECT_TRUE(runQCOCleanupPipeline(module.get()).succeeded()); - EXPECT_TRUE(verify(*module).succeeded()); - // The extra extract at the same index should fold away. - EXPECT_EQ(countOps(*module), 1U); // original extract -} - -/// foldExtractAfterInsert: extract at a different index must NOT fold. -TEST_F(QTensorTest, ExtractOpNoFoldExtractAfterInsertDifferentIndex) { auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - auto [outTensor, q0] = b.qtensorExtract(tensor, 0); + auto tensor0 = b.qtensorAlloc(3); + auto [tensor1, q0] = b.qtensorExtract(tensor0, 0); auto q1 = b.h(q0); - auto afterInsert = b.qtensorInsert(q1, outTensor, 0); - // Extract at index 1 — different from the insert's index 0 - b.qtensorExtract(afterInsert, 1); + auto tensor2 = b.qtensorInsert(q1, tensor1, 0); + b.qtensorExtract(tensor2, 0); }); ASSERT_TRUE(canonicalized); EXPECT_TRUE(verify(*canonicalized).succeeded()); - // The insert should still be present (not folded). - EXPECT_GE(countOps(*canonicalized), 1U); + EXPECT_EQ(countOps(*canonicalized), 1U); + EXPECT_EQ(countOps(*canonicalized), 1U); } -/// RemoveInsertExtractPair: extract through a disjoint InsertOp should find -/// the original extract and eliminate both. -TEST_F(QTensorTest, ExtractOpRemoveInsertExtractPairThroughDisjointInsert) { +// foldExtractAfterInsert: Do not fold if index is different +TEST_F(QTensorTest, ExtractOpFoldExtractAfterInsertDifferentIndex) { auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - // Extract qubit 0. - auto [t1, q0] = b.qtensorExtract(tensor, 0); - // Extract qubit 1 (disjoint from qubit 0). - auto [t2, q1] = b.qtensorExtract(t1, 1); - // Insert qubit 1 back at index 1, then extract it again — same index. - // The canonicalizer should eliminate both the insert and the re-extract. - auto afterInsert = b.qtensorInsert(q1, t2, 1); - b.qtensorExtract(afterInsert, 1); - // Use q0 so it isn't dead. - b.h(q0); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); -} - -/// RemoveInsertExtractPair: a nested ExtractOp at the same index must block -/// re-ordering (linearity guard). -TEST_F(QTensorTest, - ExtractOpRemoveInsertExtractPairBlockedByNestedExtractAtSameIndex) { - // Pattern: insert q0 at 0, then extract-at-0 twice (would violate linearity - // if the first extraction were skipped). - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - auto [t1, q0] = b.qtensorExtract(tensor, 0); - // Re-insert q0 at index 0, producing a new tensor. - auto q0h = b.h(q0); - auto afterInsert = b.qtensorInsert(q0h, t1, 0); - // Attempt to extract index 0 again — the chain already has an - // extract-at-0 in it, blocking the RemoveInsertExtractPair pattern. - auto [t3, q0again] = b.qtensorExtract(afterInsert, 0); - b.h(q0again); - (void)t3; + auto tensor0 = b.qtensorAlloc(3); + auto [tensor1, q0] = b.qtensorExtract(tensor0, 0); + auto q1 = b.h(q0); + auto tensor2 = b.qtensorInsert(q1, tensor1, 0); + b.qtensorExtract(tensor2, 1); }); ASSERT_TRUE(canonicalized); EXPECT_TRUE(verify(*canonicalized).succeeded()); + EXPECT_EQ(countOps(*canonicalized), 1U); + EXPECT_EQ(countOps(*canonicalized), 1U); } // ============================================================================ -// 5. InsertOp — verify(), fold, and canonicalization +// InsertOp // ============================================================================ -/// A valid insert at index 0 into a size-3 tensor must pass verification. -TEST_F(QTensorTest, InsertOpValidIndexVerifies) { - auto module = - QCOProgramBuilder::build(context.get(), [](QCOProgramBuilder& b) { - auto t = b.qtensorAlloc(3); - auto [out, q] = b.qtensorExtract(t, 0); - b.qtensorInsert(q, out, 0); - }); - ASSERT_TRUE(module); - EXPECT_TRUE(verify(*module).succeeded()); -} - -/// An insert at a negative constant index must fail verification. +/// An insert at a negative constant index fails verification. TEST_F(QTensorTest, InsertOpNegativeIndexFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); - // Extract qubit 0 to get both a tracked tensor and a qubit. - auto tensor = builder.qtensorAlloc(3); - auto [outTensor, q0] = builder.qtensorExtract(tensor, 0); - // Insert at index -1 — raw op creation bypasses builder tracking. - auto negIdx = arith::ConstantIndexOp::create(builder, -1); - InsertOp::create(builder, q0, outTensor, negIdx.getResult()); - // finalize() will dealloc outTensor and sink q0 (both still tracked, both - // reused — valid in SSA). + auto tensor0 = builder.qtensorAlloc(3); + auto [tensor1, q0] = builder.qtensorExtract(tensor0, 0); + auto index = arith::ConstantIndexOp::create(builder, -1); + InsertOp::create(builder, q0, tensor1, index.getResult()); auto module = builder.finalize(); + ASSERT_TRUE(module); EXPECT_TRUE(verify(*module).failed()); } -/// An insert at an index equal to the destination dimension must fail. +/// An insert at an index equal to the destination dimension fails verification. TEST_F(QTensorTest, InsertOpIndexAtDimFailsVerification) { QCOProgramBuilder builder(context.get()); builder.initialize(); - auto tensor = builder.qtensorAlloc(3); - auto [outTensor, q0] = builder.qtensorExtract(tensor, 0); - auto idx3 = arith::ConstantIndexOp::create(builder, 3); // == dim - InsertOp::create(builder, q0, outTensor, idx3.getResult()); + auto tensor0 = builder.qtensorAlloc(3); + auto [tensor1, q0] = builder.qtensorExtract(tensor0, 0); + auto index = arith::ConstantIndexOp::create(builder, 3); + InsertOp::create(builder, q0, tensor1, index.getResult()); auto module = builder.finalize(); + ASSERT_TRUE(module); EXPECT_TRUE(verify(*module).failed()); } -/// foldInsertAfterExtract: insert(extract(t, i).qubit, extract(t, i).out, i) -/// should fold to `t`. -TEST_F(QTensorTest, InsertOpFoldInsertAfterExtractSameIndex) { - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - auto [outTensor, q0] = b.qtensorExtract(tensor, 0); - // Insert the extracted qubit back at the same index without modification. - b.qtensorInsert(q0, outTensor, 0); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); - // The extract-insert pair should have been eliminated entirely. - EXPECT_EQ(countOps(*canonicalized), 0U); - EXPECT_EQ(countOps(*canonicalized), 0U); -} - -/// foldInsertAfterExtract: inserting the qubit at a different index must NOT -/// fold. -TEST_F(QTensorTest, InsertOpNoFoldInsertAfterExtractDifferentIndex) { - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - auto [outTensor, q0] = b.qtensorExtract(tensor, 0); - // Insert at index 1 instead of 0 - b.qtensorInsert(q0, outTensor, 1); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); - EXPECT_GE(countOps(*canonicalized), 1U); -} - -/// foldInsertAfterExtract: inserting into a different tensor (not the extract's -/// out_tensor) must NOT fold. -TEST_F(QTensorTest, InsertOpNoFoldInsertAfterExtractDifferentDest) { - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto t1 = b.qtensorAlloc(3); - auto t2 = b.qtensorAlloc(3); - auto [outTensor, q0] = b.qtensorExtract(t1, 0); - // q0 came from t1, but we insert into t2's out-tensor - auto [t2out, q1] = b.qtensorExtract(t2, 1); - b.qtensorInsert(q0, t2out, 0); - b.h(q1); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); - EXPECT_GE(countOps(*canonicalized), 1U); -} - -/// RemoveExtractInsertPair: an insert-after-extract that has been modified -/// (qubit passed through an H gate) must NOT be eliminated. -TEST_F(QTensorTest, InsertOpNoRemoveExtractInsertPairAfterMutation) { - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - auto [outTensor, q0] = b.qtensorExtract(tensor, 0); - auto q1 = b.h(q0); // mutation — scalar ≠ extract.getResult() - b.qtensorInsert(q1, outTensor, 0); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); - // The HOp mutates the qubit, so the pair cannot be collapsed. - EXPECT_GE(countOps(*canonicalized), 1U); -} - -/// RemoveExtractInsertPair: insert shadowed by an earlier InsertOp at the same -/// index must not be eliminated. -TEST_F(QTensorTest, InsertOpRemoveExtractInsertPairBlockedByShadowingInsert) { - // Pattern: - // t1, q0 = extract(alloc, 0) - // t2 = insert(q0, t1, 0) ← overwrites index 0 - // t3 = insert(q0_h, t2, 0) ← another write to index 0 - // Trying to find the matching extract for the second insert should be - // blocked by the first insert at the same index. - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - auto [t1, q0] = b.qtensorExtract(tensor, 0); - // First insert q0 at 0. - auto t2 = b.qtensorInsert(q0, t1, 0); - // Second insert at 0 (different qubit — from another extract). - auto [t2out, q1] = b.qtensorExtract(t2, 0); - auto q1h = b.h(q1); - b.qtensorInsert(q1h, t2out, 0); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); -} - -/// RemoveExtractInsertPair: insert blocked by a disjoint insert (different -/// index) should still succeed in finding the original extract. -TEST_F(QTensorTest, InsertOpRemoveExtractInsertPairThroughDisjointInsert) { - // Pattern: - // t1, q0 = extract(alloc, 0) - // t2, q1 = extract(t1, 1) ← disjoint from index 0 - // t3 = insert(q1, t2, 1) ← insert at index 1 (disjoint) - // t4 = insert(q0, t3, 0) ← insert matches the extract at 0 - // Both the q0 extract-insert and q1 extract-insert should collapse. - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor = b.qtensorAlloc(3); - auto [t1, q0] = b.qtensorExtract(tensor, 0); - auto [t2, q1] = b.qtensorExtract(t1, 1); - auto t3 = b.qtensorInsert(q1, t2, 1); - b.qtensorInsert(q0, t3, 0); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); - // Both pairs should collapse. - EXPECT_EQ(countOps(*canonicalized), 0U); - EXPECT_EQ(countOps(*canonicalized), 0U); -} - -// ============================================================================ -// 6. Integration -// -// These tests use the full QCO cleanup pipeline and compare canonicalized -// modules for structural equivalence with permutations. -// ============================================================================ - -struct QTensorIntegrationTestCase { - std::string name; - mqt::test::NamedBuilder programBuilder; - mqt::test::NamedBuilder referenceBuilder; - - // NOLINTNEXTLINE(llvm-prefer-static-over-anonymous-namespace) - friend std::ostream& operator<<(std::ostream& os, - const QTensorIntegrationTestCase& info); -}; - -// NOLINTNEXTLINE(llvm-prefer-static-over-anonymous-namespace) -std::ostream& operator<<(std::ostream& os, - const QTensorIntegrationTestCase& info) { - return os << "QTensor{" << info.name << "}"; -} - -class QTensorIntegrationTest - : public testing::TestWithParam { -protected: - std::unique_ptr context; - - void SetUp() override { - DialectRegistry registry; - registry.insert(); - context = std::make_unique(); - context->appendDialectRegistry(registry); - context->loadAllAvailableDialects(); - } -}; - -TEST_P(QTensorIntegrationTest, ProgramEquivalence) { - const auto& [_, programBuilder, referenceBuilder] = GetParam(); - const auto name = " (" + GetParam().name + ")"; - mqt::test::DeferredPrinter printer; - - auto program = QCOProgramBuilder::build(context.get(), programBuilder.fn); - ASSERT_TRUE(program); - printer.record(program.get(), "Original QTensor IR" + name); - EXPECT_TRUE(verify(*program).succeeded()); - - EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); - printer.record(program.get(), "Canonicalized QTensor IR" + name); - EXPECT_TRUE(verify(*program).succeeded()); - - auto reference = QCOProgramBuilder::build(context.get(), referenceBuilder.fn); - ASSERT_TRUE(reference); - printer.record(reference.get(), "Reference QTensor IR" + name); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); - printer.record(reference.get(), "Canonicalized Reference QTensor IR" + name); - EXPECT_TRUE(verify(*reference).succeeded()); - - EXPECT_TRUE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); -} - -/// @name QTensor/QTensor.cpp (relocated from QCO test suite) -/// @{ -INSTANTIATE_TEST_SUITE_P( - QTensorOpsTest, QTensorIntegrationTest, - testing::Values( - QTensorIntegrationTestCase{"QTensorAlloc", - MQT_NAMED_BUILDER(qtensorAlloc), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QTensorIntegrationTestCase{"QTensorAllocDealloc", - MQT_NAMED_BUILDER(qtensorDealloc), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QTensorIntegrationTestCase{"QTensorFromElements", - MQT_NAMED_BUILDER(qtensorFromElements), - MQT_NAMED_BUILDER(qtensorFromElements)}, - QTensorIntegrationTestCase{"QTensorExtract", - MQT_NAMED_BUILDER(qtensorExtract), - MQT_NAMED_BUILDER(qtensorExtract)}, - QTensorIntegrationTestCase{"QTensorInsert", - MQT_NAMED_BUILDER(qtensorInsert), - MQT_NAMED_BUILDER(qtensorInsert)}, - QTensorIntegrationTestCase{ - "QTensorExtractInsertSameIndex", - MQT_NAMED_BUILDER(qtensorExtractInsertSameIndex), - MQT_NAMED_BUILDER(qtensorAlloc)}, - QTensorIntegrationTestCase{ - "QTensorExtractInsertIndexMismatch", - MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch), - MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch)}, - QTensorIntegrationTestCase{ - "QTensorInsertExtractSameIndex", - MQT_NAMED_BUILDER(qtensorInsertExtractSameIndex), - MQT_NAMED_BUILDER(qtensorInsert)}, - QTensorIntegrationTestCase{ - "QTensorInsertExtractIndexMismatch", - MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch), - MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch)})); -/// @} } // namespace // ============================================================================ -// 7. Integration — multi-qubit permutation equivalence tests +// Canonicalization // ============================================================================ static OwningOpRef buildTwoQubitInsertChainProgram(MLIRContext* context, const bool reverseInsertOrder, const bool swapInsertTargets) { - QCOProgramBuilder builder(context); - builder.initialize(); - - auto tensor = builder.qtensorAlloc(2); - auto [tensorAfterFirstExtract, qubit0] = builder.qtensorExtract(tensor, 0); - auto [baseTensor, qubit1] = - builder.qtensorExtract(tensorAfterFirstExtract, 1); - - const int64_t qubit0Target = swapInsertTargets ? 1 : 0; - const int64_t qubit1Target = swapInsertTargets ? 0 : 1; + const int64_t q0Target = swapInsertTargets ? 1 : 0; + const int64_t q1Target = swapInsertTargets ? 0 : 1; - auto currentTensor = baseTensor; - if (reverseInsertOrder) { - currentTensor = builder.qtensorInsert(qubit1, currentTensor, qubit1Target); - currentTensor = builder.qtensorInsert(qubit0, currentTensor, qubit0Target); - } else { - currentTensor = builder.qtensorInsert(qubit0, currentTensor, qubit0Target); - currentTensor = builder.qtensorInsert(qubit1, currentTensor, qubit1Target); - } - - builder.qtensorDealloc(currentTensor); - return builder.finalize(); -} - -static OwningOpRef -buildMixedExtractInsertProgram(MLIRContext* context, const bool reverseOrder, - const bool swapInsertTargets) { QCOProgramBuilder builder(context); builder.initialize(); - auto tensor = builder.qtensorAlloc(3); - auto tensorAfterReads = tensor; - Value qubit0 = nullptr; - Value qubit1 = nullptr; - - if (reverseOrder) { - std::tie(tensorAfterReads, qubit1) = - builder.qtensorExtract(tensorAfterReads, 1); - std::tie(tensorAfterReads, qubit0) = - builder.qtensorExtract(tensorAfterReads, 0); - } else { - std::tie(tensorAfterReads, qubit0) = - builder.qtensorExtract(tensorAfterReads, 0); - std::tie(tensorAfterReads, qubit1) = - builder.qtensorExtract(tensorAfterReads, 1); - } + Value q0 = nullptr; + Value q1 = nullptr; - const int64_t q0Target = 0; - const int64_t q1Target = swapInsertTargets ? 2 : 1; + auto tensor = builder.qtensorAlloc(2); + std::tie(tensor, q0) = builder.qtensorExtract(tensor, 0); + std::tie(tensor, q1) = builder.qtensorExtract(tensor, 1); - auto tensorAfterWrites = tensorAfterReads; - if (reverseOrder) { - tensorAfterWrites = - builder.qtensorInsert(qubit0, tensorAfterWrites, q0Target); - tensorAfterWrites = - builder.qtensorInsert(qubit1, tensorAfterWrites, q1Target); + if (reverseInsertOrder) { + tensor = builder.qtensorInsert(q1, tensor, q1Target); + tensor = builder.qtensorInsert(q0, tensor, q0Target); } else { - tensorAfterWrites = - builder.qtensorInsert(qubit1, tensorAfterWrites, q1Target); - tensorAfterWrites = - builder.qtensorInsert(qubit0, tensorAfterWrites, q0Target); + tensor = builder.qtensorInsert(q0, tensor, q0Target); + tensor = builder.qtensorInsert(q1, tensor, q1Target); } - builder.qtensorDealloc(tensorAfterWrites); + builder.qtensorDealloc(tensor); return builder.finalize(); } @@ -759,18 +333,19 @@ buildResetWithCommutingInsertProgram(MLIRContext* context, QCOProgramBuilder builder(context); builder.initialize(); + Value q0 = nullptr; + Value q1 = nullptr; + auto tensor = builder.qtensorAlloc(2); - auto [tensorAfterExtract0, qubit0] = builder.qtensorExtract(tensor, 0); - auto tensorAfterInsert0 = - builder.qtensorInsert(qubit0, tensorAfterExtract0, 0); - auto [tensorAfterExtract1, qubit1] = - builder.qtensorExtract(tensorAfterInsert0, 1); + std::tie(tensor, q0) = builder.qtensorExtract(tensor, 0); + tensor = builder.qtensorInsert(q0, tensor, 0); + std::tie(tensor, q1) = builder.qtensorExtract(tensor, 1); if (withReset) { - qubit1 = builder.reset(qubit1); + q1 = builder.reset(q1); } - auto tensorFinal = builder.qtensorInsert(qubit1, tensorAfterExtract1, 1); - builder.qtensorDealloc(tensorFinal); + tensor = builder.qtensorInsert(q1, tensor, 1); + builder.qtensorDealloc(tensor); return builder.finalize(); } @@ -780,28 +355,28 @@ buildResetWithSameIndexInsertProgram(MLIRContext* context, QCOProgramBuilder builder(context); builder.initialize(); + Value q0 = nullptr; + Value q10 = nullptr; + Value q11 = nullptr; + auto tensor = builder.qtensorAlloc(2); - auto [tensorAfterExtract0, qubit0] = builder.qtensorExtract(tensor, 0); - auto [tensorAfterExtract1, qubit1] = - builder.qtensorExtract(tensorAfterExtract0, 1); - qubit1 = builder.h(qubit1); - auto tensorAfterInsert1 = - builder.qtensorInsert(qubit1, tensorAfterExtract1, 1); - auto [tensorAfterReadBack1, qubit1ReadBack] = - builder.qtensorExtract(tensorAfterInsert1, 1); + std::tie(tensor, q0) = builder.qtensorExtract(tensor, 0); + std::tie(tensor, q10) = builder.qtensorExtract(tensor, 1); + q10 = builder.h(q10); + tensor = builder.qtensorInsert(q10, tensor, 1); + std::tie(tensor, q11) = builder.qtensorExtract(tensor, 1); if (withReset) { - qubit1ReadBack = builder.reset(qubit1ReadBack); + q11 = builder.reset(q11); } - auto tensorAfterInsert1ReadBack = - builder.qtensorInsert(qubit1ReadBack, tensorAfterReadBack1, 1); - auto tensorFinal = - builder.qtensorInsert(qubit0, tensorAfterInsert1ReadBack, 0); - builder.qtensorDealloc(tensorFinal); + tensor = builder.qtensorInsert(q11, tensor, 1); + tensor = builder.qtensorInsert(q0, tensor, 0); + builder.qtensorDealloc(tensor); return builder.finalize(); } namespace { + TEST_F(QTensorTest, InsertChainPermutationEquivalence) { auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); ASSERT_TRUE(program); @@ -822,11 +397,13 @@ TEST_F(QTensorTest, InsertChainPermutationEquivalence) { TEST_F(QTensorTest, InsertChainDifferentAssignmentsNotEquivalent) { auto program = buildTwoQubitInsertChainProgram(context.get(), false, false); ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); EXPECT_TRUE(verify(*program).succeeded()); auto reference = buildTwoQubitInsertChainProgram(context.get(), true, true); ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); EXPECT_TRUE(verify(*reference).succeeded()); @@ -834,14 +411,16 @@ TEST_F(QTensorTest, InsertChainDifferentAssignmentsNotEquivalent) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } -TEST_F(QTensorTest, MixedExtractInsertPermutationEquivalence) { - auto program = buildMixedExtractInsertProgram(context.get(), false, false); +TEST_F(QTensorTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { + auto program = buildResetWithCommutingInsertProgram(context.get(), true); ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); EXPECT_TRUE(verify(*program).succeeded()); - auto reference = buildMixedExtractInsertProgram(context.get(), true, false); + auto reference = buildResetWithCommutingInsertProgram(context.get(), false); ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); EXPECT_TRUE(verify(*reference).succeeded()); @@ -849,14 +428,16 @@ TEST_F(QTensorTest, MixedExtractInsertPermutationEquivalence) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } -TEST_F(QTensorTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { - auto program = buildMixedExtractInsertProgram(context.get(), false, false); +TEST_F(QTensorTest, ResetAfterExtractThroughSameIndexInsertIsNotEliminated) { + auto program = buildResetWithSameIndexInsertProgram(context.get(), true); ASSERT_TRUE(program); + EXPECT_TRUE(verify(*program).succeeded()); EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); EXPECT_TRUE(verify(*program).succeeded()); - auto reference = buildMixedExtractInsertProgram(context.get(), true, true); + auto reference = buildResetWithSameIndexInsertProgram(context.get(), false); ASSERT_TRUE(reference); + EXPECT_TRUE(verify(*reference).succeeded()); EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); EXPECT_TRUE(verify(*reference).succeeded()); @@ -864,34 +445,103 @@ TEST_F(QTensorTest, MixedExtractInsertDifferentAssignmentsNotEquivalent) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } -TEST_F(QTensorTest, ResetAfterExtractThroughCommutingInsertIsEliminated) { - auto program = buildResetWithCommutingInsertProgram(context.get(), true); - ASSERT_TRUE(program); - EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); - EXPECT_TRUE(verify(*program).succeeded()); +// ============================================================================ +// Integration +// ============================================================================ - auto reference = buildResetWithCommutingInsertProgram(context.get(), false); - ASSERT_TRUE(reference); - EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); - EXPECT_TRUE(verify(*reference).succeeded()); +struct QTensorIntegrationTestCase { + std::string name; + mqt::test::NamedBuilder programBuilder; + mqt::test::NamedBuilder referenceBuilder; - EXPECT_TRUE( - areModulesEquivalentWithPermutations(program.get(), reference.get())); + friend std::ostream& operator<<(std::ostream& os, + const QTensorIntegrationTestCase& info); +}; + +// NOLINTNEXTLINE(llvm-prefer-static-over-anonymous-namespace) +std::ostream& operator<<(std::ostream& os, + const QTensorIntegrationTestCase& info) { + return os << "QTensor{" << info.name << "}"; } -TEST_F(QTensorTest, ResetAfterExtractThroughSameIndexInsertIsNotEliminated) { - auto program = buildResetWithSameIndexInsertProgram(context.get(), true); +class QTensorIntegrationTest + : public testing::TestWithParam { +protected: + std::unique_ptr context; + + void SetUp() override { + DialectRegistry registry; + registry.insert(); + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } +}; + +TEST_P(QTensorIntegrationTest, ProgramEquivalence) { + const auto& [_, programBuilder, referenceBuilder] = GetParam(); + const auto name = " (" + GetParam().name + ")"; + mqt::test::DeferredPrinter printer; + + auto program = QCOProgramBuilder::build(context.get(), programBuilder.fn); ASSERT_TRUE(program); + printer.record(program.get(), "Original QTensor IR" + name); + EXPECT_TRUE(verify(*program).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(program.get()).succeeded()); + printer.record(program.get(), "Canonicalized QTensor IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - auto reference = buildResetWithSameIndexInsertProgram(context.get(), false); + auto reference = QCOProgramBuilder::build(context.get(), referenceBuilder.fn); ASSERT_TRUE(reference); + printer.record(reference.get(), "Reference QTensor IR" + name); + EXPECT_TRUE(verify(*reference).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(reference.get()).succeeded()); + printer.record(reference.get(), "Canonicalized Reference QTensor IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - EXPECT_FALSE( + EXPECT_TRUE( areModulesEquivalentWithPermutations(program.get(), reference.get())); } +/// @name QTensor/QTensor.cpp (relocated from QCO test suite) +/// @{ +INSTANTIATE_TEST_SUITE_P( + QTensorOpsTest, QTensorIntegrationTest, + testing::Values( + QTensorIntegrationTestCase{"QTensorAlloc", + MQT_NAMED_BUILDER(qtensorAlloc), + MQT_NAMED_BUILDER(qtensorAlloc)}, + QTensorIntegrationTestCase{"QTensorAllocDealloc", + MQT_NAMED_BUILDER(qtensorDealloc), + MQT_NAMED_BUILDER(qtensorAlloc)}, + QTensorIntegrationTestCase{"QTensorFromElements", + MQT_NAMED_BUILDER(qtensorFromElements), + MQT_NAMED_BUILDER(qtensorFromElements)}, + QTensorIntegrationTestCase{"QTensorExtract", + MQT_NAMED_BUILDER(qtensorExtract), + MQT_NAMED_BUILDER(qtensorExtract)}, + QTensorIntegrationTestCase{"QTensorInsert", + MQT_NAMED_BUILDER(qtensorInsert), + MQT_NAMED_BUILDER(qtensorInsert)}, + QTensorIntegrationTestCase{ + "QTensorExtractInsertSameIndex", + MQT_NAMED_BUILDER(qtensorExtractInsertSameIndex), + MQT_NAMED_BUILDER(qtensorAlloc)}, + QTensorIntegrationTestCase{ + "QTensorExtractInsertIndexMismatch", + MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch), + MQT_NAMED_BUILDER(qtensorExtractInsertIndexMismatch)}, + QTensorIntegrationTestCase{ + "QTensorInsertExtractSameIndex", + MQT_NAMED_BUILDER(qtensorInsertExtractSameIndex), + MQT_NAMED_BUILDER(qtensorInsert)}, + QTensorIntegrationTestCase{ + "QTensorInsertExtractIndexMismatch", + MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch), + MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch)})); +/// @} + } // namespace From c16ce6e734b1b1e4f6426cb2aeb1211bf73801f9 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 20:23:32 +0200 Subject: [PATCH 68/71] Remove redundant folds --- .../mlir/Dialect/QTensor/IR/QTensorOps.td | 2 -- .../QTensor/IR/Operations/ExtractOp.cpp | 28 ----------------- .../QTensor/IR/Operations/InsertOp.cpp | 28 ----------------- .../Dialect/QTensor/IR/test_qtensor_ir.cpp | 30 ------------------- 4 files changed, 88 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index 1a0b98d6ea..78840d4b0f 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -135,7 +135,6 @@ def ExtractOp let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor)"; let hasCanonicalizer = 1; - let hasFolder = 1; let hasVerifier = 1; } @@ -168,7 +167,6 @@ def InsertOp }]; let hasCanonicalizer = 1; - let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 5e575537a2..c731a718b2 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -39,34 +39,6 @@ LogicalResult ExtractOp::verify() { return success(); } -/** - * @brief Check if a `qtensor.extract` operation reads from a `qtensor.insert` - * operation. - */ -static InsertOp foldExtractAfterInsert(ExtractOp extractOp) { - auto insertOp = extractOp.getTensor().getDefiningOp(); - if (!insertOp) { - return nullptr; - } - - if (!areEquivalentIndices(insertOp.getIndex(), extractOp.getIndex())) { - return nullptr; - } - - return insertOp; -} - -LogicalResult ExtractOp::fold(FoldAdaptor /*adaptor*/, - SmallVectorImpl& results) { - if (auto insertOp = foldExtractAfterInsert(*this)) { - results.emplace_back(insertOp.getDest()); - results.emplace_back(insertOp.getScalar()); - return success(); - } - - return failure(); -} - namespace { /** diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 2e807bf487..28849add4e 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -33,34 +33,6 @@ static bool isRemovableExtractInsertPair(InsertOp insertOp, areEquivalentIndices(insertOp.getIndex(), extractOp.getIndex()); } -/** - * @brief Fold the direct pattern - * `insert(extract(tensor, idx).qubit, extract(tensor, idx).out, idx)`. - */ -static Value foldInsertAfterExtract(InsertOp insertOp) { - auto extractOp = insertOp.getScalar().getDefiningOp(); - if (!extractOp) { - return nullptr; - } - - if (insertOp.getDest() != extractOp.getOutTensor()) { - return nullptr; - } - - if (!isRemovableExtractInsertPair(insertOp, extractOp)) { - return nullptr; - } - - return extractOp.getTensor(); -} - -OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { - if (auto result = foldInsertAfterExtract(*this)) { - return result; - } - return {}; -} - /** * @brief Find a matching `qtensor.extract` for an insert index in a tensor * chain by traversing nested scalar tensor ops. diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 562b4d322e..793f29f726 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -230,36 +230,6 @@ TEST_F(QTensorTest, ExtractOpIndexAtDimFailsVerification) { EXPECT_TRUE(verify(*module).failed()); } -// foldExtractAfterInsert: Fold if index is equivalent -TEST_F(QTensorTest, ExtractOpFoldExtractAfterInsertSameIndex) { - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor0 = b.qtensorAlloc(3); - auto [tensor1, q0] = b.qtensorExtract(tensor0, 0); - auto q1 = b.h(q0); - auto tensor2 = b.qtensorInsert(q1, tensor1, 0); - b.qtensorExtract(tensor2, 0); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); - EXPECT_EQ(countOps(*canonicalized), 1U); - EXPECT_EQ(countOps(*canonicalized), 1U); -} - -// foldExtractAfterInsert: Do not fold if index is different -TEST_F(QTensorTest, ExtractOpFoldExtractAfterInsertDifferentIndex) { - auto canonicalized = buildAndCanonicalize([](QCOProgramBuilder& b) { - auto tensor0 = b.qtensorAlloc(3); - auto [tensor1, q0] = b.qtensorExtract(tensor0, 0); - auto q1 = b.h(q0); - auto tensor2 = b.qtensorInsert(q1, tensor1, 0); - b.qtensorExtract(tensor2, 1); - }); - ASSERT_TRUE(canonicalized); - EXPECT_TRUE(verify(*canonicalized).succeeded()); - EXPECT_EQ(countOps(*canonicalized), 1U); - EXPECT_EQ(countOps(*canonicalized), 1U); -} - // ============================================================================ // InsertOp // ============================================================================ From 90978a09f53d859ee3e4118347dff0de6bc015de Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 20:36:58 +0200 Subject: [PATCH 69/71] Remove redundant canonicalization pattern --- .../mlir/Dialect/QTensor/IR/QTensorOps.td | 1 - .../QC/IR/QubitManagement/DeallocOp.cpp | 3 +- .../QTensor/IR/Operations/ExtractOp.cpp | 79 ------------------- .../QTensor/IR/Operations/InsertOp.cpp | 23 +++--- 4 files changed, 14 insertions(+), 92 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index 78840d4b0f..8398fe3268 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -134,7 +134,6 @@ def ExtractOp let results = (outs 1DTensorOf<[QubitType]>:$out_tensor, QubitType:$result); let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor)"; - let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/QC/IR/QubitManagement/DeallocOp.cpp b/mlir/lib/Dialect/QC/IR/QubitManagement/DeallocOp.cpp index b842ab042c..db17a9ab8b 100644 --- a/mlir/lib/Dialect/QC/IR/QubitManagement/DeallocOp.cpp +++ b/mlir/lib/Dialect/QC/IR/QubitManagement/DeallocOp.cpp @@ -21,8 +21,7 @@ using namespace mlir::qc; namespace { /** - * @brief Remove matching allocation and deallocation pairs without operations - * between them. + * @brief Remove matching allocation-deallocation pairs. */ struct RemoveAllocDeallocPair final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index c731a718b2..3a3f302121 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -38,82 +38,3 @@ LogicalResult ExtractOp::verify() { } return success(); } - -namespace { - -/** - * @brief Remove matching insert-extract pairs through commuting disjoint - * tensor-chain operations. - */ -struct RemoveInsertExtractPair final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractOp extractOp, - PatternRewriter& rewriter) const override { - llvm::SmallVector traversedOps; - Value current = extractOp.getTensor(); - InsertOp matchedInsertOp = nullptr; - - auto extractIndex = extractOp.getIndex(); - if (!getConstantIntValue(extractIndex)) { - return failure(); - } - - while (auto* definingOp = current.getDefiningOp()) { - if (!isTensorChainOp(definingOp)) { - break; - } - - if (auto insertOp = llvm::dyn_cast(definingOp)) { - auto insertIndex = insertOp.getIndex(); - if (!getConstantIntValue(insertIndex)) { - return failure(); - } - if (areEquivalentIndices(insertIndex, extractIndex)) { - matchedInsertOp = insertOp; - break; - } - } else if (auto nestedExtractOp = llvm::dyn_cast(definingOp)) { - auto nestedExtractIndex = nestedExtractOp.getIndex(); - if (!getConstantIntValue(nestedExtractIndex)) { - return failure(); - } - // Do not reorder reads from the same index - if (areEquivalentIndices(extractIndex, nestedExtractIndex)) { - return failure(); - } - } else { - return failure(); - } - - traversedOps.push_back(definingOp); - current = getTensorChainInput(definingOp); - } - - if (!matchedInsertOp) { - return failure(); - } - - Value outTensor = matchedInsertOp.getDest(); - if (!traversedOps.empty()) { - Operation* oldestCommutedOp = traversedOps.back(); - rewriter.modifyOpInPlace(oldestCommutedOp, [&]() { - setTensorChainInput(oldestCommutedOp, matchedInsertOp.getDest()); - }); - outTensor = getTensorChainOutput(traversedOps.front()); - if (!outTensor) { - return failure(); - } - } - - rewriter.replaceOp(extractOp, {outTensor, matchedInsertOp.getScalar()}); - return success(); - } -}; - -} // namespace - -void ExtractOp::getCanonicalizationPatterns(RewritePatternSet& results, - MLIRContext* context) { - results.add(context); -} diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 28849add4e..aa3caf731c 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -34,13 +34,17 @@ static bool isRemovableExtractInsertPair(InsertOp insertOp, } /** - * @brief Find a matching `qtensor.extract` for an insert index in a tensor - * chain by traversing nested scalar tensor ops. + * @brief Finds the `qtensor.extract` operation corresponding to a given + * `qtensor.insert` operation. + * + * @details The function traverses the tensor chain of the `qtensor.insert` + * operation until it finds the matching `qtensor.extract` operation. */ -static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { - auto current = tensor; +static ExtractOp findMatchingExtractInTensorChain(InsertOp insertOp) { + auto current = insertOp.getDest(); + auto insertIndex = insertOp.getIndex(); - if (!getConstantIntValue(index)) { + if (!getConstantIntValue(insertIndex)) { return nullptr; } @@ -51,7 +55,7 @@ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { return nullptr; } // A more recent write to the same index shadows all older extracts - if (areEquivalentIndices(nestedInsertIndex, index)) { + if (areEquivalentIndices(nestedInsertIndex, insertIndex)) { return nullptr; } current = nestedInsertOp.getDest(); @@ -62,7 +66,7 @@ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { if (!getConstantIntValue(extractIndex)) { return nullptr; } - if (areEquivalentIndices(extractIndex, index)) { + if (areEquivalentIndices(extractIndex, insertIndex)) { return extractOp; } current = extractOp.getTensor(); @@ -76,15 +80,14 @@ static ExtractOp findMatchingExtractInTensorChain(Value tensor, Value index) { namespace { /** - * @brief Remove matching `qtensor.insert`-`qtensor.extract` pairs. + * @brief Remove matching extract-insert pairs. */ struct RemoveExtractInsertPair final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter& rewriter) const override { - auto extractOp = - findMatchingExtractInTensorChain(op.getDest(), op.getIndex()); + auto extractOp = findMatchingExtractInTensorChain(op); if (!extractOp) { return failure(); } From 928a590cec63ffa3ecfc3c1d0a82b2a59d8246b2 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 23:02:37 +0200 Subject: [PATCH 70/71] Fix linter errors --- mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 3a3f302121..a899d3ae41 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -9,15 +9,10 @@ */ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" -#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" -#include #include #include -#include #include -#include -#include #include #include From f587454d6580a9871d5b7abc9526103225c36d3b Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 7 Apr 2026 23:08:04 +0200 Subject: [PATCH 71/71] Put back InsertOp::fold() --- .../mlir/Dialect/QTensor/IR/QTensorOps.td | 1 + .../QTensor/IR/Operations/InsertOp.cpp | 35 ++++++++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index 8398fe3268..1ac3aa1885 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -166,6 +166,7 @@ def InsertOp }]; let hasCanonicalizer = 1; + let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index aa3caf731c..adeac1cb8d 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -34,11 +34,31 @@ static bool isRemovableExtractInsertPair(InsertOp insertOp, } /** - * @brief Finds the `qtensor.extract` operation corresponding to a given - * `qtensor.insert` operation. + * @brief Folds an insert operation after a matching extract operation into the + * original tensor. + */ +static Value foldInsertAfterExtract(InsertOp insertOp) { + auto extractOp = insertOp.getScalar().getDefiningOp(); + if (!extractOp) { + return nullptr; + } + + if (insertOp.getDest() != extractOp.getOutTensor()) { + return nullptr; + } + + if (!isRemovableExtractInsertPair(insertOp, extractOp)) { + return nullptr; + } + + return extractOp.getTensor(); +} + +/** + * @brief Finds the extract operation corresponding to a given insert operation. * - * @details The function traverses the tensor chain of the `qtensor.insert` - * operation until it finds the matching `qtensor.extract` operation. + * @details The function traverses the tensor chain of the insert operation + * until it finds the matching extract operation. */ static ExtractOp findMatchingExtractInTensorChain(InsertOp insertOp) { auto current = insertOp.getDest(); @@ -121,6 +141,13 @@ LogicalResult InsertOp::verify() { return success(); } +OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { + if (auto result = foldInsertAfterExtract(*this)) { + return result; + } + return {}; +} + void InsertOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context);