From 9682f414529d379b2dc13a45be8522dcfc39096e Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 26 May 2026 18:37:39 +0200 Subject: [PATCH 01/41] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Relax=20condition=20?= =?UTF-8?q?on=20modifiers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Dialect/QC/Builder/QCProgramBuilder.h | 6 +- mlir/include/mlir/Dialect/QC/IR/QCOps.td | 77 ++-- mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 22 +- mlir/include/mlir/Dialect/Utils/Utils.h | 107 +++++ mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp | 55 ++- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 66 ++-- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 61 ++- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 6 +- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 42 +- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 164 +++++--- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 331 ++++++++++------ mlir/lib/Dialect/QC/IR/QCOps.cpp | 12 + .../TranslateQuantumComputationToQC.cpp | 10 +- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 172 ++++---- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 369 +++++++++++------- mlir/lib/Dialect/QCO/IR/QCOOps.cpp | 80 +--- .../Optimizations/HadamardLifting.cpp | 3 +- mlir/unittests/programs/qc_programs.cpp | 347 +++++++++++----- 18 files changed, 1210 insertions(+), 720 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index ab9532cfb4..c2483d9f06 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -917,7 +917,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * } : !qc.qubit * ``` */ - QCProgramBuilder& ctrl(ValueRange controls, const function_ref& body); + QCProgramBuilder& ctrl(ValueRange controls, ValueRange targets, + const function_ref& body); /** * @brief Apply an inverse (i.e., adjoint) operation. @@ -936,7 +937,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * } * ``` */ - QCProgramBuilder& inv(const function_ref& body); + QCProgramBuilder& inv(ValueRange qubits, + const function_ref& body); //===--------------------------------------------------------------------===// // Deallocation diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index 8e76c9c3ba..cce7265131 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -916,7 +916,7 @@ def YieldOp : QCOp<"yield", traits = [Terminator]> { def CtrlOp : QCOp<"ctrl", - traits = [UnitaryOpInterface, + traits = [UnitaryOpInterface, AttrSizedOperandSegments, SingleBlockImplicitTerminator<"::mlir::qc::YieldOp">, RecursiveMemoryEffects]> { let summary = "Add control qubits to a unitary operation"; @@ -937,30 +937,36 @@ def CtrlOp ``` }]; - let arguments = - (ins Arg, - "the control qubits", [MemRead, MemWrite]>:$controls); + let arguments = (ins Arg, + "the control qubits", [MemRead, MemWrite]>:$controls, + Arg, + "the target qubits", [MemRead, MemWrite]>:$targets); let regions = (region SizedRegion<1>:$region); - let assemblyFormat = - "`(` $controls `)` $region attr-dict `:` type($controls)"; + let assemblyFormat = [{ + `(` $controls `)` + `targets` + custom($region, $targets) + attr-dict `:` + `{` type($controls) `}` ( `,` `{` type($targets)^ `}` )? + }]; let extraClassDeclaration = [{ - [[nodiscard]] UnitaryOpInterface getBodyUnitary(); + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); size_t getNumQubits() { return getNumTargets() + getNumControls(); } - size_t getNumTargets() { return getBodyUnitary().getNumTargets(); } + size_t getNumTargets() { return getTargets().size(); } size_t getNumControls() { return getControls().size(); } Value getQubit(size_t i); - Value getTarget(size_t i) { return getBodyUnitary().getTarget(i); } - ValueRange getTargets() { return getBodyUnitary().getTargets(); } + Value getTarget(size_t i) { return getTargets()[i]; } Value getControl(size_t i); - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } + ValueRange getParameters() { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } static StringRef getBaseSymbol() { return "ctrl"; } }]; - let builders = [OpBuilder<(ins "ValueRange":$controls, - "const function_ref&":$bodyBuilder)>]; + let builders = [OpBuilder<(ins "ValueRange":$controls, "ValueRange":$targets, + "const function_ref&":$body)>]; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -983,26 +989,35 @@ def InvOp : QCOp<"inv", ``` }]; + let arguments = (ins Arg< + Variadic, + "the qubits involved in the operation", [MemRead, MemWrite]>:$qubits); let regions = (region SizedRegion<1>:$region); - let assemblyFormat = "$region attr-dict"; - - let extraClassDeclaration = [{ - [[nodiscard]] UnitaryOpInterface getBodyUnitary(); - size_t getNumQubits() { return getBodyUnitary().getNumQubits(); } - size_t getNumTargets() { return getBodyUnitary().getNumTargets(); } - size_t getNumControls() { return getBodyUnitary().getNumControls(); } - Value getQubit(size_t i) { return getBodyUnitary().getQubit(i); } - Value getTarget(size_t i) { return getBodyUnitary().getTarget(i); } - ValueRange getTargets() { return getBodyUnitary().getTargets(); } - Value getControl(size_t i) { return getBodyUnitary().getControl(i); } - ValueRange getControls() { return getBodyUnitary().getControls(); } - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + let assemblyFormat = [{ + custom($region, $qubits) + attr-dict `:` + type($qubits) + }]; + + let extraClassDeclaration = [{ + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); + size_t getNumQubits() { return getNumTargets(); } + size_t getNumTargets() { return getQubits().size(); } + size_t getNumControls() { return 0; } + Value getQubit(size_t i) { return getTarget(i); } + Value getTarget(size_t i) { return getQubits()[i]; } + ValueRange getTargets() { return getQubits(); } + Value getControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } + ValueRange getControls() { return {nullptr, 0}; } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); } + ValueRange getParameters() { return {nullptr, 0}; } static StringRef getBaseSymbol() { return "inv"; } }]; - let builders = [OpBuilder<(ins "const function_ref&":$bodyBuilder)>]; + let builders = [OpBuilder<(ins "ValueRange":$qubits, + "const function_ref&":$body)>]; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index a5bbfb7f51..78e15ecc86 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -1102,7 +1102,8 @@ def CtrlOp }]; let extraClassDeclaration = [{ - UnitaryOpInterface getBodyUnitary(); + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); size_t getNumQubits() { return getNumControls() + getNumTargets(); } size_t getNumTargets() { return getTargetsIn().size(); } size_t getNumControls() { return getControlsIn().size(); } @@ -1120,9 +1121,9 @@ def CtrlOp ResultRange getOutputControls() { return getControlsOut(); } Value getInputForOutput(Value output); Value getOutputForInput(Value input); - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } + ValueRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "ctrl"; } [[nodiscard]] std::optional getUnitaryMatrix(); }]; @@ -1173,7 +1174,8 @@ def InvOp }]; let extraClassDeclaration = [{ - UnitaryOpInterface getBodyUnitary(); + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); size_t getNumQubits() { return getNumTargets(); } size_t getNumTargets() { return getQubitsIn().size(); } static size_t getNumControls() { return 0; } @@ -1184,16 +1186,16 @@ def InvOp ResultRange getOutputQubits() { return getQubitsOut(); } Value getInputTarget(size_t i) { return getInputQubit(i); } Value getOutputTarget(size_t i) { return getOutputQubit(i); } - static Value getInputControl(size_t i) { llvm::reportFatalUsageError("Operation does not have controls"); } + static Value getInputControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } static OperandRange getInputControls() { return {nullptr, 0}; } - static Value getOutputControl(size_t i) { llvm::reportFatalUsageError("Operation does not have controls"); } + static Value getOutputControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } static ResultRange getOutputControls() { return {nullptr, 0}; } ResultRange getOutputTargets() { return getOutputQubits(); } Value getInputForOutput(Value output); Value getOutputForInput(Value input); - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); } + ValueRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "inv"; } [[nodiscard]] std::optional getUnitaryMatrix(); }]; diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 3d976a5a63..546ecc479c 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -78,4 +79,110 @@ template return std::nullopt; } +template +[[nodiscard]] +static ParseResult +parseTargetAliasing(OpAsmParser& parser, Region& region, + SmallVectorImpl& operands) { + // 1. Parse the opening parenthesis + if (parser.parseLParen()) { + return failure(); + } + + // Temporary storage for block arguments we are about to create + SmallVector blockArgs; + + // 2. Prepare to parse the list + if (failed(parser.parseOptionalRParen())) { + do { + OpAsmParser::Argument newArg; // The "new" variable name + OpAsmParser::UnresolvedOperand oldOperand; // The "old" input variable + + // Parse "%new" + if (parser.parseArgument(newArg)) { + return failure(); + } + + // Parse "=" + if (parser.parseEqual()) { + return failure(); + } + + // Parse "%old" + if (parser.parseOperand(oldOperand)) { + return failure(); + } + operands.push_back(oldOperand); + + // Hard-code QubitType since targets in qco.ctrl are always qubits. + // This avoids double-binding type($targets_in) in the assembly format + // while keeping the parser simple and the assembly format clean. + newArg.type = QubitType::get(parser.getBuilder().getContext()); + blockArgs.push_back(newArg); + + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) { + return failure(); + } + } + + // 4. Parse the Region + // We explicitly pass the blockArgs we just parsed so they become the entry + // block! + if (parser.parseRegion(region, blockArgs)) { + return failure(); + } + + return success(); +} + +static void printTargetAliasing(OpAsmPrinter& printer, Region& region, + OperandRange targetsIn) { + printer << "("; + if (region.empty()) { + printer << ") "; + printer.printRegion(region, false); + return; + } + Block& entryBlock = region.front(); + + const auto numTargets = targetsIn.size(); + for (unsigned i = 0; i < numTargets; ++i) { + if (i > 0) { + printer << ", "; + } + printer.printOperand(entryBlock.getArgument(i)); + printer << " = "; + printer.printOperand(targetsIn[i]); + } + printer << ") "; + + printer.printRegion(region, false); +} + +// TODO: Document +static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { + if (auto blockArg = dyn_cast(qubit)) { + return qubits[blockArg.getArgNumber()]; + } + return qubit; +} + +// TODO: Rename and document +static void prova(Block& block, IRMapping& mapping, ValueRange innerQubits, + ValueRange outerQubits, ValueRange newQubits, + ValueRange qubitArgs) { + for (auto arg : block.getArguments()) { + auto innerQubit = innerQubits[arg.getArgNumber()]; + auto outerQubit = getValueFromBlockArgument(innerQubit, outerQubits); + if (auto it = llvm::find(newQubits, outerQubit); it != newQubits.end()) { + auto index = std::distance(newQubits.begin(), it); + mapping.map(arg, qubitArgs[index]); + } else { + llvm::reportFatalInternalError("TODO"); + } + } +} + } // namespace mlir::utils diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index a486e82c5d..001844e1e6 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -155,19 +155,30 @@ static void handleResult(Operation* op, ConversionPatternRewriter& rewriter, * @brief Target operands: `adaptor.getOperands()` at the matched op, or * `state.targetsIn` while lowering inside `qco.ctrl` / `qco.inv`. * - * @param state Lowering state. - * @param adaptor Operand adaptor for the matched op. + * @param op The operation being converted. + * @param adaptor The operation adaptor of the operation. + * @param state The lowering state. * @tparam NumParams Number of parameters to drop from the end of the operand * list. - * @tparam OpAdaptor Adaptor with `getOperands()`. - * @return ValueRange The target operands. + * @tparam OpType The type of the operation. + * @tparam OpAdaptorType The type of the operation adaptor. + * @return The target operands. */ -template -[[nodiscard]] static ValueRange getEffectiveTargetOperands(LoweringState& state, - OpAdaptor adaptor) { - return state.inModifier() - ? ValueRange(state.targetsIn) - : ValueRange(adaptor.getOperands().drop_back(NumParams)); +template +[[nodiscard]] static SmallVector +getEffectiveTargetOperands(OpType op, OpAdaptorType adaptor, + LoweringState& state) { + if (!state.inModifier()) { + return adaptor.getOperands().drop_back(NumParams); + } + + SmallVector targets; + for (auto targetArg : op->getOperands().drop_back(NumParams)) { + auto target = + state.targetsIn[cast(targetArg).getArgNumber()]; + targets.push_back(target); + } + return targets; } /** @@ -190,10 +201,10 @@ convertJeffGate(QCOOpType op, typename QCOOpType::Adaptor adaptor, std::index_sequence /*targetIndices*/, std::index_sequence /*paramIndices*/) { constexpr std::size_t numParams = sizeof...(ParamIndices); - ValueRange targets = getEffectiveTargetOperands(state, adaptor); + auto targets = getEffectiveTargetOperands(op, adaptor, state); assert(targets.size() >= sizeof...(TargetIndices) && "Not enough operands available for conversion"); - ValueRange params = op.getParameters(); + auto params = op.getParameters(); auto jeffOp = JeffOpType::create( rewriter, op.getLoc(), targets[TargetIndices]..., params[ParamIndices]..., @@ -336,7 +347,7 @@ static LogicalResult moveRegion(Region& source, Region& dest, ConversionPatternRewriter& rewriter, const TypeConverter* typeConverter) { rewriter.inlineRegionBefore(source, dest, dest.end()); - Block* block = &dest.front(); + auto* block = &dest.front(); TypeConverter::SignatureConversion sc(block->getNumArguments()); if (failed( typeConverter->convertSignatureArgs(block->getArgumentTypes(), sc))) { @@ -728,7 +739,7 @@ struct ConvertQCOCustomGateToJeff final } } - ValueRange targets = getEffectiveTargetOperands(state, adaptor); + auto targets = getEffectiveTargetOperands(op, adaptor, state); assert(targets.size() >= NumTargets && "Not enough operands available for conversion"); @@ -764,7 +775,7 @@ struct ConvertQCOPPRGateToJeff final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - ValueRange targets = getEffectiveTargetOperands<1>(state, adaptor); + auto targets = getEffectiveTargetOperands<1>(op, adaptor, state); assert(targets.size() >= 2 && "Not enough operands available for conversion"); createPPROp(op, rewriter, state, targets, {p0_, p1_}); @@ -798,7 +809,7 @@ struct ConvertQCOU2OpToJeff final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - ValueRange targets = getEffectiveTargetOperands<2>(state, adaptor); + auto targets = getEffectiveTargetOperands<2>(op, adaptor, state); assert(!targets.empty() && "Not enough operands available for conversion"); auto target = targets.front(); @@ -840,11 +851,8 @@ struct ConvertQCOBarrierOpToJeff final matchAndRewrite(BarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - - ValueRange targets = getEffectiveTargetOperands<0>(state, adaptor); - + auto targets = getEffectiveTargetOperands<0>(op, adaptor, state); createCustomOp(op, rewriter, state, targets, {}, false, "barrier"); - return success(); } }; @@ -934,6 +942,13 @@ struct ConvertQCOInvOpToJeff final : StatefulOpConversionPattern { state.invOp = op; if (state.targetsIn.empty()) { state.targetsIn = llvm::to_vector(adaptor.getQubitsIn()); + } else { + auto outerQubits = state.targetsIn; + SmallVector innerQubits; + for (auto arg : op.getBody()->getArguments()) { + innerQubits.push_back(outerQubits[arg.getArgNumber()]); + } + state.targetsIn = std::move(innerQubits); } // Inline region diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 97d7071f69..77758bd3ba 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -120,17 +120,12 @@ class StatefulOpConversionPattern : public OpConversionPattern { * @param sourceRegion Source region where the operations are moved from * @param targetRegion Target region where the operations are moved to * @param offset Offset to the arguments that are dropped - * @param numArgs Number of arguments that are dropped * @param replacementValues Values to replace the uses of the arguments * @param rewriter PatternRewriter of the current conversion pass */ static void inlineRegion(Region& sourceRegion, Region& targetRegion, - unsigned int offset, unsigned int numArgs, - ValueRange replacementValues, + unsigned int offset, ValueRange replacementValues, ConversionPatternRewriter& rewriter) { - assert(replacementValues.size() == numArgs && - "replacementValues size must match numArgs"); - rewriter.inlineRegionBefore(sourceRegion, targetRegion, targetRegion.end()); auto& block = targetRegion.front(); @@ -138,7 +133,7 @@ static void inlineRegion(Region& sourceRegion, Region& targetRegion, block.getArguments().drop_front(offset), replacementValues)) { arg.replaceAllUsesWith(replacementVal); } - block.eraseArguments(offset, numArgs); + block.eraseArguments(offset, replacementValues.size()); } #define GEN_PASS_DEF_QCOTOQC @@ -645,16 +640,19 @@ struct ConvertQCOCtrlOp final : OpConversionPattern { LogicalResult matchAndRewrite(qco::CtrlOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // Get QC controls - auto qcControls = adaptor.getControlsIn(); - // Create qc.ctrl operation - auto qcOp = qc::CtrlOp::create(rewriter, op.getLoc(), qcControls); - - // Inline the region and replace the blockarguments - inlineRegion(op.getRegion(), qcOp.getRegion(), 0, - adaptor.getTargetsIn().size(), adaptor.getTargetsIn(), - rewriter); + auto qcOp = qc::CtrlOp::create( + rewriter, op.getLoc(), adaptor.getControlsIn(), adaptor.getTargetsIn()); + + auto& dstRegion = qcOp.getRegion(); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); // Replace the output qubits with the same QC references rewriter.replaceOp(op, adaptor.getOperands()); @@ -687,11 +685,17 @@ struct ConvertQCOInvOp final : OpConversionPattern { matchAndRewrite(qco::InvOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { // Create qc.inv operation - auto qcOp = qc::InvOp::create(rewriter, op.getLoc()); - - // Inline the region and replace the blockarguments - inlineRegion(op.getRegion(), qcOp.getRegion(), 0, - adaptor.getOperands().size(), adaptor.getQubitsIn(), rewriter); + auto qcOp = qc::InvOp::create(rewriter, op.getLoc(), adaptor.getQubitsIn()); + + auto& dstRegion = qcOp.getRegion(); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); // Replace the output qubits with the same QC references rewriter.replaceOp(op, adaptor.getOperands()); @@ -764,9 +768,9 @@ struct ConvertQCOSCFForOp final : OpConversionPattern { // Erase default block rewriter.eraseBlock(&newFor.getRegion().front()); - // Inline the region and replace the blockarguments - inlineRegion(op.getRegion(), newFor.getRegion(), 1, - adaptor.getInitArgs().size(), adaptor.getInitArgs(), rewriter); + // Inline the region and replace the block arguments + inlineRegion(op.getRegion(), newFor.getRegion(), 1, adaptor.getInitArgs(), + rewriter); rewriter.replaceOp(op, adaptor.getInitArgs()); @@ -810,11 +814,11 @@ struct ConvertQCOSCFWhileOp final : OpConversionPattern { auto newWhileOp = scf::WhileOp::create(rewriter, op->getLoc(), TypeRange{}, ValueRange{}); - // Inline the regions and replace the blockarguments - inlineRegion(op.getBefore(), newWhileOp.getBefore(), 0, - adaptor.getInits().size(), adaptor.getInits(), rewriter); - inlineRegion(op.getAfter(), newWhileOp.getAfter(), 0, - adaptor.getInits().size(), adaptor.getInits(), rewriter); + // Inline the regions and replace the block arguments + inlineRegion(op.getBefore(), newWhileOp.getBefore(), 0, adaptor.getInits(), + rewriter); + inlineRegion(op.getAfter(), newWhileOp.getAfter(), 0, adaptor.getInits(), + rewriter); rewriter.replaceOp(op, adaptor.getInits()); @@ -855,15 +859,13 @@ struct ConvertQCOIfOp final : OpConversionPattern { // Erase the default empty then block rewriter.eraseBlock(&newThenRegion.front()); - // Inline the region and replace the blockarguments + // Inline the region and replace the block arguments inlineRegion(op.getThenRegion(), newThenRegion, 0, - adaptor.getOperands().size() - 1, adaptor.getOperands().drop_front(1), rewriter); // Inline the else block if it has more than just the yield operation if (oldElseRegion.front().getOperations().size() > 1) { inlineRegion(oldElseRegion, newIf.getElseRegion(), 0, - adaptor.getOperands().size() - 1, adaptor.getOperands().drop_front(1), rewriter); } diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index b4a4b0a151..9c32fc302d 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -391,22 +391,6 @@ static void popModifierFrame(LoweringState& state) { state.modifierFrames.pop_back(); } -/** @brief Adds entry block aliases for modifier target values. */ -template -[[nodiscard]] static ValueRange addModifierAliases(OpType op, - const size_t numTargets, - PatternRewriter& rewriter) { - auto& entryBlock = op.getRegion().front(); - const auto opLoc = op.getLoc(); - const auto qubitType = qco::QubitType::get(op.getContext()); - rewriter.modifyOpInPlace(op, [&] { - for (size_t i = 0; i < numTargets; ++i) { - entryBlock.addArgument(qubitType, opLoc); - } - }); - return entryBlock.getArguments().take_back(numTargets); -} - /** * @brief Inserts extracted qubits that are not required by @p target back into * their tensors. @@ -525,7 +509,8 @@ collectQubitValuesInsideSCFOps(Operation* op, LoweringState* state) { // Iterate through all operations of the current region for (auto& operation : region.front().getOperations()) { // Recursively walk through nested regions - if (operation.getNumRegions() > 0) { + if (operation.getNumRegions() > 0 && + !isa(operation)) { auto [qubits, registers] = collectQubitValuesInsideSCFOps(&operation, state); auto& regionQubitMap = state->regionQubitMap[op]; @@ -1124,16 +1109,20 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { assignMappedQubits(state, operation, qcControls, qcoOp.getControlsOut()); assignMappedQubits(state, operation, qcTargets, qcoOp.getTargetsOut()); - // Clone body region from QC to QCO + auto qcArgs = op.getRegion().front().getArguments(); + + // Inline region auto& dstRegion = qcoOp.getRegion(); - rewriter.cloneRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); - // Create block arguments for QCO targets - auto& entryBlock = dstRegion.front(); - assert(entryBlock.getNumArguments() == 0 && - "QC ctrl region unexpectedly has entry block arguments"); - pushModifierFrame(state, qcTargets, - addModifierAliases(qcoOp, numTargets, rewriter)); + pushModifierFrame(state, qcArgs, qcoOp.getRegion().front().getArguments()); rewriter.eraseOp(op); return success(); @@ -1174,16 +1163,20 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { assignMappedQubits(state, operation, qcTargets, qcoOp.getOutputTargets()); - // Clone body region from QC to QCO + auto qcArgs = op.getRegion().front().getArguments(); + + // Inline region auto& dstRegion = qcoOp.getRegion(); - rewriter.cloneRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - - // Create block arguments for target qubits and seed the nested frame. - auto& entryBlock = dstRegion.front(); - assert(entryBlock.getNumArguments() == 0 && - "QC inv region unexpectedly has entry block arguments"); - pushModifierFrame(state, qcTargets, - addModifierAliases(qcoOp, numTargets, rewriter)); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); + + pushModifierFrame(state, qcArgs, qcoOp.getRegion().front().getArguments()); rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 98942d998c..6c432f57d2 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -870,9 +870,9 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { adaptor.getControls().end()); state.controls[state.inCtrlOp] = controls; - // Inline region and remove operation - rewriter.inlineBlockBefore(&op.getRegion().front(), op->getBlock(), - op->getIterator()); + // Inline block and remove operation + rewriter.inlineBlockBefore(&op.getRegion().front(), op, + adaptor.getTargets()); rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 5eb022c03a..4bde38d301 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -223,7 +223,8 @@ QCProgramBuilder& QCProgramBuilder::reset(Value qubit) { const std::variant&(PARAM), ValueRange controls) { \ checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - CtrlOp::create(*this, controls, [&] { OP_CLASS::create(*this, param); }); \ + ctrl(controls, ValueRange{}, \ + [&](ValueRange /*targets*/) { OP_NAME(param); }); \ return *this; \ } @@ -247,7 +248,7 @@ DEFINE_ZERO_TARGET_ONE_PARAMETER(GPhaseOp, gphase, theta) QCProgramBuilder& QCProgramBuilder::mc##OP_NAME(ValueRange controls, \ Value target) { \ checkFinalized(); \ - CtrlOp::create(*this, controls, [&] { OP_CLASS::create(*this, target); }); \ + ctrl(controls, target, [&](ValueRange targets) { OP_NAME(targets[0]); }); \ return *this; \ } @@ -285,8 +286,8 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg) Value target) { \ checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, target, param); }); \ + ctrl(controls, target, \ + [&](ValueRange targets) { OP_NAME(param, targets[0]); }); \ return *this; \ } @@ -321,8 +322,8 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, theta) checkFinalized(); \ auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, target, param1, param2); }); \ + ctrl(controls, target, \ + [&](ValueRange targets) { OP_NAME(param1, param2, targets[0]); }); \ return *this; \ } @@ -360,8 +361,8 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ auto param3 = variantToValue(*this, getLoc(), PARAM3); \ - CtrlOp::create(*this, controls, [&] { \ - OP_CLASS::create(*this, target, param1, param2, param3); \ + ctrl(controls, target, [&](ValueRange targets) { \ + OP_NAME(param1, param2, param3, targets[0]); \ }); \ return *this; \ } @@ -386,8 +387,8 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, u, theta, phi, lambda) QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ ValueRange controls, Value qubit0, Value qubit1) { \ checkFinalized(); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, qubit0, qubit1); }); \ + ctrl(controls, ValueRange{qubit0, qubit1}, \ + [&](ValueRange targets) { OP_NAME(targets[0], targets[1]); }); \ return *this; \ } @@ -418,8 +419,8 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr) Value qubit0, Value qubit1) { \ checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, qubit0, qubit1, param); }); \ + ctrl(controls, ValueRange{qubit0, qubit1}, \ + [&](ValueRange targets) { OP_NAME(param, targets[0], targets[1]); }); \ return *this; \ } @@ -455,8 +456,8 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta) checkFinalized(); \ auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ - CtrlOp::create(*this, controls, [&] { \ - OP_CLASS::create(*this, qubit0, qubit1, param1, param2); \ + ctrl(controls, ValueRange{qubit0, qubit1}, [&](ValueRange targets) { \ + OP_NAME(param1, param2, targets[0], targets[1]); \ }); \ return *this; \ } @@ -478,16 +479,19 @@ QCProgramBuilder& QCProgramBuilder::barrier(ValueRange qubits) { // Modifiers //===----------------------------------------------------------------------===// -QCProgramBuilder& QCProgramBuilder::ctrl(ValueRange controls, - const function_ref& body) { +QCProgramBuilder& +QCProgramBuilder::ctrl(ValueRange controls, ValueRange targets, + const function_ref& body) { checkFinalized(); - CtrlOp::create(*this, controls, body); + CtrlOp::create(*this, controls, targets, body); return *this; } -QCProgramBuilder& QCProgramBuilder::inv(const function_ref& body) { +QCProgramBuilder& +QCProgramBuilder::inv(ValueRange qubits, + const function_ref& body) { checkFinalized(); - InvOp::create(*this, body); + InvOp::create(*this, qubits, body); return *this; } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index d457fa9a35..cf943b1f99 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -9,10 +9,12 @@ */ #include "mlir/Dialect/QC/IR/QCOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include #include +#include #include #include #include @@ -33,22 +35,47 @@ struct MergeNestedCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - auto* bodyUnitary = op.getBodyUnitary().getOperation(); - auto bodyCtrlOp = dyn_cast(bodyUnitary); - if (!bodyCtrlOp) { + // Require at least one control + // Trivial case is handled by ReduceCtrl + if (op.getNumControls() == 0) { return failure(); } - // add the inner controls as operands to the outer one - op->insertOperands(op.getNumOperands(), bodyCtrlOp.getControls()); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + if (!innerCtrlOp) { + return failure(); + } - // Move the inner unitary op into the outer one's body region and replace - // the outer one with the inner one's results - const OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(bodyUnitary); - auto* innerUnitaryOp = bodyCtrlOp.getBodyUnitary().getOperation(); - rewriter.moveOpBefore(innerUnitaryOp, bodyUnitary); - rewriter.replaceOp(bodyUnitary, innerUnitaryOp->getResults()); + auto outerControls = op.getControls(); + auto outerTargets = op.getTargets(); + auto innerTargets = innerCtrlOp.getTargets(); + + SmallVector controls; + SmallVector targets; + llvm::append_range(controls, outerControls); + for (auto [arg, qubit] : + llvm::zip_equal(op.getBody()->getArguments(), outerTargets)) { + if (llvm::is_contained(innerTargets, arg)) { + targets.push_back(qubit); + } else { + controls.push_back(qubit); + } + } + + rewriter.replaceOpWithNewOp( + op, controls, targets, [&](ValueRange targetArgs) { + auto* innerCtrlBody = innerCtrlOp.getBody(); + IRMapping mapping; + utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, + targets, targetArgs); + for (auto& op : innerCtrlBody->without_terminator()) { + rewriter.clone(op, mapping); + } + }); return success(); } @@ -63,16 +90,30 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - auto* bodyUnitary = op.getBodyUnitary().getOperation(); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + // Inline ops from empty control modifiers, IdOp and BarrierOp - if (op.getNumControls() == 0 || isa(bodyUnitary)) { - rewriter.moveOpBefore(bodyUnitary, op); - rewriter.replaceOp(op, bodyUnitary->getResults()); + if (op.getNumControls() == 0 || isa(innerOp)) { + const auto numTargets = op.getNumTargets(); + auto outerTargets = op.getTargets(); + SmallVector targets; + for (auto target : innerOp->getOperands().take_front(numTargets)) { + targets.push_back( + utils::getValueFromBlockArgument(target, outerTargets)); + } + + rewriter.moveOpBefore(innerOp, op); + innerOp->setOperands(0, numTargets, targets); + rewriter.eraseOp(op); return success(); } // The remaining code explicitly handles GPhaseOp and nothing else - auto gPhaseOp = dyn_cast(bodyUnitary); + auto gPhaseOp = dyn_cast(innerOp); if (!gPhaseOp) { return failure(); } @@ -84,16 +125,23 @@ struct ReduceCtrl final : OpRewritePattern { return success(); } - // Remove the last control and replace with a single POp with the removed - // control as target - auto controls = op.getControls(); - auto target = controls.back(); - controls = controls.drop_back(); - op->setOperands(controls); + // Adjust the segment sizes of the control and target operands + const auto opSegmentsAttrName = CtrlOp::getOperandSegmentSizeAttr(); + auto segmentsAttr = + op->getAttrOfType(opSegmentsAttrName); + auto newSegments = DenseI32ArrayAttr::get( + rewriter.getContext(), {segmentsAttr[0] - 1, segmentsAttr[1] + 1}); + op->setAttr(opSegmentsAttrName, newSegments); + + // Add a block argument for the target qubit + auto arg = op.getBody()->addArgument(QubitType::get(rewriter.getContext()), + op.getLoc()); + // Replace the current GPhaseOp with a PhaseOp const OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(gPhaseOp); - rewriter.replaceOpWithNewOp(gPhaseOp, target, gPhaseOp.getTheta()); + POp::create(rewriter, gPhaseOp.getLoc(), arg, gPhaseOp.getTheta()); + rewriter.eraseOp(gPhaseOp); return success(); } @@ -101,13 +149,27 @@ struct ReduceCtrl final : OpRewritePattern { } // namespace -UnitaryOpInterface CtrlOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t CtrlOp::getNumBodyUnitaries() { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + count++; + } + } + return count; +} + +UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + if (count == i) { + return cast(op); + } + count++; + } + } + llvm::reportFatalUsageError("Unitary index out of bounds"); } Value CtrlOp::getQubit(const size_t i) { @@ -116,9 +178,9 @@ Value CtrlOp::getQubit(const size_t i) { return getControls()[i]; } if (numControls <= i && i < getNumQubits()) { - return getBodyUnitary().getQubit(i - numControls); + return getTarget(i - numControls); } - llvm::reportFatalUsageError("Invalid qubit index"); + llvm::reportFatalUsageError("Qubit index out of bounds"); } Value CtrlOp::getControl(const size_t i) { @@ -129,15 +191,19 @@ Value CtrlOp::getControl(const size_t i) { } void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, - ValueRange controls, - const function_ref& bodyBuilder) { - const OpBuilder::InsertionGuard guard(odsBuilder); - odsState.addOperands(controls); - auto* region = odsState.addRegion(); - auto& block = region->emplaceBlock(); + ValueRange controls, ValueRange targets, + const function_ref& body) { + build(odsBuilder, odsState, controls, targets); + auto& block = odsState.regions.front()->emplaceBlock(); + + auto qubitType = QubitType::get(odsBuilder.getContext()); + for (size_t i = 0; i < targets.size(); ++i) { + block.addArgument(qubitType, odsState.location); + } + const OpBuilder::InsertionGuard guard(odsBuilder); odsBuilder.setInsertionPointToStart(&block); - bodyBuilder(); + body(block.getArguments()); YieldOp::create(odsBuilder, odsState.location); } @@ -150,16 +216,6 @@ LogicalResult CtrlOp::verify() { return emitOpError( "last operation in body region must be a yield operation"); } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } SmallPtrSet uniqueQubits; for (const auto& control : getControls()) { @@ -167,11 +223,9 @@ LogicalResult CtrlOp::verify() { return emitOpError("duplicate control qubit found"); } } - auto bodyUnitary = getBodyUnitary(); - const auto numQubits = bodyUnitary.getNumQubits(); - for (size_t i = 0; i < numQubits; i++) { - if (!uniqueQubits.insert(bodyUnitary.getQubit(i)).second) { - return emitOpError("duplicate qubit found"); + for (const auto& target : getTargets()) { + if (!uniqueQubits.insert(target).second) { + return emitOpError("duplicate target qubit found"); } } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 065fe431be..b935e5c823 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -9,11 +9,13 @@ */ #include "mlir/Dialect/QC/IR/QCOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include #include #include +#include #include #include #include @@ -33,20 +35,36 @@ namespace { struct MoveCtrlOutside final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InvOp invOp, + LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto bodyUnitary = invOp.getBodyUnitary(); - auto innerCtrlOp = dyn_cast(bodyUnitary.getOperation()); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerCtrlOp) { return failure(); } - auto controls = innerCtrlOp.getControls(); - rewriter.replaceOpWithNewOp(invOp, controls, [&] { - InvOp::create(rewriter, invOp.getLoc(), [&] { - rewriter.clone(*innerCtrlOp.getBodyUnitary().getOperation()); - }); - }); + const auto numControls = innerCtrlOp.getNumControls(); + const auto numTargets = innerCtrlOp.getNumTargets(); + auto outerQubits = op.getQubits(); + auto controls = outerQubits.take_front(numControls); + auto targets = outerQubits.take_back(numTargets); + + rewriter.replaceOpWithNewOp( + op, controls, targets, [&](ValueRange targetArgs) { + InvOp::create( + rewriter, op.getLoc(), targetArgs, [&](ValueRange qubitArgs) { + auto* innerCtrlBody = innerCtrlOp.getBody(); + IRMapping mapping; + utils::prova(*innerCtrlBody, mapping, innerCtrlOp.getTargets(), + outerQubits, targets, qubitArgs); + for (auto& op : innerCtrlBody->without_terminator()) { + rewriter.clone(op, mapping); + } + }); + }); return success(); } @@ -62,13 +80,24 @@ struct InlineSelfAdjoint final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); if (!isa(innerOp)) { return failure(); } + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getQubits(); + SmallVector qubits; + for (auto qubit : innerOp->getOperands().take_front(numQubits)) { + qubits.push_back(utils::getValueFromBlockArgument(qubit, outerQubits)); + } + rewriter.moveOpBefore(innerOp, op); + innerOp->setOperands(0, numQubits, qubits); rewriter.replaceOp(op, innerOp->getResults()); return success(); } @@ -85,143 +114,181 @@ struct ReplaceWithKnownGates final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + + auto loc = op.getLoc(); + auto outerQubits = op.getQubits(); return TypeSwitch(innerOp) .Case([&](auto g) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), g.getTheta()); + Value negTheta = arith::NegFOp::create(rewriter, loc, g.getTheta()); rewriter.replaceOpWithNewOp(op, negTheta); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto t) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(t.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto tdg) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(tdg.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto s) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(s.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto sdg) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(sdg.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto sx) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(sx.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto sxdg) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(sxdg.getTarget(0), outerQubits)); return success(); }) .Case([&](auto p) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), p.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, p.getTheta()); + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(p.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto r) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), r.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta, - r.getPhi()); + auto negTheta = arith::NegFOp::create(rewriter, loc, r.getTheta()); + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(r.getTarget(0), outerQubits), + negTheta, r.getPhi()); return success(); }) .Case([&](auto rx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rx.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto u) { - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - Value newTheta = - arith::NegFOp::create(rewriter, op.getLoc(), u.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), newTheta, - newPhi, newLambda); + Value newPhi = arith::NegFOp::create(rewriter, loc, u.getLambda()); + Value newLambda = arith::NegFOp::create(rewriter, loc, u.getPhi()); + Value newTheta = arith::NegFOp::create(rewriter, loc, u.getTheta()); + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(u.getTarget(0), outerQubits), + newTheta, newPhi, newLambda); return success(); }) - .Case([&](auto u) { - auto pi = arith::ConstantOp::create( - rewriter, op.getLoc(), - rewriter.getF64FloatAttr(std::numbers::pi)); - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - newPhi = arith::SubFOp::create(rewriter, op.getLoc(), newPhi, pi); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - newLambda = - arith::AddFOp::create(rewriter, op.getLoc(), newLambda, pi); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), newPhi, - newLambda); + .Case([&](auto u2) { + Value pi = arith::ConstantOp::create( + rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); + Value newPhi = arith::NegFOp::create(rewriter, loc, u2.getLambda()); + newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); + Value newLambda = arith::NegFOp::create(rewriter, loc, u2.getPhi()); + newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(u2.getTarget(0), outerQubits), + newPhi, newLambda); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(1), - op.getTarget(0)); + .Case([&](auto dcx) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(dcx.getTarget(1), outerQubits), + utils::getValueFromBlockArgument(dcx.getTarget(0), outerQubits)); return success(); }) .Case([&](auto rxx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rxx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rxx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rxx.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(rxx.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto ry) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ry.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ry.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ry.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto ryy) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ryy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ryy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ryy.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(ryy.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto rz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rz.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto rzx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzx.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(rzx.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto rzz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzz.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(rzz.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto xxminusyy) { - Value negTheta = arith::NegFOp::create(rewriter, op.getLoc(), - xxminusyy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta, - xxminusyy.getBeta()); + Value negTheta = + arith::NegFOp::create(rewriter, loc, xxminusyy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(xxminusyy.getTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxminusyy.getTarget(1), + outerQubits), + negTheta, xxminusyy.getBeta()); return success(); }) .Case([&](auto xxplusyy) { Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), xxplusyy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta, - xxplusyy.getBeta()); + arith::NegFOp::create(rewriter, loc, xxplusyy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(xxplusyy.getTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxplusyy.getTarget(1), + outerQubits), + negTheta, xxplusyy.getBeta()); return success(); }) .Default([&](auto) { return failure(); }); @@ -233,41 +300,79 @@ struct ReplaceWithKnownGates final : OpRewritePattern { */ struct CancelNestedInv final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InvOp invOp, + LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto innerUnitary = invOp.getBodyUnitary(); - auto innerInvOp = dyn_cast(innerUnitary.getOperation()); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerInvOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerInvOp) { return failure(); } - auto* innerInnerUnitary = innerInvOp.getBodyUnitary().getOperation(); - rewriter.moveOpBefore(innerInnerUnitary, invOp); - rewriter.replaceOp(invOp, innerInnerUnitary->getResults()); + // TODO: Relax this condition? + if (innerInvOp.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerInnerOp = innerInvOp.getBodyUnitary(0).getOperation(); + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getQubits(); + auto innerQubits = innerInvOp.getQubits(); + SmallVector qubits; + for (auto qubit : innerInnerOp->getOperands().take_front(numQubits)) { + auto innerQubit = utils::getValueFromBlockArgument(qubit, innerQubits); + qubits.push_back( + utils::getValueFromBlockArgument(innerQubit, outerQubits)); + } + + rewriter.moveOpBefore(innerInnerOp, op); + innerInnerOp->setOperands(0, numQubits, qubits); + rewriter.replaceOp(op, innerInnerOp->getResults()); return success(); } }; } // namespace -UnitaryOpInterface InvOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t InvOp::getNumBodyUnitaries() { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + count++; + } + } + return count; +} + +UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + if (count == i) { + return cast(op); + } + count++; + } + } + llvm::reportFatalUsageError("Invalid unitary index"); } void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, - const function_ref& bodyBuilder) { - const OpBuilder::InsertionGuard guard(odsBuilder); - auto* region = odsState.addRegion(); - auto& block = region->emplaceBlock(); + ValueRange qubits, + const function_ref& body) { + build(odsBuilder, odsState, qubits); + auto& block = odsState.regions.front()->emplaceBlock(); + + auto qubitType = QubitType::get(odsBuilder.getContext()); + for (size_t i = 0; i < qubits.size(); ++i) { + block.addArgument(qubitType, odsState.location); + } + const OpBuilder::InsertionGuard guard(odsBuilder); odsBuilder.setInsertionPointToStart(&block); - bodyBuilder(); + body(block.getArguments()); YieldOp::create(odsBuilder, odsState.location); } @@ -280,16 +385,6 @@ LogicalResult InvOp::verify() { return emitOpError( "last operation in body region must be a yield operation"); } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } return success(); } diff --git a/mlir/lib/Dialect/QC/IR/QCOps.cpp b/mlir/lib/Dialect/QC/IR/QCOps.cpp index 5b93c2ebaa..bf6551f924 100644 --- a/mlir/lib/Dialect/QC/IR/QCOps.cpp +++ b/mlir/lib/Dialect/QC/IR/QCOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QC/IR/QCDialect.h" // IWYU pragma: associated +#include "mlir/Dialect/Utils/Utils.h" // The following headers are needed for some template instantiations. // IWYU pragma: begin_keep @@ -21,6 +22,17 @@ using namespace mlir; using namespace mlir::qc; +static ParseResult +parseTargetAliasing(OpAsmParser& parser, Region& region, + SmallVectorImpl& operands) { + return utils::parseTargetAliasing(parser, region, operands); +} + +static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, + Region& region, OperandRange targetsIn) { + utils::printTargetAliasing(printer, region, targetsIn); +} + //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index ac70cd4269..66d562ae82 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -452,10 +452,14 @@ static void addISWAPdgOp(QCProgramBuilder& builder, auto target0 = qubits[operation.getTargets()[0]]; auto target1 = qubits[operation.getTargets()[1]]; if (const auto& controls = getControls(operation, qubits); controls.empty()) { - builder.inv([&] { builder.iswap(target0, target1); }); + builder.inv({target0, target1}, [&](ValueRange qubits) { + builder.iswap(qubits[0], qubits[1]); + }); } else { - builder.ctrl(controls, [&] { - builder.inv([&] { builder.iswap(target0, target1); }); + builder.ctrl(controls, {target0, target1}, [&](ValueRange targets) { + builder.inv(targets, [&](ValueRange qubits) { + builder.iswap(qubits[0], qubits[1]); + }); }); } } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 25fc88d084..e86f3f7dc3 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -42,38 +43,52 @@ struct MergeNestedCtrl final : OpRewritePattern { LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - // Require at least one positive control + // Require at least one control // Trivial case is handled by ReduceCtrl - const auto numOuterControls = op.getNumControls(); - if (numOuterControls == 0) { + if (op.getNumControls() == 0) { return failure(); } - auto bodyCtrlOp = dyn_cast(op.getBodyUnitary().getOperation()); - if (!bodyCtrlOp) { + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { return failure(); } - const auto numInnerControls = bodyCtrlOp.getNumControls(); - auto outerControls = op.getControlsIn(); + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + if (!innerCtrlOp) { + return failure(); + } + auto outerTargets = op.getTargetsIn(); - auto newAdditionalControls = outerTargets.take_front(numInnerControls); - auto newTargets = outerTargets.drop_front(numInnerControls); - auto newControls = llvm::to_vector( - llvm::concat(outerControls, newAdditionalControls)); + auto outerControls = op.getControlsIn(); + auto innerTargets = innerCtrlOp.getTargetsIn(); + + SmallVector controls; + SmallVector targets; + llvm::append_range(controls, outerControls); + for (auto [arg, qubit] : + llvm::zip_equal(op.getBody()->getArguments(), outerTargets)) { + if (llvm::is_contained(innerTargets, arg)) { + targets.push_back(qubit); + } else { + controls.push_back(qubit); + } + } rewriter.replaceOpWithNewOp( - op, newControls, newTargets, - [&](ValueRange newTargetArgs) -> SmallVector { + op, controls, targets, + [&](ValueRange targetArgs) -> SmallVector { + auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - auto* innerBody = bodyCtrlOp.getBody(); - for (size_t i = 0; i < bodyCtrlOp.getNumTargets(); ++i) { - mapping.map(innerBody->getArgument(i), newTargetArgs[i]); + utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, + targets, targetArgs); + SmallVector yields; + for (auto& op : innerCtrlBody->without_terminator()) { + auto results = rewriter.clone(op, mapping)->getResults(); + llvm::append_range(yields, results); } - - return rewriter - .clone(*bodyCtrlOp.getBodyUnitary().getOperation(), mapping) - ->getResults(); + return yields; }); + return success(); } }; @@ -87,20 +102,32 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - auto* bodyUnitary = op.getBodyUnitary().getOperation(); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + // Inline ops from empty control modifiers, IdOp and BarrierOp - if (op.getNumControls() == 0 || isa(bodyUnitary)) { - rewriter.moveOpBefore(bodyUnitary, op); - bodyUnitary->setOperands(0, op.getNumTargets(), op.getTargetsIn()); + if (op.getNumControls() == 0 || isa(innerOp)) { + const auto numTargets = op.getNumTargets(); + auto outerTargets = op.getTargetsIn(); + SmallVector targets; + for (auto target : innerOp->getOperands().take_front(numTargets)) { + targets.push_back( + utils::getValueFromBlockArgument(target, outerTargets)); + } + + rewriter.moveOpBefore(innerOp, op); + innerOp->setOperands(0, numTargets, targets); rewriter.replaceAllUsesWith(op.getControlsOut(), op.getControlsIn()); - rewriter.replaceAllUsesWith(op.getTargetsOut(), - bodyUnitary->getResults()); + rewriter.replaceAllUsesWith(op.getTargetsOut(), innerOp->getResults()); rewriter.eraseOp(op); return success(); } // The remaining code explicitly handles GPhaseOp and nothing else - auto gPhaseOp = dyn_cast(bodyUnitary); + auto gPhaseOp = dyn_cast(innerOp); if (!gPhaseOp) { return failure(); } @@ -136,7 +163,7 @@ struct ReduceCtrl final : OpRewritePattern { auto yieldOp = cast(op.getBody()->back()); yieldOp->setOperands(pOp->getResults()); - // erase the GPhaseOp + // Erase the GPhaseOp rewriter.eraseOp(gPhaseOp); return success(); @@ -145,13 +172,27 @@ struct ReduceCtrl final : OpRewritePattern { } // namespace -UnitaryOpInterface CtrlOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t CtrlOp::getNumBodyUnitaries() { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + count++; + } + } + return count; +} + +UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + if (count == i) { + return cast(op); + } + count++; + } + } + llvm::reportFatalUsageError("Unitary index out of bounds"); } Value CtrlOp::getInputQubit(const size_t i) { @@ -162,7 +203,7 @@ Value CtrlOp::getInputQubit(const size_t i) { if (numControls <= i && i < getNumQubits()) { return getTargetsIn()[i - numControls]; } - llvm::reportFatalUsageError("Invalid qubit index"); + llvm::reportFatalUsageError("Qubit index out of bounds"); } Value CtrlOp::getOutputQubit(const size_t i) { @@ -173,7 +214,7 @@ Value CtrlOp::getOutputQubit(const size_t i) { if (numControls <= i && i < getNumQubits()) { return getTargetsOut()[i - numControls]; } - llvm::reportFatalUsageError("Invalid qubit index"); + llvm::reportFatalUsageError("Qubit index out of bounds"); } Value CtrlOp::getInputTarget(const size_t i) { @@ -238,7 +279,7 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, controls, targets); auto& block = odsState.regions.front()->emplaceBlock(); - const auto qubitType = QubitType::get(odsBuilder.getContext()); + auto qubitType = QubitType::get(odsBuilder.getContext()); for (size_t i = 0; i < targets.size(); ++i) { block.addArgument(qubitType, odsState.location); } @@ -275,18 +316,9 @@ LogicalResult CtrlOp::verify() { return emitOpError("yield operation must yield ") << numTargets << " values, but found " << numYieldOperands; } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } SmallPtrSet uniqueQubitsIn; + SmallPtrSet uniqueTargetsIn; for (const auto& control : getControlsIn()) { if (!uniqueQubitsIn.insert(control).second) { return emitOpError("duplicate control qubit found"); @@ -296,29 +328,20 @@ LogicalResult CtrlOp::verify() { if (!uniqueQubitsIn.insert(target).second) { return emitOpError("duplicate target qubit found"); } - } - - auto bodyUnitary = getBodyUnitary(); - if (bodyUnitary.getNumQubits() != numTargets) { - return emitOpError("body unitary must operate on exactly ") - << numTargets << " target qubits, but found " - << bodyUnitary.getNumQubits(); - } - const auto numQubits = bodyUnitary.getNumQubits(); - for (size_t i = 0; i < numQubits; i++) { - if (bodyUnitary.getInputQubit(i) != block.getArgument(i)) { - return emitOpError("body unitary must use target alias block argument ") - << i << " (and not the original target operand)"; + if (!uniqueTargetsIn.insert(target).second) { + return emitOpError("duplicate target qubit found"); } } - // Also require yield to forward the unitary's outputs in-order. - for (size_t i = 0; i < numTargets; ++i) { - if (block.back().getOperand(i) != bodyUnitary.getOutputQubit(i)) { - return emitOpError("yield operand ") - << i << " must be the body unitary output qubit " << i; - } - } + // TODO: Re-enable + // for (size_t i = 0; i < getNumBodyUnitaries(); ++i) { + // auto bodyUnitary = getBodyUnitary(i); + // for (size_t j = 0; j < bodyUnitary.getNumQubits(); ++j) { + // if (!uniqueTargetsIn.contains(bodyUnitary.getInputQubit(j))) { + // return emitOpError("unitary is using an unknown input qubit"); + // } + // } + // } SmallPtrSet uniqueQubitsOut; for (const auto& control : getControlsOut()) { @@ -326,8 +349,8 @@ LogicalResult CtrlOp::verify() { return emitOpError("duplicate control qubit found"); } } - for (size_t i = 0; i < numQubits; i++) { - if (!uniqueQubitsOut.insert(bodyUnitary.getOutputQubit(i)).second) { + for (size_t i = 0; i < numTargets; i++) { + if (!uniqueQubitsOut.insert(block.back().getOperand(i)).second) { return emitOpError("duplicate qubit found"); } } @@ -341,11 +364,16 @@ void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional CtrlOp::getUnitaryMatrix() { - auto&& bodyUnitary = getBodyUnitary(); + // TODO: Relax this condition + if (getNumBodyUnitaries() != 1) { + return std::nullopt; + } + + auto bodyUnitary = getBodyUnitary(0); if (!bodyUnitary) { return std::nullopt; } - auto&& targetMatrix = bodyUnitary.getUnitaryMatrix(); + auto targetMatrix = bodyUnitary.getUnitaryMatrix(); if (!targetMatrix) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index d82a64f819..1b6a98c07e 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -40,36 +41,41 @@ namespace { struct MoveCtrlOutside final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InvOp invOp, + LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto bodyUnitary = invOp.getBodyUnitary(); - auto innerCtrlOp = dyn_cast(bodyUnitary.getOperation()); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerCtrlOp) { return failure(); } const auto numControls = innerCtrlOp.getNumControls(); const auto numTargets = innerCtrlOp.getNumTargets(); - auto invTargets = invOp.getInputQubits(); - auto controls = invTargets.take_front(numControls); - auto targets = invTargets.take_back(numTargets); + auto outerQubits = op.getQubitsIn(); + auto controls = outerQubits.take_front(numControls); + auto targets = outerQubits.take_back(numTargets); rewriter.replaceOpWithNewOp( - invOp, controls, targets, - [&](ValueRange newTargetArgs) -> SmallVector { + op, controls, targets, + [&](ValueRange targetArgs) -> SmallVector { return InvOp::create( - rewriter, invOp.getLoc(), newTargetArgs, - [&](ValueRange invArgs) -> SmallVector { + rewriter, op.getLoc(), targetArgs, + [&](ValueRange qubitArgs) -> SmallVector { + auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - auto* innerBody = innerCtrlOp.getBody(); - for (size_t i = 0; i < innerCtrlOp.getNumTargets(); - ++i) { - mapping.map(innerBody->getArgument(i), invArgs[i]); + utils::prova(*innerCtrlBody, mapping, + innerCtrlOp.getTargetsIn(), outerQubits, + targets, qubitArgs); + SmallVector yields; + for (auto& op : innerCtrlBody->without_terminator()) { + auto results = + rewriter.clone(op, mapping)->getResults(); + llvm::append_range(yields, results); } - auto* cloned = rewriter.clone( - *innerCtrlOp.getBodyUnitary().getOperation(), - mapping); - return cloned->getResults(); + return yields; }) .getResults(); }); @@ -88,14 +94,24 @@ struct InlineSelfAdjoint final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); if (!isa(innerOp)) { return failure(); } + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getInputQubits(); + SmallVector qubits; + for (auto qubit : innerOp->getOperands().take_front(numQubits)) { + qubits.push_back(utils::getValueFromBlockArgument(qubit, outerQubits)); + } + rewriter.moveOpBefore(innerOp, op); - innerOp->setOperands(0, op.getNumQubits(), op.getInputQubits()); + innerOp->setOperands(0, numQubits, qubits); rewriter.replaceOp(op, innerOp->getResults()); return success(); } @@ -112,138 +128,192 @@ struct ReplaceWithKnownGates final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + + auto loc = op.getLoc(); + auto outerQubits = op.getInputQubits(); return TypeSwitch(innerOp) .Case([&](auto g) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), g.getTheta()); + Value negTheta = arith::NegFOp::create(rewriter, loc, g.getTheta()); rewriter.replaceOpWithNewOp(op, negTheta); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto t) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(t.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto tdg) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(tdg.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto s) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(s.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto sdg) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(sdg.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto sx) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(sx.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto sxdg) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(sxdg.getInputTarget(0), + outerQubits)); return success(); }) .Case([&](auto p) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), p.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, p.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(p.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto r) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), r.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta, - r.getPhi()); + Value negTheta = arith::NegFOp::create(rewriter, loc, r.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(r.getInputTarget(0), + outerQubits), + negTheta, r.getPhi()); return success(); }) .Case([&](auto rx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rx.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto u) { - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - Value newTheta = - arith::NegFOp::create(rewriter, op.getLoc(), u.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), newTheta, - newPhi, newLambda); + Value newPhi = arith::NegFOp::create(rewriter, loc, u.getLambda()); + Value newLambda = arith::NegFOp::create(rewriter, loc, u.getPhi()); + Value newTheta = arith::NegFOp::create(rewriter, loc, u.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(u.getInputTarget(0), + outerQubits), + newTheta, newPhi, newLambda); return success(); }) - .Case([&](auto u) { + .Case([&](auto u2) { auto pi = arith::ConstantOp::create( - rewriter, op.getLoc(), - rewriter.getF64FloatAttr(std::numbers::pi)); - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - newPhi = arith::SubFOp::create(rewriter, op.getLoc(), newPhi, pi); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - newLambda = - arith::AddFOp::create(rewriter, op.getLoc(), newLambda, pi); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), newPhi, - newLambda); + rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); + Value newPhi = arith::NegFOp::create(rewriter, loc, u2.getLambda()); + newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); + Value newLambda = arith::NegFOp::create(rewriter, loc, u2.getPhi()); + newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(u2.getInputTarget(0), + outerQubits), + newPhi, newLambda); return success(); }) .Case([&](auto rxx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rxx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rxx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rxx.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(rxx.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto ry) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ry.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ry.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ry.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto ryy) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ryy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ryy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ryy.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(ryy.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto rz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rz.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto rzx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzx.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(rzx.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto rzz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzz.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(rzz.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto xxminusyy) { - Value negTheta = arith::NegFOp::create(rewriter, op.getLoc(), - xxminusyy.getTheta()); + Value negTheta = + arith::NegFOp::create(rewriter, loc, xxminusyy.getTheta()); rewriter.replaceOpWithNewOp( - op, op.getInputTarget(0), op.getInputTarget(1), negTheta, - xxminusyy.getBeta()); + op, + utils::getValueFromBlockArgument(xxminusyy.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxminusyy.getInputTarget(1), + outerQubits), + negTheta, xxminusyy.getBeta()); return success(); }) .Case([&](auto xxplusyy) { Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), xxplusyy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), - negTheta, xxplusyy.getBeta()); + arith::NegFOp::create(rewriter, loc, xxplusyy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(xxplusyy.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxplusyy.getInputTarget(1), + outerQubits), + negTheta, xxplusyy.getBeta()); return success(); }) .Default([&](auto) { return failure(); }); @@ -258,30 +328,61 @@ struct CancelNestedInv final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerUnitary = op.getBodyUnitary().getOperation(); - auto innerInvOp = dyn_cast(innerUnitary); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerInvOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerInvOp) { return failure(); } - auto* innerInnerUnitary = innerInvOp.getBodyUnitary().getOperation(); - rewriter.moveOpBefore(innerInnerUnitary, op); - innerInnerUnitary->setOperands(0, op.getNumQubits(), op.getInputQubits()); - rewriter.replaceOp(op, innerInnerUnitary->getResults()); + // TODO: Relax this condition? + if (innerInvOp.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerInnerOp = innerInvOp.getBodyUnitary(0).getOperation(); + + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getInputQubits(); + auto innerQubits = innerInvOp.getInputQubits(); + SmallVector qubits; + for (auto qubit : innerInnerOp->getOperands().take_front(numQubits)) { + auto innerQubit = utils::getValueFromBlockArgument(qubit, innerQubits); + qubits.push_back( + utils::getValueFromBlockArgument(innerQubit, outerQubits)); + } + rewriter.moveOpBefore(innerInnerOp, op); + innerInnerOp->setOperands(0, numQubits, qubits); + rewriter.replaceOp(op, innerInnerOp->getResults()); return success(); } }; } // namespace -UnitaryOpInterface InvOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t InvOp::getNumBodyUnitaries() { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + count++; + } + } + return count; +} + +UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + if (count == i) { + return cast(op); + } + count++; + } + } + llvm::reportFatalUsageError("Unitary index out of bounds"); } Value InvOp::getInputQubit(const size_t i) { @@ -322,7 +423,7 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubits); auto& block = odsState.regions.front()->emplaceBlock(); - const auto qubitType = QubitType::get(odsBuilder.getContext()); + auto qubitType = QubitType::get(odsBuilder.getContext()); for (size_t i = 0; i < qubits.size(); ++i) { block.addArgument(qubitType, odsState.location); } @@ -359,38 +460,23 @@ LogicalResult InvOp::verify() { return emitOpError("yield operation must yield ") << numTargets << " values, but found " << numYieldOperands; } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } - auto bodyUnitary = getBodyUnitary(); - if (bodyUnitary.getNumQubits() != numTargets) { - return emitOpError("body unitary must operate on exactly ") - << numTargets << " target qubits, but found " - << bodyUnitary.getNumQubits(); - } - const auto numQubits = bodyUnitary.getNumQubits(); - for (size_t i = 0; i < numQubits; i++) { - if (bodyUnitary.getInputQubit(i) != block.getArgument(i)) { - return emitOpError("body unitary must use target alias block argument ") - << i << " (and not the original target operand)"; + SmallPtrSet uniqueQubitsIn; + for (const auto& target : getQubitsIn()) { + if (!uniqueQubitsIn.insert(target).second) { + return emitOpError("duplicate qubit found"); } } - // Also require yield to forward the unitary's outputs in-order. - for (size_t i = 0; i < numTargets; ++i) { - if (block.back().getOperand(i) != bodyUnitary.getOutputQubit(i)) { - return emitOpError("yield operand ") - << i << " must be the body unitary output qubit " << i; - } - } + // TODO: Re-enable + // for (size_t i = 0; i < getNumBodyUnitaries(); ++i) { + // auto bodyUnitary = getBodyUnitary(i); + // for (size_t j = 0; j < bodyUnitary.getNumQubits(); ++j) { + // if (!uniqueQubitsIn.contains(bodyUnitary.getInputQubit(j))) { + // return emitOpError("unitary is using an unknown qubit"); + // } + // } + // } return success(); } @@ -402,11 +488,16 @@ void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional InvOp::getUnitaryMatrix() { - auto&& bodyUnitary = getBodyUnitary(); + // TODO: Relax this condition + if (getNumBodyUnitaries() != 1) { + return std::nullopt; + } + + auto bodyUnitary = getBodyUnitary(0); if (!bodyUnitary) { return std::nullopt; } - auto&& targetMatrix = bodyUnitary.getUnitaryMatrix(); + auto targetMatrix = bodyUnitary.getUnitaryMatrix(); if (!targetMatrix) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index a3ce816081..f1cb23a849 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" // IWYU pragma: associated +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -37,57 +38,12 @@ using namespace mlir::qco; static ParseResult parseTargetAliasing(OpAsmParser& parser, Region& region, SmallVectorImpl& operands) { - // 1. Parse the opening parenthesis - if (parser.parseLParen()) { - return failure(); - } - - // Temporary storage for block arguments we are about to create - SmallVector blockArgs; - - // 2. Prepare to parse the list - if (failed(parser.parseOptionalRParen())) { - do { - OpAsmParser::Argument newArg; // The "new" variable name - OpAsmParser::UnresolvedOperand oldOperand; // The "old" input variable - - // Parse "%new" - if (parser.parseArgument(newArg)) { - return failure(); - } - - // Parse "=" - if (parser.parseEqual()) { - return failure(); - } - - // Parse "%old" - if (parser.parseOperand(oldOperand)) { - return failure(); - } - operands.push_back(oldOperand); - - // Hard-code QubitType since targets in qco.ctrl are always qubits. - // This avoids double-binding type($targets_in) in the assembly format - // while keeping the parser simple and the assembly format clean. - newArg.type = QubitType::get(parser.getBuilder().getContext()); - blockArgs.push_back(newArg); - - } while (succeeded(parser.parseOptionalComma())); - - if (parser.parseRParen()) { - return failure(); - } - } - - // 4. Parse the Region - // We explicitly pass the blockArgs we just parsed so they become the entry - // block! - if (parser.parseRegion(region, blockArgs)) { - return failure(); - } + return utils::parseTargetAliasing(parser, region, operands); +} - return success(); +static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, + Region& region, OperandRange targetsIn) { + utils::printTargetAliasing(printer, region, targetsIn); } ParseResult IfOp::parse(::mlir::OpAsmParser& parser, @@ -213,30 +169,6 @@ void IfOp::print(OpAsmPrinter& p) { p.printOptionalAttrDict((*this)->getAttrs()); } -static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, - Region& region, OperandRange targetsIn) { - printer << "("; - if (region.empty()) { - printer << ") "; - printer.printRegion(region, false); - return; - } - Block& entryBlock = region.front(); - - const auto numTargets = targetsIn.size(); - for (unsigned i = 0; i < numTargets; ++i) { - if (i > 0) { - printer << ", "; - } - printer.printOperand(entryBlock.getArgument(i)); - printer << " = "; - printer.printOperand(targetsIn[i]); - } - printer << ") "; - - printer.printRegion(region, false); -} - //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp index 0ca22726a1..3d874533b6 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp @@ -162,7 +162,8 @@ struct LiftHadamardAboveCNOTPattern final : OpRewritePattern { if (!cnotGate) { return failure(); } - if (!isa(cnotGate.getBodyUnitary()) || + if (cnotGate.getNumBodyUnitaries() != 1 || + !isa(cnotGate.getBodyUnitary(0)) || cnotGate.getOutputTarget(0) != inQubitHadamard) { return failure(); } diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 373452252a..234c70cc9a 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -62,7 +62,7 @@ void staticQubitsWithCtrl(QCProgramBuilder& b) { void staticQubitsWithInv(QCProgramBuilder& b) { auto q0 = b.staticQubit(0); - b.inv([&]() { b.t(q0); }); + b.inv({q0}, [&](ValueRange qubits) { b.t(qubits[0]); }); } void staticQubitsWithDuplicates(QCProgramBuilder& b) { @@ -75,7 +75,7 @@ void staticQubitsWithDuplicates(QCProgramBuilder& b) { b.p(std::numbers::pi / 2., q1a); b.rzz(0.123, q0b, q1b); b.cx(q0b, q1b); - b.inv([&]() { b.t(q0a); }); + b.inv({q0a}, [&](ValueRange qubits) { b.t(qubits[0]); }); } void staticQubitsCanonical(QCProgramBuilder& b) { @@ -86,7 +86,7 @@ void staticQubitsCanonical(QCProgramBuilder& b) { b.p(std::numbers::pi / 2., q1); b.rzz(0.123, q0, q1); b.cx(q0, q1); - b.inv([&]() { b.t(q0); }); + b.inv({q0}, [&](ValueRange qubits) { b.t(qubits[0]); }); } void allocDeallocPair(QCProgramBuilder& b) { @@ -194,7 +194,8 @@ void multipleControlledGlobalPhase(QCProgramBuilder& b) { void nestedControlledGlobalPhase(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.ctrl(q[0], [&] { b.cgphase(0.123, q[1]); }); + b.ctrl(q[0], {q[1]}, + [&](ValueRange targets) { b.cgphase(0.123, targets[0]); }); } void trivialControlledGlobalPhase(QCProgramBuilder& b) { @@ -203,12 +204,13 @@ void trivialControlledGlobalPhase(QCProgramBuilder& b) { } void inverseGlobalPhase(QCProgramBuilder& b) { - b.inv([&]() { b.gphase(-0.123); }); + b.inv({}, [&](ValueRange qubits) { b.gphase(-0.123); }); } void inverseMultipleControlledGlobalPhase(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcgphase(-0.123, {q[0], q[1], q[2]}); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcgphase(-0.123, qubits); }); } void identity(QCProgramBuilder& b) { @@ -228,7 +230,8 @@ void multipleControlledIdentity(QCProgramBuilder& b) { void nestedControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.ctrl(q[2], [&] { b.cid(q[1], q[0]); }); + b.ctrl(q[2], {q[0], q[1]}, + [&](ValueRange targets) { b.cid(targets[1], targets[0]); }); } void trivialControlledIdentity(QCProgramBuilder& b) { @@ -238,12 +241,13 @@ void trivialControlledIdentity(QCProgramBuilder& b) { void inverseIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.id(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.id(qubits[0]); }); } void inverseMultipleControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcid({q[2], q[1]}, q[0]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcid({qubits[0], qubits[1]}, qubits[2]); }); } void x(QCProgramBuilder& b) { @@ -263,7 +267,8 @@ void multipleControlledX(QCProgramBuilder& b) { void nestedControlledX(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cx(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cx(targets[0], targets[1]); }); } void trivialControlledX(QCProgramBuilder& b) { @@ -282,12 +287,13 @@ void repeatedControlledX(QCProgramBuilder& b) { void inverseX(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.x(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.x(qubits[0]); }); } void inverseMultipleControlledX(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcx({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcx({qubits[0], qubits[1]}, qubits[2]); }); } void y(QCProgramBuilder& b) { @@ -307,7 +313,8 @@ void multipleControlledY(QCProgramBuilder& b) { void nestedControlledY(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cy(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cy(targets[0], targets[1]); }); } void trivialControlledY(QCProgramBuilder& b) { @@ -317,12 +324,13 @@ void trivialControlledY(QCProgramBuilder& b) { void inverseY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.y(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.y(qubits[0]); }); } void inverseMultipleControlledY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcy({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcy({qubits[0], qubits[1]}, qubits[2]); }); } void z(QCProgramBuilder& b) { @@ -342,7 +350,8 @@ void multipleControlledZ(QCProgramBuilder& b) { void nestedControlledZ(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cz(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cz(targets[0], targets[1]); }); } void trivialControlledZ(QCProgramBuilder& b) { @@ -352,12 +361,13 @@ void trivialControlledZ(QCProgramBuilder& b) { void inverseZ(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.z(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.z(qubits[0]); }); } void inverseMultipleControlledZ(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcz({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcz({qubits[0], qubits[1]}, qubits[2]); }); } void h(QCProgramBuilder& b) { @@ -377,7 +387,8 @@ void multipleControlledH(QCProgramBuilder& b) { void nestedControlledH(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.ch(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.ch(targets[0], targets[1]); }); } void trivialControlledH(QCProgramBuilder& b) { @@ -387,12 +398,13 @@ void trivialControlledH(QCProgramBuilder& b) { void inverseH(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.h(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.h(qubits[0]); }); } void inverseMultipleControlledH(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mch({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mch({qubits[0], qubits[1]}, qubits[2]); }); } void hWithoutRegister(QCProgramBuilder& b) { @@ -417,7 +429,8 @@ void multipleControlledS(QCProgramBuilder& b) { void nestedControlledS(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cs(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cs(targets[0], targets[1]); }); } void trivialControlledS(QCProgramBuilder& b) { @@ -427,12 +440,13 @@ void trivialControlledS(QCProgramBuilder& b) { void inverseS(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.s(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.s(qubits[0]); }); } void inverseMultipleControlledS(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcs({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcs({qubits[0], qubits[1]}, qubits[2]); }); } void sdg(QCProgramBuilder& b) { @@ -452,7 +466,8 @@ void multipleControlledSdg(QCProgramBuilder& b) { void nestedControlledSdg(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.csdg(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.csdg(targets[0], targets[1]); }); } void trivialControlledSdg(QCProgramBuilder& b) { @@ -462,12 +477,13 @@ void trivialControlledSdg(QCProgramBuilder& b) { void inverseSdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.sdg(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.sdg(qubits[0]); }); } void inverseMultipleControlledSdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcsdg({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcsdg({qubits[0], qubits[1]}, qubits[2]); }); } void t_(QCProgramBuilder& b) { @@ -487,7 +503,8 @@ void multipleControlledT(QCProgramBuilder& b) { void nestedControlledT(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.ct(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.ct(targets[0], targets[1]); }); } void trivialControlledT(QCProgramBuilder& b) { @@ -497,12 +514,13 @@ void trivialControlledT(QCProgramBuilder& b) { void inverseT(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.t(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.t(qubits[0]); }); } void inverseMultipleControlledT(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mct({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mct({qubits[0], qubits[1]}, qubits[2]); }); } void tdg(QCProgramBuilder& b) { @@ -522,7 +540,8 @@ void multipleControlledTdg(QCProgramBuilder& b) { void nestedControlledTdg(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.ctdg(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.ctdg(targets[0], targets[1]); }); } void trivialControlledTdg(QCProgramBuilder& b) { @@ -532,12 +551,13 @@ void trivialControlledTdg(QCProgramBuilder& b) { void inverseTdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.tdg(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.tdg(qubits[0]); }); } void inverseMultipleControlledTdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mctdg({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mctdg({qubits[0], qubits[1]}, qubits[2]); }); } void sx(QCProgramBuilder& b) { @@ -557,7 +577,8 @@ void multipleControlledSx(QCProgramBuilder& b) { void nestedControlledSx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.csx(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.csx(targets[0], targets[1]); }); } void trivialControlledSx(QCProgramBuilder& b) { @@ -567,12 +588,13 @@ void trivialControlledSx(QCProgramBuilder& b) { void inverseSx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.sx(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.sx(qubits[0]); }); } void inverseMultipleControlledSx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcsx({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcsx({qubits[0], qubits[1]}, qubits[2]); }); } void sxdg(QCProgramBuilder& b) { @@ -592,7 +614,8 @@ void multipleControlledSxdg(QCProgramBuilder& b) { void nestedControlledSxdg(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.csxdg(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.csxdg(targets[0], targets[1]); }); } void trivialControlledSxdg(QCProgramBuilder& b) { @@ -602,12 +625,14 @@ void trivialControlledSxdg(QCProgramBuilder& b) { void inverseSxdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.sxdg(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.sxdg(qubits[0]); }); } void inverseMultipleControlledSxdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcsxdg({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcsxdg({qubits[0], qubits[1]}, qubits[2]); + }); } void rx(QCProgramBuilder& b) { @@ -627,7 +652,8 @@ void multipleControlledRx(QCProgramBuilder& b) { void nestedControlledRx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.crx(0.123, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.crx(0.123, targets[0], targets[1]); }); } void trivialControlledRx(QCProgramBuilder& b) { @@ -637,12 +663,14 @@ void trivialControlledRx(QCProgramBuilder& b) { void inverseRx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.rx(-0.123, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.rx(-0.123, qubits[0]); }); } void inverseMultipleControlledRx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcrx(-0.123, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcrx(-0.123, {qubits[0], qubits[1]}, qubits[2]); + }); } void ry(QCProgramBuilder& b) { @@ -662,7 +690,8 @@ void multipleControlledRy(QCProgramBuilder& b) { void nestedControlledRy(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cry(0.456, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cry(0.456, targets[0], targets[1]); }); } void trivialControlledRy(QCProgramBuilder& b) { @@ -672,12 +701,14 @@ void trivialControlledRy(QCProgramBuilder& b) { void inverseRy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.ry(-0.456, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.ry(-0.456, qubits[0]); }); } void inverseMultipleControlledRy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcry(-0.456, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcry(-0.456, {qubits[0], qubits[1]}, qubits[2]); + }); } void rz(QCProgramBuilder& b) { @@ -697,7 +728,8 @@ void multipleControlledRz(QCProgramBuilder& b) { void nestedControlledRz(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.crz(0.789, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.crz(0.789, targets[0], targets[1]); }); } void trivialControlledRz(QCProgramBuilder& b) { @@ -707,12 +739,14 @@ void trivialControlledRz(QCProgramBuilder& b) { void inverseRz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.rz(-0.789, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.rz(-0.789, qubits[0]); }); } void inverseMultipleControlledRz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcrz(-0.789, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcrz(-0.789, {qubits[0], qubits[1]}, qubits[2]); + }); } void p(QCProgramBuilder& b) { @@ -732,7 +766,8 @@ void multipleControlledP(QCProgramBuilder& b) { void nestedControlledP(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cp(0.123, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cp(0.123, targets[0], targets[1]); }); } void trivialControlledP(QCProgramBuilder& b) { @@ -742,12 +777,14 @@ void trivialControlledP(QCProgramBuilder& b) { void inverseP(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.p(-0.123, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.p(-0.123, qubits[0]); }); } void inverseMultipleControlledP(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcp(-0.123, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcp(-0.123, {qubits[0], qubits[1]}, qubits[2]); + }); } void r(QCProgramBuilder& b) { @@ -767,7 +804,9 @@ void multipleControlledR(QCProgramBuilder& b) { void nestedControlledR(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cr(0.123, 0.456, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, [&](ValueRange targets) { + b.cr(0.123, 0.456, targets[0], targets[1]); + }); } void trivialControlledR(QCProgramBuilder& b) { @@ -777,12 +816,14 @@ void trivialControlledR(QCProgramBuilder& b) { void inverseR(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.r(-0.123, 0.456, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.r(-0.123, 0.456, qubits[0]); }); } void inverseMultipleControlledR(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcr(-0.123, 0.456, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcr(-0.123, 0.456, {qubits[0], qubits[1]}, qubits[2]); + }); } void u2(QCProgramBuilder& b) { @@ -802,7 +843,9 @@ void multipleControlledU2(QCProgramBuilder& b) { void nestedControlledU2(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cu2(0.234, 0.567, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, [&](ValueRange targets) { + b.cu2(0.234, 0.567, targets[0], targets[1]); + }); } void trivialControlledU2(QCProgramBuilder& b) { @@ -813,13 +856,16 @@ void trivialControlledU2(QCProgramBuilder& b) { void inverseU2(QCProgramBuilder& b) { constexpr double pi = std::numbers::pi; auto q = b.allocQubitRegister(1); - b.inv([&]() { b.u2(-0.567 + pi, -0.234 - pi, q[0]); }); + b.inv(q[0], + [&](ValueRange qubits) { b.u2(-0.567 + pi, -0.234 - pi, qubits[0]); }); } void inverseMultipleControlledU2(QCProgramBuilder& b) { constexpr double pi = std::numbers::pi; auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcu2(-0.567 + pi, -0.234 - pi, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcu2(-0.567 + pi, -0.234 - pi, {qubits[0], qubits[1]}, qubits[2]); + }); } void u(QCProgramBuilder& b) { @@ -839,7 +885,9 @@ void multipleControlledU(QCProgramBuilder& b) { void nestedControlledU(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cu(0.1, 0.2, 0.3, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, [&](ValueRange targets) { + b.cu(0.1, 0.2, 0.3, targets[0], targets[1]); + }); } void trivialControlledU(QCProgramBuilder& b) { @@ -849,12 +897,14 @@ void trivialControlledU(QCProgramBuilder& b) { void inverseU(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.u(-0.1, -0.3, -0.2, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.u(-0.1, -0.3, -0.2, qubits[0]); }); } void inverseMultipleControlledU(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcu(-0.1, -0.3, -0.2, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcu(-0.1, -0.3, -0.2, {qubits[0], qubits[1]}, qubits[2]); + }); } void swap(QCProgramBuilder& b) { @@ -874,7 +924,9 @@ void multipleControlledSwap(QCProgramBuilder& b) { void nestedControlledSwap(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cswap(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cswap(targets[0], targets[1], targets[2]); + }); } void trivialControlledSwap(QCProgramBuilder& b) { @@ -884,12 +936,14 @@ void trivialControlledSwap(QCProgramBuilder& b) { void inverseSwap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.swap(q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { b.swap(qubits[0], qubits[1]); }); } void inverseMultipleControlledSwap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcswap({q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcswap({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void iswap(QCProgramBuilder& b) { @@ -909,7 +963,9 @@ void multipleControlledIswap(QCProgramBuilder& b) { void nestedControlledIswap(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.ciswap(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.ciswap(targets[0], targets[1], targets[2]); + }); } void trivialControlledIswap(QCProgramBuilder& b) { @@ -919,12 +975,15 @@ void trivialControlledIswap(QCProgramBuilder& b) { void inverseIswap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.iswap(q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.iswap(qubits[0], qubits[1]); }); } void inverseMultipleControlledIswap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mciswap({q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mciswap({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void dcx(QCProgramBuilder& b) { @@ -944,7 +1003,9 @@ void multipleControlledDcx(QCProgramBuilder& b) { void nestedControlledDcx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cdcx(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cdcx(targets[0], targets[1], targets[2]); + }); } void trivialControlledDcx(QCProgramBuilder& b) { @@ -954,12 +1015,14 @@ void trivialControlledDcx(QCProgramBuilder& b) { void inverseDcx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.dcx(q[1], q[0]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { b.dcx(qubits[1], qubits[0]); }); } void inverseMultipleControlledDcx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcdcx({q[0], q[1]}, q[3], q[2]); }); + b.inv({q[0], q[1], q[3], q[2]}, [&](ValueRange qubits) { + b.mcdcx({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void ecr(QCProgramBuilder& b) { @@ -979,7 +1042,9 @@ void multipleControlledEcr(QCProgramBuilder& b) { void nestedControlledEcr(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cecr(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cecr(targets[0], targets[1], targets[2]); + }); } void trivialControlledEcr(QCProgramBuilder& b) { @@ -989,12 +1054,14 @@ void trivialControlledEcr(QCProgramBuilder& b) { void inverseEcr(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.ecr(q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { b.ecr(qubits[0], qubits[1]); }); } void inverseMultipleControlledEcr(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcecr({q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcecr({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void rxx(QCProgramBuilder& b) { @@ -1014,7 +1081,9 @@ void multipleControlledRxx(QCProgramBuilder& b) { void nestedControlledRxx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.crxx(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.crxx(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRxx(QCProgramBuilder& b) { @@ -1024,18 +1093,22 @@ void trivialControlledRxx(QCProgramBuilder& b) { void inverseRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.rxx(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.rxx(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcrxx(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcrxx(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void tripleControlledRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(5); b.mcrxx(0.123, {q[0], q[1], q[2]}, q[3], q[4]); } + void fourControlledRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(6); b.mcrxx(0.123, {q[0], q[1], q[2], q[3]}, q[4], q[5]); @@ -1058,7 +1131,9 @@ void multipleControlledRyy(QCProgramBuilder& b) { void nestedControlledRyy(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cryy(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cryy(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRyy(QCProgramBuilder& b) { @@ -1068,12 +1143,15 @@ void trivialControlledRyy(QCProgramBuilder& b) { void inverseRyy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.ryy(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.ryy(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRyy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcryy(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcryy(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void rzx(QCProgramBuilder& b) { @@ -1093,7 +1171,9 @@ void multipleControlledRzx(QCProgramBuilder& b) { void nestedControlledRzx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.crzx(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.crzx(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRzx(QCProgramBuilder& b) { @@ -1103,12 +1183,15 @@ void trivialControlledRzx(QCProgramBuilder& b) { void inverseRzx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.rzx(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.rzx(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRzx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcrzx(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcrzx(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void rzz(QCProgramBuilder& b) { @@ -1128,7 +1211,9 @@ void multipleControlledRzz(QCProgramBuilder& b) { void nestedControlledRzz(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.crzz(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.crzz(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRzz(QCProgramBuilder& b) { @@ -1138,12 +1223,15 @@ void trivialControlledRzz(QCProgramBuilder& b) { void inverseRzz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.rzz(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.rzz(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRzz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcrzz(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcrzz(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void xxPlusYY(QCProgramBuilder& b) { @@ -1163,7 +1251,9 @@ void multipleControlledXxPlusYY(QCProgramBuilder& b) { void nestedControlledXxPlusYY(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cxx_plus_yy(0.123, 0.456, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cxx_plus_yy(0.123, 0.456, targets[0], targets[1], targets[2]); + }); } void trivialControlledXxPlusYY(QCProgramBuilder& b) { @@ -1173,12 +1263,16 @@ void trivialControlledXxPlusYY(QCProgramBuilder& b) { void inverseXxPlusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.xx_plus_yy(-0.123, 0.456, q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.xx_plus_yy(-0.123, 0.456, qubits[0], qubits[1]); + }); } void inverseMultipleControlledXxPlusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcxx_plus_yy(-0.123, 0.456, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcxx_plus_yy(-0.123, 0.456, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void xxMinusYY(QCProgramBuilder& b) { @@ -1198,7 +1292,9 @@ void multipleControlledXxMinusYY(QCProgramBuilder& b) { void nestedControlledXxMinusYY(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cxx_minus_yy(0.123, 0.456, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cxx_minus_yy(0.123, 0.456, targets[0], targets[1], targets[2]); + }); } void trivialControlledXxMinusYY(QCProgramBuilder& b) { @@ -1208,12 +1304,17 @@ void trivialControlledXxMinusYY(QCProgramBuilder& b) { void inverseXxMinusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.xx_minus_yy(-0.123, 0.456, q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.xx_minus_yy(-0.123, 0.456, qubits[0], qubits[1]); + }); } void inverseMultipleControlledXxMinusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcxx_minus_yy(-0.123, 0.456, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcxx_minus_yy(-0.123, 0.456, {qubits[0], qubits[1]}, qubits[2], + qubits[3]); + }); } void barrier(QCProgramBuilder& b) { @@ -1233,59 +1334,91 @@ void barrierMultipleQubits(QCProgramBuilder& b) { void singleControlledBarrier(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.ctrl(q[1], [&] { b.barrier(q[0]); }); + b.ctrl(q[1], q[0], [&](ValueRange targets) { b.barrier(targets[0]); }); } void inverseBarrier(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.barrier(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.barrier(qubits[0]); }); } void trivialCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.ctrl({}, [&]() { b.rxx(0.123, q[0], q[1]); }); + b.ctrl({}, {q[0], q[1]}, + [&](ValueRange targets) { b.rxx(0.123, targets[0], targets[1]); }); } void nestedCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.ctrl(q[0], [&]() { b.ctrl(q[1], [&]() { b.rxx(0.123, q[2], q[3]); }); }); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + b.ctrl(targets[0], {targets[1], targets[2]}, [&](ValueRange innerTargets) { + b.rxx(0.123, innerTargets[0], innerTargets[1]); + }); + }); } void tripleNestedCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(5); - b.ctrl(q[0], [&]() { - b.ctrl(q[1], [&]() { b.ctrl(q[2], [&]() { b.rxx(0.123, q[3], q[4]); }); }); + b.ctrl(q[0], {q[1], q[2], q[3], q[4]}, [&](ValueRange targets) { + b.ctrl(targets[0], {targets[1], targets[2], targets[3]}, + [&](ValueRange innerTargets) { + b.ctrl(innerTargets[0], {innerTargets[1], innerTargets[2]}, + [&](ValueRange innerInnerTargets) { + b.rxx(0.123, innerInnerTargets[0], innerInnerTargets[1]); + }); + }); }); } void doubleNestedCtrlTwoQubits(QCProgramBuilder& b) { auto q = b.allocQubitRegister(6); - b.ctrl({q[0], q[1]}, - [&]() { b.ctrl({q[2], q[3]}, [&]() { b.rxx(0.123, q[4], q[5]); }); }); + b.ctrl({q[0], q[1]}, {q[2], q[3], q[4], q[5]}, [&](ValueRange targets) { + b.ctrl({targets[0], targets[1]}, {targets[2], targets[3]}, + [&](ValueRange innerTargets) { + b.rxx(0.123, innerTargets[0], innerTargets[1]); + }); + }); } void ctrlInvSandwich(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.ctrl(q[0], [&]() { - b.inv([&]() { b.ctrl(q[1], [&]() { b.rxx(-0.123, q[2], q[3]); }); }); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + b.inv(targets, [&](ValueRange qubits) { + b.ctrl(qubits[0], {qubits[1], qubits[2]}, [&](ValueRange innerTargets) { + b.rxx(-0.123, innerTargets[0], innerTargets[1]); + }); + }); }); } void nestedInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.inv([&]() { b.rxx(0.123, q[0], q[1]); }); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.inv(qubits, [&](ValueRange innerQubits) { + b.rxx(0.123, innerQubits[0], innerQubits[1]); + }); + }); } void tripleNestedInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv( - [&]() { b.inv([&]() { b.inv([&]() { b.rxx(-0.123, q[0], q[1]); }); }); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.inv(qubits, [&](ValueRange innerQubits) { + b.inv(innerQubits, [&](ValueRange innerInnerQubits) { + b.rxx(-0.123, innerInnerQubits[0], innerInnerQubits[1]); + }); + }); + }); } void invCtrlSandwich(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { - b.ctrl(q[0], [&]() { b.inv([&]() { b.rxx(0.123, q[1], q[2]); }); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.ctrl(qubits[0], {qubits[1], qubits[2]}, [&](ValueRange targets) { + b.inv({targets[0], targets[1]}, [&](ValueRange innerQubits) { + b.rxx(0.123, innerQubits[0], innerQubits[1]); + }); + }); }); } @@ -1395,7 +1528,7 @@ void nestedForLoopCtrlOpWithSeparateQubit(QCProgramBuilder& b) { b.scfFor(0, 3, 1, [&](Value iv) { auto q0 = b.memrefLoad(reg.value, iv); b.h(q0); - b.ctrl(control, [&] { b.x(q0); }); + b.ctrl(control, q0, [&](ValueRange targets) { b.x(targets[0]); }); }); } @@ -1405,7 +1538,7 @@ void nestedForLoopCtrlOpWithExtractedQubit(QCProgramBuilder& b) { b.scfFor(1, 4, 1, [&](Value iv) { auto q0 = b.memrefLoad(reg.value, iv); b.h(q0); - b.ctrl(reg[0], [&] { b.x(q0); }); + b.ctrl(reg[0], q0, [&](ValueRange targets) { b.x(targets[0]); }); }); } From fcdfc1af0b03d587a41a3bddc7f4377a510b2c01 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 29 May 2026 00:50:38 +0200 Subject: [PATCH 02/41] Fix equivalence checking --- mlir/lib/Support/IRVerification.cpp | 45 ++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index eaac426f0a..8c3b413083 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -10,6 +10,7 @@ #include "mlir/Support/IRVerification.h" +#include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QTensor/IR/QTensorUtils.h" #include @@ -469,7 +470,6 @@ static bool areOperationsEquivalent(Operation* lhs, Operation* rhs, if (!rhsConst) { return false; } - if (!areConstantAttributesEquivalent(lhsConst.getValue(), rhsConst.getValue())) { return false; @@ -513,17 +513,38 @@ static bool areOperationsEquivalent(Operation* lhs, Operation* rhs, return false; } - // Check operands according to value mapping - for (auto [lhsOperand, rhsOperand] : - llvm::zip(lhs->getOperands(), rhs->getOperands())) { - if (auto it = valueMap.find(lhsOperand); it != valueMap.end()) { - // Value already mapped, must match - if (it->second != rhsOperand) { + ValueRange lhsOperands; + ValueRange rhsOperands; + // TODO: Extend this + if (auto lhsCtrl = dyn_cast(lhs)) { + auto rhsCtrl = dyn_cast(rhs); + if (!rhsCtrl) { + return false; + } + if (lhsCtrl.getTargets().size() != rhsCtrl.getTargets().size()) { + return false; + } + for (auto [lhsTarget, lhsArg] : + llvm::zip(lhsCtrl.getTargets(), lhsCtrl.getBody()->getArguments())) { + auto rhsTarget = valueMap[lhsTarget]; + if (!llvm::is_contained(rhsCtrl.getTargets(), rhsTarget)) { return false; } - } else { - // Establish new mapping - valueMap[lhsOperand] = rhsOperand; + auto it = llvm::find(rhsCtrl.getTargets(), rhsTarget); + auto index = std::distance(rhsCtrl.getTargets().begin(), it); + valueMap[lhsArg] = rhsCtrl.getBody()->getArgument(index); + } + lhsOperands = lhsCtrl.getControls(); + rhsOperands = rhsCtrl.getControls(); + } else { + lhsOperands = lhs->getOperands(); + rhsOperands = rhs->getOperands(); + } + + // Check operands according to value mapping + for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOperands, rhsOperands)) { + if (!areValuesEquivalent(lhsOperand, rhsOperand, valueMap)) { + return false; } } @@ -725,7 +746,9 @@ static bool areBlocksEquivalent(Block& lhs, Block& rhs, if (lhsArg.getType() != rhsArg.getType()) { return false; } - valueMap[lhsArg] = rhsArg; + if (!valueMap.contains(lhsArg)) { + valueMap[lhsArg] = rhsArg; + } } // Collect all operations From a69cc622013a78f6305494a329d3c7e79b1062b2 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 29 May 2026 15:47:00 +0200 Subject: [PATCH 03/41] Fix linter errors --- mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp | 1 + mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 2 -- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 2 ++ mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 3 +++ mlir/lib/Dialect/QC/IR/QCOps.cpp | 7 +++++++ mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 1 + mlir/lib/Support/IRVerification.cpp | 2 ++ mlir/unittests/programs/qc_programs.cpp | 2 +- 8 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index 001844e1e6..09bb1b0949 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 9c32fc302d..33a2df9217 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1096,7 +1096,6 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - const auto numTargets = op.getNumTargets(); const auto qcControls = op.getControls(); const auto qcTargets = op.getTargets(); auto qcoControls = resolveMappedQubits(state, operation, qcControls); @@ -1154,7 +1153,6 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - const auto numTargets = op.getNumTargets(); const auto qcTargets = op.getTargets(); auto qcoTargets = resolveMappedQubits(state, operation, qcTargets); diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index cf943b1f99..a894891192 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -8,12 +8,14 @@ * Licensed under the MIT License */ +#include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" #include #include #include +#include #include #include #include diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index b935e5c823..3a60bde33f 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -8,10 +8,12 @@ * Licensed under the MIT License */ +#include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" #include +#include #include #include #include @@ -21,6 +23,7 @@ #include #include +#include #include using namespace mlir; diff --git a/mlir/lib/Dialect/QC/IR/QCOps.cpp b/mlir/lib/Dialect/QC/IR/QCOps.cpp index bf6551f924..6a72833861 100644 --- a/mlir/lib/Dialect/QC/IR/QCOps.cpp +++ b/mlir/lib/Dialect/QC/IR/QCOps.cpp @@ -13,6 +13,13 @@ #include "mlir/Dialect/QC/IR/QCDialect.h" // IWYU pragma: associated #include "mlir/Dialect/Utils/Utils.h" +#include +#include +#include +#include +#include +#include + // The following headers are needed for some template instantiations. // IWYU pragma: begin_keep #include diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 1b6a98c07e..bb7a25925f 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Utils/Utils.h" #include +#include #include #include #include diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 8c3b413083..07d723ec6b 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -34,11 +34,13 @@ #include #include #include +#include #include #include #include #include +#include #include using namespace mlir; diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 234c70cc9a..22646d0af6 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -204,7 +204,7 @@ void trivialControlledGlobalPhase(QCProgramBuilder& b) { } void inverseGlobalPhase(QCProgramBuilder& b) { - b.inv({}, [&](ValueRange qubits) { b.gphase(-0.123); }); + b.inv({}, [&](ValueRange /*qubits*/) { b.gphase(-0.123); }); } void inverseMultipleControlledGlobalPhase(QCProgramBuilder& b) { From 342234381b939de6fa5385e888404ebecc20ae85 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 29 May 2026 17:42:29 +0200 Subject: [PATCH 04/41] Add patterns for removing empty modifiers --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 21 +++++++++++++++---- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 21 +++++++++++++++---- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 21 +++++++++++++++---- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 21 +++++++++++++++---- mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp | 6 +++++- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 6 +++++- mlir/unittests/programs/qc_programs.cpp | 12 +++++++++++ mlir/unittests/programs/qc_programs.h | 6 ++++++ mlir/unittests/programs/qco_programs.cpp | 12 +++++++++++ mlir/unittests/programs/qco_programs.h | 6 ++++++ 10 files changed, 114 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index a894891192..b1147a6d81 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -149,6 +149,22 @@ struct ReduceCtrl final : OpRewritePattern { } }; +/** + * @brief Erase control modifiers that do not have any body unitaries. + */ +struct EraseEmptyCtrl final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CtrlOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace size_t CtrlOp::getNumBodyUnitaries() { @@ -211,9 +227,6 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); - } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); @@ -236,5 +249,5 @@ LogicalResult CtrlOp::verify() { void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 3a60bde33f..fa9fdbfca1 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -337,6 +337,22 @@ struct CancelNestedInv final : OpRewritePattern { } }; +/** + * @brief Erase inverse modifiers that do not have any body unitaries. + */ +struct EraseEmptyInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InvOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace size_t InvOp::getNumBodyUnitaries() { @@ -381,9 +397,6 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); - } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); @@ -394,5 +407,5 @@ LogicalResult InvOp::verify() { void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + ReplaceWithKnownGates, EraseEmptyInv>(context); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index e86f3f7dc3..01a9281a45 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -170,6 +170,22 @@ struct ReduceCtrl final : OpRewritePattern { } }; +/** + * @brief Erase control modifiers that do not have any body unitaries. + */ +struct EraseEmptyCtrl final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CtrlOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + } // namespace size_t CtrlOp::getNumBodyUnitaries() { @@ -292,9 +308,6 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); - } const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( @@ -360,7 +373,7 @@ LogicalResult CtrlOp::verify() { void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional CtrlOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index bb7a25925f..9e7c4f5490 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -361,6 +361,22 @@ struct CancelNestedInv final : OpRewritePattern { } }; +/** + * @brief Erase inverse modifiers that do not have any body unitaries. + */ +struct EraseEmptyInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InvOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + } // namespace size_t InvOp::getNumBodyUnitaries() { @@ -437,9 +453,6 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); - } const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( @@ -485,7 +498,7 @@ LogicalResult InvOp::verify() { void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + CancelNestedInv, EraseEmptyInv>(context); } std::optional InvOp::getUnitaryMatrix() { diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index 97e4627363..f0c133086a 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -119,6 +119,8 @@ INSTANTIATE_TEST_SUITE_P( QCCtrlOpTest, QCTest, testing::Values(QCTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, QCTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), MQT_NAMED_BUILDER(multipleControlledRxx)}, QCTestCase{"TripleNestedCtrl", @@ -136,7 +138,9 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCInvOpTest, QCTest, - testing::Values(QCTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), + testing::Values(QCTestCase{"EmptyInv", MQT_NAMED_BUILDER(emptyInv), + MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), MQT_NAMED_BUILDER(rxx)}, QCTestCase{"TripleNestedInv", MQT_NAMED_BUILDER(tripleNestedInv), diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 413f29336d..0be0914f5e 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -223,6 +223,8 @@ INSTANTIATE_TEST_SUITE_P( QCOCtrlOpTest, QCOTest, testing::Values(QCOTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, QCOTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), MQT_NAMED_BUILDER(multipleControlledRxx)}, QCOTestCase{"TripleNestedCtrl", @@ -240,7 +242,9 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCOInvOpTest, QCOTest, - testing::Values(QCOTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), + testing::Values(QCOTestCase{"EmptyInv", MQT_NAMED_BUILDER(emptyInv), + MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), MQT_NAMED_BUILDER(rxx)}, QCOTestCase{"TripleNestedInv", MQT_NAMED_BUILDER(tripleNestedInv), diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 22646d0af6..5287be6150 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -1348,6 +1348,12 @@ void trivialCtrl(QCProgramBuilder& b) { [&](ValueRange targets) { b.rxx(0.123, targets[0], targets[1]); }); } +void emptyCtrl(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.rxx(0.123, q[0], q[1]); + b.ctrl({q[0]}, {q[1]}, [&](ValueRange /*targets*/) {}); +} + void nestedCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { @@ -1391,6 +1397,12 @@ void ctrlInvSandwich(QCProgramBuilder& b) { }); } +void emptyInv(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.rxx(0.123, q[0], q[1]); + b.inv({q[0], q[1]}, [&](ValueRange /*targets*/) {}); +} + void nestedInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); b.inv({q[0], q[1]}, [&](ValueRange qubits) { diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index e6569f7648..eeafac7cac 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -814,6 +814,9 @@ void inverseBarrier(QCProgramBuilder& b); /// Creates a circuit with a trivial ctrl modifier. void trivialCtrl(QCProgramBuilder& b); +/// Creates a circuit with an empty ctrl modifier. +void emptyCtrl(QCProgramBuilder& b); + /// Creates a circuit with nested ctrl modifiers. void nestedCtrl(QCProgramBuilder& b); @@ -828,6 +831,9 @@ void ctrlInvSandwich(QCProgramBuilder& b); // --- InvOp ---------------------------------------------------------------- // +/// Creates a circuit with an empty inverse modifier. +void emptyInv(QCProgramBuilder& b); + /// Creates a circuit with nested inverse modifiers. void nestedInv(QCProgramBuilder& b); diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 0ad96fbb10..a985ecb11a 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -1936,6 +1936,12 @@ void trivialCtrl(QCOProgramBuilder& b) { }); } +void emptyCtrl(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + std::tie(q[0], q[1]) = b.rxx(0.123, q[0], q[1]); + b.ctrl(q[0], q[1], [&](ValueRange targets) { return targets; }); +} + void nestedCtrl(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(4); b.ctrl({q[0]}, {q[1], q[2], q[3]}, [&](ValueRange targets) { @@ -2003,6 +2009,12 @@ void ctrlInvSandwich(QCOProgramBuilder& b) { }); } +void emptyInv(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + std::tie(q[0], q[1]) = b.rxx(0.123, q[0], q[1]); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { return qubits; }); +} + void nestedInv(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(2); b.inv({q[0], q[1]}, [&](ValueRange qubits) { diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index b4197c5a7f..a8659701a8 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -960,6 +960,9 @@ void twoBarrier(QCOProgramBuilder& b); /// Creates a circuit with a trivial ctrl modifier. void trivialCtrl(QCOProgramBuilder& b); +/// Creates a circuit with an empty ctrl modifier. +void emptyCtrl(QCOProgramBuilder& b); + /// Creates a circuit with nested ctrl modifiers. void nestedCtrl(QCOProgramBuilder& b); @@ -974,6 +977,9 @@ void ctrlInvSandwich(QCOProgramBuilder& b); // --- InvOp ---------------------------------------------------------------- // +/// Creates a circuit with an empty inverse modifier. +void emptyInv(QCOProgramBuilder& b); + /// Creates a circuit with nested inverse modifiers. void nestedInv(QCOProgramBuilder& b); From c4e28828f90f6848a73e351a60e64ded1eab9e84 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 29 May 2026 18:11:08 +0200 Subject: [PATCH 05/41] Add test cases --- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 8 +- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 10 ++- .../Conversion/QCOToQC/test_qco_to_qc.cpp | 17 +++- .../Conversion/QCToQCO/test_qc_to_qco.cpp | 22 ++++- mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp | 31 +++---- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 39 +++++---- mlir/unittests/programs/qc_programs.cpp | 46 ++++++++++ mlir/unittests/programs/qc_programs.h | 17 ++++ mlir/unittests/programs/qco_programs.cpp | 84 +++++++++++++++++++ mlir/unittests/programs/qco_programs.h | 26 +++++- 10 files changed, 256 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 01a9281a45..2dba84815d 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -81,10 +81,12 @@ struct MergeNestedCtrl final : OpRewritePattern { IRMapping mapping; utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, targets, targetArgs); - SmallVector yields; for (auto& op : innerCtrlBody->without_terminator()) { - auto results = rewriter.clone(op, mapping)->getResults(); - llvm::append_range(yields, results); + rewriter.clone(op, mapping); + } + SmallVector yields; + for (auto value : innerCtrlBody->getTerminator()->getOperands()) { + yields.push_back(mapping.lookup(value)); } return yields; }); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 9e7c4f5490..872a8a8902 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -70,11 +70,13 @@ struct MoveCtrlOutside final : OpRewritePattern { utils::prova(*innerCtrlBody, mapping, innerCtrlOp.getTargetsIn(), outerQubits, targets, qubitArgs); - SmallVector yields; for (auto& op : innerCtrlBody->without_terminator()) { - auto results = - rewriter.clone(op, mapping)->getResults(); - llvm::append_range(yields, results); + rewriter.clone(op, mapping); + } + SmallVector yields; + for (auto value : + innerCtrlBody->getTerminator()->getOperands()) { + yields.push_back(mapping.lookup(value)); } return yields; }) diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index 4bd2b24615..aa3a428809 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -144,6 +144,17 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::allocDeallocPair)})); /// @} +/// \name QCOToQC/Modifiers/CtrlOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P( + QCOCtrlOpTest, QCOToQCTest, + testing::Values(QCOToQCTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qco::ctrlTwo), + MQT_NAMED_BUILDER(qc::ctrlTwo)}, + QCOToQCTestCase{"CtrlInvTwo", + MQT_NAMED_BUILDER(qco::ctrlInvTwo), + MQT_NAMED_BUILDER(qc::ctrlInvTwo)})); +/// @} + /// \name QCOToQC/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( @@ -160,7 +171,11 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::dcx)}, QCOToQCTestCase{"InverseMultipleControlledDCX", MQT_NAMED_BUILDER(qco::inverseMultipleControlledDcx), - MQT_NAMED_BUILDER(qc::multipleControlledDcx)})); + MQT_NAMED_BUILDER(qc::multipleControlledDcx)}, + QCOToQCTestCase{"InvTwo", MQT_NAMED_BUILDER(qco::invTwo), + MQT_NAMED_BUILDER(qc::invTwo)}, + QCOToQCTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(qco::invCtrlTwo), + MQT_NAMED_BUILDER(qc::ctrlInvTwo)})); /// @} /// \name QCOToQC/Operations/StandardGates/BarrierOp.cpp diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index 3f0df25542..71f47b0841 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -143,6 +143,17 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qco::allocSinkPair)})); /// @} +/// \name QCToQCO/Modifiers/CtrlOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P( + QCCtrlOpTest, QCToQCOTest, + testing::Values(QCToQCOTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + MQT_NAMED_BUILDER(qco::ctrlTwo)}, + QCToQCOTestCase{"CtrlInvTwo", + MQT_NAMED_BUILDER(qc::ctrlInvTwo), + MQT_NAMED_BUILDER(qco::ctrlInvTwo)})); +/// @} + /// \name QCToQCO/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( @@ -151,10 +162,13 @@ INSTANTIATE_TEST_SUITE_P( // iSWAP cannot be inverted with current canonicalization QCToQCOTestCase{"InverseiSWAP", MQT_NAMED_BUILDER(qc::inverseIswap), MQT_NAMED_BUILDER(qco::inverseIswap)}, - QCToQCOTestCase{ - "InverseMultipleControllediSWAP", - MQT_NAMED_BUILDER(qc::inverseMultipleControlledIswap), - MQT_NAMED_BUILDER(qco::inverseMultipleControlledIswap)})); + QCToQCOTestCase{"InverseMultipleControllediSWAP", + MQT_NAMED_BUILDER(qc::inverseMultipleControlledIswap), + MQT_NAMED_BUILDER(qco::inverseMultipleControlledIswap)}, + QCToQCOTestCase{"InvTwo", MQT_NAMED_BUILDER(qc::invTwo), + MQT_NAMED_BUILDER(qco::invTwo)}, + QCToQCOTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(qc::invCtrlTwo), + MQT_NAMED_BUILDER(qco::ctrlInvTwo)})); /// @} /// \name QCToQCO/Operations/StandardGates/BarrierOp.cpp diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index f0c133086a..4d0f56912b 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -117,21 +117,22 @@ TEST_F(QCTest, BuilderRejectsMixedStaticAndDynamicQubitAllocationModes) { /// @{ INSTANTIATE_TEST_SUITE_P( QCCtrlOpTest, QCTest, - testing::Values(QCTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCTestCase{"TripleNestedCtrl", - MQT_NAMED_BUILDER(tripleNestedCtrl), - MQT_NAMED_BUILDER(tripleControlledRxx)}, - QCTestCase{"CtrlInvSandwich", - MQT_NAMED_BUILDER(ctrlInvSandwich), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCTestCase{"DoubleNestedCtrlTwoQubits", - MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), - MQT_NAMED_BUILDER(fourControlledRxx)})); + testing::Values( + QCTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCTestCase{"TripleNestedCtrl", MQT_NAMED_BUILDER(tripleNestedCtrl), + MQT_NAMED_BUILDER(tripleControlledRxx)}, + QCTestCase{"CtrlInvSandwich", MQT_NAMED_BUILDER(ctrlInvSandwich), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCTestCase{"DoubleNestedCtrlTwoQubits", + MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), + MQT_NAMED_BUILDER(fourControlledRxx)}, + QCTestCase{"NestedCtrlTwo", MQT_NAMED_BUILDER(nestedCtrlTwo), + MQT_NAMED_BUILDER(ctrlTwo)})); /// @} /// \name QC/Modifiers/InvOp.cpp diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 0be0914f5e..4a8fb691ad 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -221,21 +221,22 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCOCtrlOpTest, QCOTest, - testing::Values(QCOTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCOTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCOTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCOTestCase{"TripleNestedCtrl", - MQT_NAMED_BUILDER(tripleNestedCtrl), - MQT_NAMED_BUILDER(tripleControlledRxx)}, - QCOTestCase{"CtrlInvSandwich", - MQT_NAMED_BUILDER(ctrlInvSandwich), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCOTestCase{"DoubleNestedCtrlTwoQubits", - MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), - MQT_NAMED_BUILDER(fourControlledRxx)})); + testing::Values( + QCOTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCOTestCase{"TripleNestedCtrl", MQT_NAMED_BUILDER(tripleNestedCtrl), + MQT_NAMED_BUILDER(tripleControlledRxx)}, + QCOTestCase{"CtrlInvSandwich", MQT_NAMED_BUILDER(ctrlInvSandwich), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCOTestCase{"DoubleNestedCtrlTwoQubits", + MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), + MQT_NAMED_BUILDER(fourControlledRxx)}, + QCOTestCase{"NestedCtrlTwo", MQT_NAMED_BUILDER(nestedCtrlTwo), + MQT_NAMED_BUILDER(ctrlTwo)})); /// @} /// \name QCO/Modifiers/InvOp.cpp @@ -251,7 +252,9 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(rxx)}, QCOTestCase{"InvControlSandwich", MQT_NAMED_BUILDER(invCtrlSandwich), - MQT_NAMED_BUILDER(singleControlledRxx)})); + MQT_NAMED_BUILDER(singleControlledRxx)}, + QCOTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(invCtrlTwo), + MQT_NAMED_BUILDER(ctrlInvTwo)})); /// @} /// \name QCO/Operations/StandardGates/BarrierOp.cpp @@ -963,6 +966,10 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(inverseMultipleControlledX), MQT_NAMED_BUILDER(multipleControlledX)}, QCOTestCase{"TwoX", MQT_NAMED_BUILDER(twoX), + MQT_NAMED_BUILDER(emptyQCO)}, + QCOTestCase{"ControlledTwoX", MQT_NAMED_BUILDER(controlledTwoX), + MQT_NAMED_BUILDER(emptyQCO)}, + QCOTestCase{"inverseTwoX", MQT_NAMED_BUILDER(twoX), MQT_NAMED_BUILDER(emptyQCO)})); /// @} diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 5287be6150..5b10a98af4 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -1397,6 +1397,34 @@ void ctrlInvSandwich(QCProgramBuilder& b) { }); } +void ctrlTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + b.x(targets[0]); + b.rxx(0.123, targets[0], targets[1]); + }); +} + +void nestedCtrlTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + b.ctrl(targets[0], {targets[1], targets[2]}, [&](ValueRange innerTargets) { + b.x(innerTargets[0]); + b.rxx(0.123, innerTargets[0], innerTargets[1]); + }); + }); +} + +void ctrlInvTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.ctrl(q[0], {q[1], q[2]}, [&](ValueRange targets) { + b.inv(targets, [&](ValueRange qubits) { + b.x(qubits[0]); + b.rxx(0.123, qubits[0], qubits[1]); + }); + }); +} + void emptyInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); b.rxx(0.123, q[0], q[1]); @@ -1434,6 +1462,24 @@ void invCtrlSandwich(QCProgramBuilder& b) { }); } +void invTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.x(qubits[0]); + b.rxx(0.123, qubits[0], qubits[1]); + }); +} + +void invCtrlTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.ctrl(qubits[0], {qubits[1], qubits[2]}, [&](ValueRange targets) { + b.x(targets[0]); + b.rxx(0.123, targets[0], targets[1]); + }); + }); +} + void simpleIf(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.h(q[0]); diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index eeafac7cac..14114b1ee2 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -829,6 +829,16 @@ void doubleNestedCtrlTwoQubits(QCProgramBuilder& b); /// Creates a circuit with control modifiers interleaved by an inverse modifier. void ctrlInvSandwich(QCProgramBuilder& b); +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QCProgramBuilder& b); + +/// Creates a circuit with nested control modifiers applied to two gates. +void nestedCtrlTwo(QCProgramBuilder& b); + +/// Creates a circuit with a control modifier applied to a inverse modifier +/// applied to two gates. +void ctrlInvTwo(QCProgramBuilder& b); + // --- InvOp ---------------------------------------------------------------- // /// Creates a circuit with an empty inverse modifier. @@ -843,6 +853,13 @@ void tripleNestedInv(QCProgramBuilder& b); /// Creates a circuit with inverse modifiers interleaved by a control modifier. void invCtrlSandwich(QCProgramBuilder& b); +/// Creates a circuit with an inverse modifier applied to two gates. +void invTwo(QCProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to a control modifier +/// applied to two gates. +void invCtrlTwo(QCProgramBuilder& b); + // --- IfOp ----------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index a985ecb11a..868e3a2a3b 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -301,6 +301,24 @@ void twoX(QCOProgramBuilder& b) { q[0] = b.x(q[0]); } +void controlledTwoX(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.ctrl(q[0], q[1], [&](ValueRange targets) { + auto q = b.x(targets[0]); + q = b.x(q); + return SmallVector{q}; + }); +} + +void inverseTwoX(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.inv(q[0], [&](ValueRange qubits) { + auto q = b.x(qubits[0]); + q = b.x(q); + return SmallVector{q}; + }); +} + void y(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.y(q[0]); @@ -2009,6 +2027,46 @@ void ctrlInvSandwich(QCOProgramBuilder& b) { }); } +void ctrlTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + auto i0 = targets[0]; + auto i1 = targets[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); +} + +void nestedCtrlTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + const auto& [controlsOut, targetsOut] = b.ctrl( + targets[0], {targets[1], targets[2]}, [&](ValueRange innerTargets) { + auto i0 = innerTargets[0]; + auto i1 = innerTargets[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); + return llvm::to_vector(llvm::concat(controlsOut, targetsOut)); + }); +} + +void ctrlInvTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.ctrl(q[0], {q[1], q[2]}, [&](ValueRange targets) { + auto inner = b.inv(targets, [&](ValueRange qubits) { + auto i0 = qubits[0]; + auto i1 = qubits[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); + return llvm::to_vector(inner); + }); +} + void emptyInv(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(2); std::tie(q[0], q[1]) = b.rxx(0.123, q[0], q[1]); @@ -2058,6 +2116,32 @@ void invCtrlSandwich(QCOProgramBuilder& b) { }); } +void invTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + auto i0 = qubits[0]; + auto i1 = qubits[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); +} + +void invCtrlTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + const auto& [controlsOut, targetsOut] = + b.ctrl({qubits[0]}, {qubits[1], qubits[2]}, [&](ValueRange targets) { + auto i0 = targets[0]; + auto i1 = targets[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); + return llvm::to_vector(llvm::concat(controlsOut, targetsOut)); + }); +} + void simpleIf(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(1); auto q0 = b.h(q[0]); diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index a8659701a8..1ec606d103 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -167,9 +167,16 @@ void inverseX(QCOProgramBuilder& b); /// Creates a circuit with an inverse modifier applied to a controlled X gate. void inverseMultipleControlledX(QCOProgramBuilder& b); -/// Creates a circuit with two X gates in a row. +/// Creates a circuit with two subsequent X gates. void twoX(QCOProgramBuilder& b); +/// Creates a circuit with a control modifier applied to two subsequent X gates. +void controlledTwoX(QCOProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to two subsequent X +/// gates. +void inverseTwoX(QCOProgramBuilder& b); + // --- YOp ------------------------------------------------------------------ // /// Creates a circuit with just a Y gate. @@ -975,6 +982,16 @@ void doubleNestedCtrlTwoQubits(QCOProgramBuilder& b); /// Creates a circuit with control modifiers interleaved by an inverse modifier. void ctrlInvSandwich(QCOProgramBuilder& b); +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QCOProgramBuilder& b); + +/// Creates a circuit with nested control modifiers applied to two gates. +void nestedCtrlTwo(QCOProgramBuilder& b); + +/// Creates a circuit with a control modifier applied to an inverse modifier +/// applied to two gates. +void ctrlInvTwo(QCOProgramBuilder& b); + // --- InvOp ---------------------------------------------------------------- // /// Creates a circuit with an empty inverse modifier. @@ -989,6 +1006,13 @@ void tripleNestedInv(QCOProgramBuilder& b); /// Creates a circuit with inverse modifiers interleaved by a control modifier. void invCtrlSandwich(QCOProgramBuilder& b); +/// Creates a circuit with an inverse modifier applied to two gates. +void invTwo(QCOProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to a control modifier +/// applied to two gates. +void invCtrlTwo(QCOProgramBuilder& b); + // --- IfOp ---------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. From 60249614384b009d57ce84da637369db1d88ab0f Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:22:32 +0200 Subject: [PATCH 06/41] Add support for translating CompoundOperations --- .../TranslateQuantumComputationToQC.cpp | 176 +++++++++++++----- .../test_quantum_computation_translation.cpp | 6 + mlir/unittests/programs/qc_programs.cpp | 14 +- mlir/unittests/programs/qc_programs.h | 6 +- .../programs/quantum_computation_programs.cpp | 25 +++ .../programs/quantum_computation_programs.h | 8 + 6 files changed, 178 insertions(+), 57 deletions(-) diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index 66d562ae82..727679d223 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -73,6 +73,27 @@ struct TranslationState { /// Whether the translation is currently processing an IfElseOperation bool inIfElse = false; + + /// Whether the translation is currently within a control modifier + bool inCtrlOp = false; + + /// Mapping from physical qubit index to block argument + DenseMap ctrlTargets{}; + + Value getQubit(size_t index) const { + if (!inCtrlOp) { + if (index >= qubits.size()) { + llvm::reportFatalInternalError("Qubit index out of bounds"); + } + return qubits[index]; + } else { + auto it = ctrlTargets.find(index); + if (it == ctrlTargets.end()) { + llvm::reportFatalInternalError("Qubit index out of bounds"); + } + return it->second; + } + }; }; } // namespace @@ -222,7 +243,7 @@ static void addMeasureOp(QCProgramBuilder& builder, const auto& classics = measureOp.getClassics(); for (size_t i = 0; i < targets.size(); ++i) { - const auto& qubit = state.qubits[targets[i]]; + const auto& qubit = state.getQubit(targets[i]); const auto bitIdx = static_cast(classics[i]); const auto& [mem, localIdx] = state.bitMap[bitIdx]; const auto& bit = mem[static_cast(localIdx)]; @@ -239,13 +260,13 @@ static void addMeasureOp(QCProgramBuilder& builder, * * @param builder The QCProgramBuilder used to create operations * @param operation The reset operation to translate - * @param qubits Flat vector of qubit values indexed by physical qubit index + * @param state The translation state */ static void addResetOp(QCProgramBuilder& builder, const ::qc::Operation& operation, - const SmallVector& qubits) { + TranslationState& state) { for (const auto& target : operation.getTargets()) { - auto qubit = qubits[target]; + auto qubit = state.getQubit(target); builder.reset(qubit); } } @@ -258,18 +279,21 @@ static void addResetOp(QCProgramBuilder& builder, * the qubit values corresponding to positive controls. * * @param operation The operation containing controls - * @param qubits Flat vector of qubit values indexed by physical qubit index + * @param state The translation state * @return Vector of qubit values corresponding to positive controls */ static SmallVector getControls(const ::qc::Operation& operation, - const SmallVector& qubits) { + TranslationState& state) { + if (state.inCtrlOp) { + return {}; + } SmallVector controls; for (const auto& [control, type] : operation.getControls()) { if (type == ::qc::Control::Type::Neg) { llvm::reportFatalInternalError( "Negative controls cannot be translated to QC at the moment"); } - controls.push_back(qubits[control]); + controls.push_back(state.getQubit(control)); } return controls; } @@ -286,13 +310,13 @@ static SmallVector getControls(const ::qc::Operation& operation, * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + TranslationState& state) { \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(target); \ } else { \ @@ -326,14 +350,14 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdg, sxdg) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param = operation.getParameter()[0]; \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param, target); \ } else { \ @@ -358,15 +382,15 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(P, p) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param1 = operation.getParameter()[0]; \ const auto& param2 = operation.getParameter()[1]; \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param1, param2, target); \ } else { \ @@ -391,16 +415,16 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2, u2) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param1 = operation.getParameter()[0]; \ const auto& param2 = operation.getParameter()[1]; \ const auto& param3 = operation.getParameter()[2]; \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param1, param2, param3, target); \ } else { \ @@ -424,14 +448,14 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(U, u) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ - const auto& target0 = qubits[operation.getTargets()[0]]; \ - const auto& target1 = qubits[operation.getTargets()[1]]; \ - if (const auto& controls = getControls(operation, qubits); \ + TranslationState& state) { \ + const auto& target0 = state.getQubit(operation.getTargets()[0]); \ + const auto& target1 = state.getQubit(operation.getTargets()[1]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(target0, target1); \ } else { \ @@ -448,10 +472,10 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECR, ecr) static void addISWAPdgOp(QCProgramBuilder& builder, const ::qc::Operation& operation, - const SmallVector& qubits) { - auto target0 = qubits[operation.getTargets()[0]]; - auto target1 = qubits[operation.getTargets()[1]]; - if (const auto& controls = getControls(operation, qubits); controls.empty()) { + TranslationState& state) { + auto target0 = state.getQubit(operation.getTargets()[0]); + auto target1 = state.getQubit(operation.getTargets()[1]); + if (const auto& controls = getControls(operation, state); controls.empty()) { builder.inv({target0, target1}, [&](ValueRange qubits) { builder.iswap(qubits[0], qubits[1]); }); @@ -476,15 +500,15 @@ static void addISWAPdgOp(QCProgramBuilder& builder, * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param = operation.getParameter()[0]; \ - const auto& target0 = qubits[operation.getTargets()[0]]; \ - const auto& target1 = qubits[operation.getTargets()[1]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target0 = state.getQubit(operation.getTargets()[0]); \ + const auto& target1 = state.getQubit(operation.getTargets()[1]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param, target0, target1); \ } else { \ @@ -511,16 +535,16 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZ, rzz) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param1 = operation.getParameter()[0]; \ const auto& param2 = operation.getParameter()[1]; \ - const auto& target0 = qubits[operation.getTargets()[0]]; \ - const auto& target1 = qubits[operation.getTargets()[1]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target0 = state.getQubit(operation.getTargets()[0]); \ + const auto& target1 = state.getQubit(operation.getTargets()[1]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param1, param2, target0, target1); \ } else { \ @@ -537,10 +561,10 @@ DEFINE_TWO_TARGET_TWO_PARAMETER(XXminusYY, xx_minus_yy) static void addBarrierOp(QCProgramBuilder& builder, const ::qc::Operation& operation, - const SmallVector& qubits) { + TranslationState& state) { SmallVector targets; for (const auto& targetIdx : operation.getTargets()) { - targets.push_back(qubits[targetIdx]); + targets.push_back(state.getQubit(targetIdx)); } builder.barrier(targets); } @@ -550,6 +574,60 @@ static LogicalResult translateOperation(QCProgramBuilder& builder, const ::qc::Operation& operation, TranslationState& state); +// CompoundOp + +static LogicalResult addCompoundOp(QCProgramBuilder& builder, + const ::qc::Operation& operation, + TranslationState& state) { + const auto& compoundOp = + dynamic_cast(operation); + if (const auto& controls = getControls(operation, state); controls.empty()) { + for (const auto& op : compoundOp) { + if (failed(translateOperation(builder, *op, state))) { + return failure(); + } + } + } else { + // Collect targets + DenseMap targetMap; + for (const auto& op : compoundOp) { + if (dynamic_cast(op.get()) != nullptr) { + llvm::reportFatalInternalError("Nested CompoundOperations cannot be " + "translated to QC at the moment"); + } + for (const auto& target : op->getTargets()) { + if (!targetMap.contains(target)) { + targetMap[target] = state.getQubit(target); + } + } + } + SmallVector> sortedPairs(targetMap.begin(), + targetMap.end()); + std::sort(sortedPairs.begin(), sortedPairs.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + SmallVector targets; + for (const auto& pair : sortedPairs) { + targets.push_back(pair.second); + } + // Build control modifier + builder.ctrl(controls, targets, [&](ValueRange targetArgs) { + state.inCtrlOp = true; + for (size_t i = 0; i < sortedPairs.size(); ++i) { + state.ctrlTargets[sortedPairs[i].first] = targetArgs[i]; + } + for (const auto& op : compoundOp) { + if (failed(translateOperation(builder, *op, state))) { + llvm::reportFatalInternalError("Failed to translate operation inside " + "controlled CompoundOperation"); + } + } + state.ctrlTargets.clear(); + state.inCtrlOp = false; + }); + } + return success(); +} + // IfElseOp static LogicalResult addIfElseOp(QCProgramBuilder& builder, @@ -626,7 +704,7 @@ static LogicalResult addIfElseOp(QCProgramBuilder& builder, #define ADD_OP_CASE(OP_CORE) \ case ::qc::OpType::OP_CORE: \ - add##OP_CORE##Op(builder, operation, qubits); \ + add##OP_CORE##Op(builder, operation, state); \ return success(); /** @@ -640,7 +718,6 @@ static LogicalResult addIfElseOp(QCProgramBuilder& builder, static LogicalResult translateOperation(QCProgramBuilder& builder, const ::qc::Operation& operation, TranslationState& state) { - const auto& qubits = state.qubits; switch (operation.getType()) { case ::qc::OpType::Measure: addMeasureOp(builder, operation, state); @@ -676,7 +753,12 @@ static LogicalResult translateOperation(QCProgramBuilder& builder, ADD_OP_CASE(XXminusYY) ADD_OP_CASE(Barrier) case ::qc::OpType::iSWAPdg: - addISWAPdgOp(builder, operation, qubits); + addISWAPdgOp(builder, operation, state); + return success(); + case ::qc::OpType::Compound: + if (failed(addCompoundOp(builder, operation, state))) { + return failure(); + } return success(); case ::qc::OpType::IfElse: if (failed(addIfElseOp(builder, operation, state))) { diff --git a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp index 0e5d53783b..893902448b 100644 --- a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp +++ b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp @@ -418,9 +418,15 @@ INSTANTIATE_TEST_SUITE_P( "BarrierMultipleQubits", MQT_NAMED_BUILDER(qc::barrierMultipleQubits), MQT_NAMED_BUILDER(mlir::qc::barrierMultipleQubits)}, + QuantumComputationTranslationTestCase{ + "CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + MQT_NAMED_BUILDER(mlir::qc::ctrlTwo)}, QuantumComputationTranslationTestCase{ "SimpleIf", MQT_NAMED_BUILDER(qc::simpleIf), MQT_NAMED_BUILDER(mlir::qc::simpleIf)}, + QuantumComputationTranslationTestCase{ + "IfTwoQubits", MQT_NAMED_BUILDER(qc::ifTwoQubits), + MQT_NAMED_BUILDER(mlir::qc::ifTwoQubits)}, QuantumComputationTranslationTestCase{ "IfElse", MQT_NAMED_BUILDER(qc::ifElse), MQT_NAMED_BUILDER(mlir::qc::ifElse)})); diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 5b10a98af4..7c7608963d 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -1487,13 +1487,6 @@ void simpleIf(QCProgramBuilder& b) { b.scfIf(cond, [&] { b.x(q[0]); }); } -void ifElse(QCProgramBuilder& b) { - auto q = b.allocQubitRegister(1); - b.h(q[0]); - auto cond = b.measure(q[0]); - b.scfIf(cond, [&] { b.x(q[0]); }, [&] { b.z(q[0]); }); -} - void ifTwoQubits(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); b.h(q[0]); @@ -1504,6 +1497,13 @@ void ifTwoQubits(QCProgramBuilder& b) { }); } +void ifElse(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.h(q[0]); + auto cond = b.measure(q[0]); + b.scfIf(cond, [&] { b.x(q[0]); }, [&] { b.z(q[0]); }); +} + void nestedIfOpForLoop(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); auto q0 = b.allocQubit(); diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index 14114b1ee2..2f08c5236e 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -865,12 +865,12 @@ void invCtrlTwo(QCProgramBuilder& b); /// Creates a circuit with a simple if operation with one qubit. void simpleIf(QCProgramBuilder& b); -/// Creates a circuit with an if operation with an else branch. -void ifElse(QCProgramBuilder& b); - /// Creates a circuit with an if operation with two qubits. void ifTwoQubits(QCProgramBuilder& b); +/// Creates a circuit with an if operation with an else branch. +void ifElse(QCProgramBuilder& b); + /// Creates a circuit with an if operation with a nested for operation with /// a register. void nestedIfOpForLoop(QCProgramBuilder& b); diff --git a/mlir/unittests/programs/quantum_computation_programs.cpp b/mlir/unittests/programs/quantum_computation_programs.cpp index 719fd50b17..db12525029 100644 --- a/mlir/unittests/programs/quantum_computation_programs.cpp +++ b/mlir/unittests/programs/quantum_computation_programs.cpp @@ -15,6 +15,7 @@ #include "ir/operations/StandardOperation.hpp" #include +#include namespace qc { @@ -538,6 +539,17 @@ void barrierMultipleQubits(QuantumComputation& comp) { comp.barrier({0, 1, 2}); } +void ctrlTwo(QuantumComputation& comp) { + const auto& q = comp.addQubitRegister(4, "q"); + CompoundOperation compound; + compound.emplace_back(2, X); + compound.emplace_back(Targets{2, 3}, RXX, + std::vector{0.123}); + compound.addControl(0); + compound.addControl(1); + comp.emplace_back(std::move(compound)); +} + void simpleIf(QuantumComputation& comp) { const auto& q = comp.addQubitRegister(1, "q"); const auto& c = comp.addClassicalRegister(1, "c"); @@ -546,6 +558,19 @@ void simpleIf(QuantumComputation& comp) { comp.if_(X, q[0], c[0]); } +void ifTwoQubits(QuantumComputation& comp) { + const auto& q = comp.addQubitRegister(2, "q"); + const auto& c = comp.addClassicalRegister(1, "c"); + comp.h(q[0]); + comp.measure(q[0], c[0]); + CompoundOperation compound; + compound.emplace_back(0, X); + compound.emplace_back(1, X); + IfElseOperation ifElse( + std::make_unique(std::move(compound)), nullptr, c[0]); + comp.emplace_back(std::move(ifElse)); +} + void ifElse(QuantumComputation& comp) { const auto& q = comp.addQubitRegister(1, "q"); const auto& c = comp.addClassicalRegister(1, "c"); diff --git a/mlir/unittests/programs/quantum_computation_programs.h b/mlir/unittests/programs/quantum_computation_programs.h index b21bba30d5..f0e1856d8f 100644 --- a/mlir/unittests/programs/quantum_computation_programs.h +++ b/mlir/unittests/programs/quantum_computation_programs.h @@ -385,11 +385,19 @@ void barrierTwoQubits(QuantumComputation& comp); /// Creates a circuit with a barrier on multiple qubits. void barrierMultipleQubits(QuantumComputation& comp); +// --- CtrlOp --------------------------------------------------------------- // + +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QuantumComputation& comp); + // --- IfOp ----------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. void simpleIf(QuantumComputation& comp); +/// Creates a circuit with an if operation with two qubits. +void ifTwoQubits(QuantumComputation& comp); + /// Creates a circuit with an if operation with an else branch. void ifElse(QuantumComputation& comp); From ccec22f71905f9612b58fae2da973c1f5a92ffe0 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 16:57:11 +0200 Subject: [PATCH 07/41] Fix QC-to-QIR conversion --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 24 ++++++++++++------- .../Compiler/test_compiler_pipeline.cpp | 5 +++- .../Conversion/QCToQIR/test_qc_to_qir.cpp | 8 +++++++ mlir/unittests/programs/qir_programs.cpp | 6 +++++ mlir/unittests/programs/qir_programs.h | 5 ++++ 5 files changed, 38 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 6c432f57d2..832820ca65 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -93,7 +93,7 @@ struct LoweringState : QIRMetadata { /// Modifier information int64_t inCtrlOp = 0; - DenseMap> controls; + SmallVector controls; /// Allocator and StringSaver for stable StringRefs llvm::BumpPtrAllocator allocator; @@ -174,7 +174,7 @@ convertUnitaryToCallOp(QCOpType& op, QCOpAdaptorType& adaptor, // Query state for modifier information const auto inCtrlOp = state.inCtrlOp; const SmallVector controls = - inCtrlOp != 0 ? state.controls[inCtrlOp] : SmallVector{}; + inCtrlOp != 0 ? state.controls : SmallVector{}; const size_t numCtrls = controls.size(); // Define argument types @@ -209,9 +209,9 @@ convertUnitaryToCallOp(QCOpType& op, QCOpAdaptorType& adaptor, operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); // Clean up modifier information - if (inCtrlOp != 0) { - state.controls.erase(inCtrlOp); - state.inCtrlOp--; + state.inCtrlOp--; + if (inCtrlOp == 0) { + state.controls.clear(); } // Replace operation with CallOp @@ -315,7 +315,7 @@ struct ConvertQCUnitaryOpQIR : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); const auto inCtrlOp = state.inCtrlOp; - const size_t numCtrls = inCtrlOp != 0 ? state.controls[inCtrlOp].size() : 0; + const size_t numCtrls = inCtrlOp != 0 ? state.controls.size() : 0; const auto fnName = GetFnName(numCtrls); return convertUnitaryToCallOp(op, adaptor, rewriter, this->getContext(), state, fnName, NumTargets, NumParams); @@ -863,12 +863,18 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(CtrlOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // Update modifier information auto& state = getState(); - state.inCtrlOp++; + + if (state.inCtrlOp != 0) { + return rewriter.notifyMatchFailure(op, + "Nested CtrlOps are not supported"); + } + + // Update modifier information + state.inCtrlOp = op.getNumBodyUnitaries(); const SmallVector controls(adaptor.getControls().begin(), adaptor.getControls().end()); - state.controls[state.inCtrlOp] = controls; + state.controls = controls; // Inline block and remove operation rewriter.inlineBlockBefore(&op.getRegion().front(), op, diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index 44618eedae..d9072d6630 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -686,6 +686,9 @@ INSTANTIATE_TEST_SUITE_P( "MultipleControlledXXMinusYY", MQT_NAMED_BUILDER(qc::multipleControlledXxMinusYY), nullptr, MQT_NAMED_BUILDER(mlir::qc::multipleControlledXxMinusYY), - MQT_NAMED_BUILDER(mlir::qir::multipleControlledXxMinusYY)})); + MQT_NAMED_BUILDER(mlir::qir::multipleControlledXxMinusYY)}, + CompilerPipelineTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + nullptr, MQT_NAMED_BUILDER(mlir::qc::ctrlTwo), + MQT_NAMED_BUILDER(mlir::qir::ctrlTwo)})); } // namespace mqt::test::compiler diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index cd8bcf6073..6de8bf8483 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -649,3 +649,11 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::allocDeallocPair), MQT_NAMED_BUILDER(qir::emptyQIR)})); /// @} + +/// \name QCToQIR/Modifiers/CtrlOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P(QCToQIRCtrlOpTest, QCToQIRTest, + testing::Values(QCToQIRTestCase{ + "NestedCtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + MQT_NAMED_BUILDER(qir::ctrlTwo)})); +/// @} diff --git a/mlir/unittests/programs/qir_programs.cpp b/mlir/unittests/programs/qir_programs.cpp index 6ae5023a09..68f209943f 100644 --- a/mlir/unittests/programs/qir_programs.cpp +++ b/mlir/unittests/programs/qir_programs.cpp @@ -605,4 +605,10 @@ void multipleControlledXxMinusYY(QIRProgramBuilder& b) { b.mcxx_minus_yy(0.123, 0.456, {q[0], q[1]}, q[2], q[3]); } +void ctrlTwo(QIRProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.mcx({q[0], q[1]}, q[2]); + b.mcrxx(0.123, {q[0], q[1]}, q[2], q[3]); +} + } // namespace mlir::qir diff --git a/mlir/unittests/programs/qir_programs.h b/mlir/unittests/programs/qir_programs.h index 92f6c54078..86a7f7c807 100644 --- a/mlir/unittests/programs/qir_programs.h +++ b/mlir/unittests/programs/qir_programs.h @@ -422,4 +422,9 @@ void singleControlledXxMinusYY(QIRProgramBuilder& b); /// Creates a circuit with a multi-controlled XXMinusYY gate. void multipleControlledXxMinusYY(QIRProgramBuilder& b); +// --- CtrlOp --------------------------------------------------------------- // + +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QIRProgramBuilder& b); + } // namespace mlir::qir From e5eb892d884a3ea0c556a718700fa275d538fde8 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:10:01 +0200 Subject: [PATCH 08/41] Resolve TODO comments --- mlir/include/mlir/Dialect/Utils/Utils.h | 20 ++++++++++++++------ mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 4 ++-- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 5 +++-- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 18 ++---------------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 16 +++------------- mlir/lib/Support/IRVerification.cpp | 1 - 6 files changed, 24 insertions(+), 40 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 546ecc479c..072c2c2368 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -161,7 +161,10 @@ static void printTargetAliasing(OpAsmPrinter& printer, Region& region, printer.printRegion(region, false); } -// TODO: Document +/** + * @brief Get the value corresponding to @p qubit from the block arguments @p + * qubits if @p qubit is a block argument, otherwise return @p qubit itself. + */ static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { if (auto blockArg = dyn_cast(qubit)) { return qubits[blockArg.getArgNumber()]; @@ -169,10 +172,15 @@ static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { return qubit; } -// TODO: Rename and document -static void prova(Block& block, IRMapping& mapping, ValueRange innerQubits, - ValueRange outerQubits, ValueRange newQubits, - ValueRange qubitArgs) { +/** + * @brief Create a mapping between block arguments and qubit values. + * + * @details This helper function is used to resolve block arguments for nested + * modifiers. + */ +static void populateMapping(Block& block, IRMapping& mapping, + ValueRange innerQubits, ValueRange outerQubits, + ValueRange newQubits, ValueRange qubitArgs) { for (auto arg : block.getArguments()) { auto innerQubit = innerQubits[arg.getArgNumber()]; auto outerQubit = getValueFromBlockArgument(innerQubit, outerQubits); @@ -180,7 +188,7 @@ static void prova(Block& block, IRMapping& mapping, ValueRange innerQubits, auto index = std::distance(newQubits.begin(), it); mapping.map(arg, qubitArgs[index]); } else { - llvm::reportFatalInternalError("TODO"); + llvm::reportFatalInternalError("Outer qubit not found in new qubits"); } } } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index b1147a6d81..d76aef7ade 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -72,8 +72,8 @@ struct MergeNestedCtrl final : OpRewritePattern { op, controls, targets, [&](ValueRange targetArgs) { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, - targets, targetArgs); + utils::populateMapping(*innerCtrlBody, mapping, innerTargets, + outerTargets, targets, targetArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index fa9fdbfca1..fa1273d503 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -61,8 +61,9 @@ struct MoveCtrlOutside final : OpRewritePattern { rewriter, op.getLoc(), targetArgs, [&](ValueRange qubitArgs) { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::prova(*innerCtrlBody, mapping, innerCtrlOp.getTargets(), - outerQubits, targets, qubitArgs); + utils::populateMapping(*innerCtrlBody, mapping, + innerCtrlOp.getTargets(), outerQubits, + targets, qubitArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 2dba84815d..69febe54e4 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -79,8 +79,8 @@ struct MergeNestedCtrl final : OpRewritePattern { [&](ValueRange targetArgs) -> SmallVector { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, - targets, targetArgs); + utils::populateMapping(*innerCtrlBody, mapping, innerTargets, + outerTargets, targets, targetArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); } @@ -333,7 +333,6 @@ LogicalResult CtrlOp::verify() { } SmallPtrSet uniqueQubitsIn; - SmallPtrSet uniqueTargetsIn; for (const auto& control : getControlsIn()) { if (!uniqueQubitsIn.insert(control).second) { return emitOpError("duplicate control qubit found"); @@ -343,21 +342,8 @@ LogicalResult CtrlOp::verify() { if (!uniqueQubitsIn.insert(target).second) { return emitOpError("duplicate target qubit found"); } - if (!uniqueTargetsIn.insert(target).second) { - return emitOpError("duplicate target qubit found"); - } } - // TODO: Re-enable - // for (size_t i = 0; i < getNumBodyUnitaries(); ++i) { - // auto bodyUnitary = getBodyUnitary(i); - // for (size_t j = 0; j < bodyUnitary.getNumQubits(); ++j) { - // if (!uniqueTargetsIn.contains(bodyUnitary.getInputQubit(j))) { - // return emitOpError("unitary is using an unknown input qubit"); - // } - // } - // } - SmallPtrSet uniqueQubitsOut; for (const auto& control : getControlsOut()) { if (!uniqueQubitsOut.insert(control).second) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 872a8a8902..0c8b4bac98 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -67,9 +67,9 @@ struct MoveCtrlOutside final : OpRewritePattern { [&](ValueRange qubitArgs) -> SmallVector { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::prova(*innerCtrlBody, mapping, - innerCtrlOp.getTargetsIn(), outerQubits, - targets, qubitArgs); + utils::populateMapping(*innerCtrlBody, mapping, + innerCtrlOp.getTargetsIn(), + outerQubits, targets, qubitArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); } @@ -484,16 +484,6 @@ LogicalResult InvOp::verify() { } } - // TODO: Re-enable - // for (size_t i = 0; i < getNumBodyUnitaries(); ++i) { - // auto bodyUnitary = getBodyUnitary(i); - // for (size_t j = 0; j < bodyUnitary.getNumQubits(); ++j) { - // if (!uniqueQubitsIn.contains(bodyUnitary.getInputQubit(j))) { - // return emitOpError("unitary is using an unknown qubit"); - // } - // } - // } - return success(); } diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 07d723ec6b..0221464606 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -517,7 +517,6 @@ static bool areOperationsEquivalent(Operation* lhs, Operation* rhs, ValueRange lhsOperands; ValueRange rhsOperands; - // TODO: Extend this if (auto lhsCtrl = dyn_cast(lhs)) { auto rhsCtrl = dyn_cast(rhs); if (!rhsCtrl) { From fed9ed86725efc786395feed55936bbb1fcea524 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:26:22 +0200 Subject: [PATCH 09/41] Remove remaining TODO comments in preparation for the Rabbit --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 2 -- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 3 --- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 3 --- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 4 ---- 4 files changed, 12 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index d76aef7ade..fad9a4b2e6 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -43,7 +43,6 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -92,7 +91,6 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index fa1273d503..b34424cceb 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -40,7 +40,6 @@ struct MoveCtrlOutside final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -306,7 +305,6 @@ struct CancelNestedInv final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -315,7 +313,6 @@ struct CancelNestedInv final : OpRewritePattern { return failure(); } - // TODO: Relax this condition? if (innerInvOp.getNumBodyUnitaries() != 1) { return failure(); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 69febe54e4..8994661d86 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -49,7 +49,6 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -104,7 +103,6 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -365,7 +363,6 @@ void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional CtrlOp::getUnitaryMatrix() { - // TODO: Relax this condition if (getNumBodyUnitaries() != 1) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 0c8b4bac98..bf9da8144e 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -44,7 +44,6 @@ struct MoveCtrlOutside final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -331,7 +330,6 @@ struct CancelNestedInv final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -340,7 +338,6 @@ struct CancelNestedInv final : OpRewritePattern { return failure(); } - // TODO: Relax this condition? if (innerInvOp.getNumBodyUnitaries() != 1) { return failure(); } @@ -494,7 +491,6 @@ void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional InvOp::getUnitaryMatrix() { - // TODO: Relax this condition if (getNumBodyUnitaries() != 1) { return std::nullopt; } From d4c45e950ee74a9cd0dc79a6130ee21bf7e33477 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:31:46 +0200 Subject: [PATCH 10/41] Fix linter errors --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 2 +- .../TranslateQuantumComputationToQC.cpp | 28 +++++++++++-------- .../programs/quantum_computation_programs.cpp | 3 ++ 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 832820ca65..3078d15b76 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -92,7 +92,7 @@ struct LoweringState : QIRMetadata { DenseMap resultPtrs; /// Modifier information - int64_t inCtrlOp = 0; + size_t inCtrlOp = 0; SmallVector controls; /// Allocator and StringSaver for stable StringRefs diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index 727679d223..930817894d 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -12,6 +12,7 @@ #include "ir/QuantumComputation.hpp" #include "ir/Register.hpp" +#include "ir/operations/CompoundOperation.hpp" #include "ir/operations/Control.hpp" #include "ir/operations/IfElseOperation.hpp" #include "ir/operations/NonUnitaryOperation.hpp" @@ -19,6 +20,7 @@ #include "ir/operations/Operation.hpp" #include "mlir/Dialect/QC/Builder/QCProgramBuilder.h" +#include #include #include #include @@ -78,21 +80,21 @@ struct TranslationState { bool inCtrlOp = false; /// Mapping from physical qubit index to block argument - DenseMap ctrlTargets{}; + DenseMap ctrlTargets; - Value getQubit(size_t index) const { - if (!inCtrlOp) { - if (index >= qubits.size()) { - llvm::reportFatalInternalError("Qubit index out of bounds"); - } - return qubits[index]; - } else { + [[nodiscard]] Value getQubit(size_t index) const { + if (inCtrlOp) { auto it = ctrlTargets.find(index); if (it == ctrlTargets.end()) { llvm::reportFatalInternalError("Qubit index out of bounds"); } return it->second; } + + if (index >= qubits.size()) { + llvm::reportFatalInternalError("Qubit index out of bounds"); + } + return qubits[index]; }; }; @@ -603,8 +605,8 @@ static LogicalResult addCompoundOp(QCProgramBuilder& builder, } SmallVector> sortedPairs(targetMap.begin(), targetMap.end()); - std::sort(sortedPairs.begin(), sortedPairs.end(), - [](const auto& a, const auto& b) { return a.first < b.first; }); + llvm::sort(sortedPairs.begin(), sortedPairs.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); SmallVector targets; for (const auto& pair : sortedPairs) { targets.push_back(pair.second); @@ -845,8 +847,10 @@ OwningOpRef translateQuantumComputationToQC( // Allocate result map SmallVector results(quantumComputation.getNcbits(), nullptr); - TranslationState state{ - .qubits = qubits, .bitMap = bitMap, .results = std::move(results)}; + TranslationState state{.qubits = qubits, + .bitMap = bitMap, + .results = std::move(results), + .ctrlTargets = DenseMap{}}; // Translate operations if (translateOperations(builder, quantumComputation, state).failed()) { diff --git a/mlir/unittests/programs/quantum_computation_programs.cpp b/mlir/unittests/programs/quantum_computation_programs.cpp index db12525029..418798c68d 100644 --- a/mlir/unittests/programs/quantum_computation_programs.cpp +++ b/mlir/unittests/programs/quantum_computation_programs.cpp @@ -11,10 +11,13 @@ #include "quantum_computation_programs.h" #include "ir/QuantumComputation.hpp" +#include "ir/operations/CompoundOperation.hpp" +#include "ir/operations/IfElseOperation.hpp" #include "ir/operations/OpType.hpp" #include "ir/operations/StandardOperation.hpp" #include +#include #include namespace qc { From 925aaa57e98b0f6946f15bf80f8bab8a92601932 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 23:58:41 +0200 Subject: [PATCH 11/41] Address the Rabbit's comments --- mlir/include/mlir/Dialect/Utils/Utils.h | 5 ++++- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 3 ++- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 8 +++++--- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 2 +- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 2 +- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 2 +- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 2 +- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 2 +- 8 files changed, 16 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 072c2c2368..a885c2da2b 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -16,6 +16,7 @@ #include #include +#include #include namespace mlir::utils { @@ -178,9 +179,11 @@ static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { * @details This helper function is used to resolve block arguments for nested * modifiers. */ -static void populateMapping(Block& block, IRMapping& mapping, +static void populateMapping(IRMapping& mapping, Block& block, ValueRange innerQubits, ValueRange outerQubits, ValueRange newQubits, ValueRange qubitArgs) { + assert(innerQubits.size() == block.getNumArguments() && + "Size of innerQubits must match number of block arguments"); for (auto arg : block.getArguments()) { auto innerQubit = innerQubits[arg.getArgNumber()]; auto outerQubit = getValueFromBlockArgument(innerQubit, outerQubits); diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 77758bd3ba..8cc099610d 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -128,7 +128,8 @@ static void inlineRegion(Region& sourceRegion, Region& targetRegion, ConversionPatternRewriter& rewriter) { rewriter.inlineRegionBefore(sourceRegion, targetRegion, targetRegion.end()); auto& block = targetRegion.front(); - + assert(block.getNumArguments() == offset + replacementValues.size() && + "Number of replacement values must match number of block arguments"); for (auto [arg, replacementVal] : llvm::zip_equal( block.getArguments().drop_front(offset), replacementValues)) { arg.replaceAllUsesWith(replacementVal); diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 3078d15b76..e7a6d2910f 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -209,9 +209,11 @@ convertUnitaryToCallOp(QCOpType& op, QCOpAdaptorType& adaptor, operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); // Clean up modifier information - state.inCtrlOp--; - if (inCtrlOp == 0) { - state.controls.clear(); + if (inCtrlOp != 0) { + state.inCtrlOp--; + if (state.inCtrlOp == 0) { + state.controls.clear(); + } } // Replace operation with CallOp diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index fad9a4b2e6..82b5dce077 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -71,7 +71,7 @@ struct MergeNestedCtrl final : OpRewritePattern { op, controls, targets, [&](ValueRange targetArgs) { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::populateMapping(*innerCtrlBody, mapping, innerTargets, + utils::populateMapping(mapping, *innerCtrlBody, innerTargets, outerTargets, targets, targetArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index b34424cceb..1417d6de75 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -60,7 +60,7 @@ struct MoveCtrlOutside final : OpRewritePattern { rewriter, op.getLoc(), targetArgs, [&](ValueRange qubitArgs) { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::populateMapping(*innerCtrlBody, mapping, + utils::populateMapping(mapping, *innerCtrlBody, innerCtrlOp.getTargets(), outerQubits, targets, qubitArgs); for (auto& op : innerCtrlBody->without_terminator()) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 8994661d86..444f554393 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -78,7 +78,7 @@ struct MergeNestedCtrl final : OpRewritePattern { [&](ValueRange targetArgs) -> SmallVector { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::populateMapping(*innerCtrlBody, mapping, innerTargets, + utils::populateMapping(mapping, *innerCtrlBody, innerTargets, outerTargets, targets, targetArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index bf9da8144e..3456e0b11c 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -66,7 +66,7 @@ struct MoveCtrlOutside final : OpRewritePattern { [&](ValueRange qubitArgs) -> SmallVector { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::populateMapping(*innerCtrlBody, mapping, + utils::populateMapping(mapping, *innerCtrlBody, innerCtrlOp.getTargetsIn(), outerQubits, targets, qubitArgs); for (auto& op : innerCtrlBody->without_terminator()) { diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 4a8fb691ad..883c7d32d2 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -969,7 +969,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(emptyQCO)}, QCOTestCase{"ControlledTwoX", MQT_NAMED_BUILDER(controlledTwoX), MQT_NAMED_BUILDER(emptyQCO)}, - QCOTestCase{"inverseTwoX", MQT_NAMED_BUILDER(twoX), + QCOTestCase{"InverseTwoX", MQT_NAMED_BUILDER(inverseTwoX), MQT_NAMED_BUILDER(emptyQCO)})); /// @} From db79dac6cf29a5d584f61e89274f2cd58e2bb788 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:28:33 +0200 Subject: [PATCH 12/41] Fix inverse cancellation --- mlir/include/mlir/Dialect/QCO/QCOUtils.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 489fceb00e..e73bdd2650 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -35,7 +35,8 @@ removeInversePairOneTargetZeroParameter(OpType op, PatternRewriter& rewriter) { } // Unlink both operations - rewriter.replaceAllUsesWith(nextOp->getResult(0), op.getInputQubit(0)); + rewriter.replaceOp(op, op.getInputQubits()); + rewriter.replaceOp(nextOp, nextOp.getInputQubits()); return success(); } @@ -64,7 +65,8 @@ removeInversePairTwoTargetZeroParameter(OpType op, PatternRewriter& rewriter) { } // Unlink both operations - rewriter.replaceAllUsesWith(nextOp->getResults(), op.getOperands()); + rewriter.replaceOp(op, op.getInputQubits()); + rewriter.replaceOp(nextOp, nextOp.getInputQubits()); return success(); } @@ -95,8 +97,8 @@ removeTwoTargetZeroParameterPairWithSwappedTargets(OpType op, } // Unlink both operations - rewriter.replaceAllUsesWith(nextOp->getResults(), - {op.getInputQubit(1), op.getInputQubit(0)}); + rewriter.replaceOp(op, op.getInputQubits()); + rewriter.replaceOp(nextOp, nextOp.getInputQubits()); return success(); } From e14656327c9e50d1aefe4d3af497d6190ebfc2d7 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:48:54 +0200 Subject: [PATCH 13/41] Improve translation of nested control modifiers --- mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp | 20 ++++++++--- .../TranslateQuantumComputationToQC.cpp | 34 ++++++++++++++----- .../Conversion/QCOToQC/test_qco_to_qc.cpp | 3 ++ .../Conversion/QCToQCO/test_qc_to_qco.cpp | 3 ++ .../test_quantum_computation_translation.cpp | 3 ++ mlir/unittests/programs/qc_programs.cpp | 8 +++++ mlir/unittests/programs/qc_programs.h | 4 +++ mlir/unittests/programs/qco_programs.cpp | 11 ++++++ mlir/unittests/programs/qco_programs.h | 4 +++ .../programs/quantum_computation_programs.cpp | 11 ++++++ .../programs/quantum_computation_programs.h | 4 +++ 11 files changed, 92 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index 09bb1b0949..96598b2c82 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -880,18 +880,24 @@ struct ConvertQCOCtrlOpToJeff final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(CtrlOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 1) { + return rewriter.notifyMatchFailure( + op, + "Control modifiers with multiple body unitaries are not supported."); + } + auto& state = getState(); if (state.inCtrlOp) { return rewriter.notifyMatchFailure( - op, "Nested control operations are not supported. Run the " + op, "Nested control modifiers are not supported. Run the " "canonicalization pass before the conversion"); } if (state.inInvOp) { return rewriter.notifyMatchFailure( - op, "Control operations inside inversion operations are not " - "supported. Run the canonicalization pass before the conversion"); + op, "Control modifiers inside inversion modifiers are not supported. " + "Run the canonicalization pass before the conversion"); } // Set modifier information @@ -930,11 +936,17 @@ struct ConvertQCOInvOpToJeff final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(InvOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 1) { + return rewriter.notifyMatchFailure(op, + "Inversion modifiers with multiple " + "body unitaries are not supported."); + } + auto& state = getState(); if (state.inInvOp) { return rewriter.notifyMatchFailure( - op, "Nested inversion operations are not supported. Run the " + op, "Nested inversion modifiers are not supported. Run the " "canonicalization pass before the conversion"); } diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index 930817894d..7a2cf23651 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QC/Translation/TranslateQuantumComputationToQC.h" +#include "ir/Definitions.hpp" #include "ir/QuantumComputation.hpp" #include "ir/Register.hpp" #include "ir/operations/CompoundOperation.hpp" @@ -80,12 +81,15 @@ struct TranslationState { bool inCtrlOp = false; /// Mapping from physical qubit index to block argument - DenseMap ctrlTargets; + DenseMap targetArgs; + + /// Control qubits of the current CompoundOperation + DenseSet<::qc::Qubit> compoundControls; [[nodiscard]] Value getQubit(size_t index) const { if (inCtrlOp) { - auto it = ctrlTargets.find(index); - if (it == ctrlTargets.end()) { + auto it = targetArgs.find(index); + if (it == targetArgs.end()) { llvm::reportFatalInternalError("Qubit index out of bounds"); } return it->second; @@ -286,11 +290,11 @@ static void addResetOp(QCProgramBuilder& builder, */ static SmallVector getControls(const ::qc::Operation& operation, TranslationState& state) { - if (state.inCtrlOp) { - return {}; - } SmallVector controls; for (const auto& [control, type] : operation.getControls()) { + if (state.compoundControls.contains(control)) { + continue; + } if (type == ::qc::Control::Type::Neg) { llvm::reportFatalInternalError( "Negative controls cannot be translated to QC at the moment"); @@ -602,6 +606,18 @@ static LogicalResult addCompoundOp(QCProgramBuilder& builder, targetMap[target] = state.getQubit(target); } } + for (const auto& control : op->getControls()) { + if (compoundOp.getControls().contains(control)) { + continue; + } + const auto& qubit = control.qubit; + if (!targetMap.contains(qubit)) { + targetMap[qubit] = state.getQubit(qubit); + } + } + } + for (const auto& [control, _] : compoundOp.getControls()) { + state.compoundControls.insert(control); } SmallVector> sortedPairs(targetMap.begin(), targetMap.end()); @@ -615,7 +631,7 @@ static LogicalResult addCompoundOp(QCProgramBuilder& builder, builder.ctrl(controls, targets, [&](ValueRange targetArgs) { state.inCtrlOp = true; for (size_t i = 0; i < sortedPairs.size(); ++i) { - state.ctrlTargets[sortedPairs[i].first] = targetArgs[i]; + state.targetArgs[sortedPairs[i].first] = targetArgs[i]; } for (const auto& op : compoundOp) { if (failed(translateOperation(builder, *op, state))) { @@ -623,7 +639,7 @@ static LogicalResult addCompoundOp(QCProgramBuilder& builder, "controlled CompoundOperation"); } } - state.ctrlTargets.clear(); + state.targetArgs.clear(); state.inCtrlOp = false; }); } @@ -850,7 +866,7 @@ OwningOpRef translateQuantumComputationToQC( TranslationState state{.qubits = qubits, .bitMap = bitMap, .results = std::move(results), - .ctrlTargets = DenseMap{}}; + .targetArgs = DenseMap{}}; // Translate operations if (translateOperations(builder, quantumComputation, state).failed()) { diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index aa3a428809..7dda9ccfda 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -150,6 +150,9 @@ INSTANTIATE_TEST_SUITE_P( QCOCtrlOpTest, QCOToQCTest, testing::Values(QCOToQCTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qco::ctrlTwo), MQT_NAMED_BUILDER(qc::ctrlTwo)}, + QCOToQCTestCase{"CtrlTwoMixed", + MQT_NAMED_BUILDER(qco::ctrlTwoMixed), + MQT_NAMED_BUILDER(qc::ctrlTwoMixed)}, QCOToQCTestCase{"CtrlInvTwo", MQT_NAMED_BUILDER(qco::ctrlInvTwo), MQT_NAMED_BUILDER(qc::ctrlInvTwo)})); diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index 71f47b0841..00b2c7fe7b 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -149,6 +149,9 @@ INSTANTIATE_TEST_SUITE_P( QCCtrlOpTest, QCToQCOTest, testing::Values(QCToQCOTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), MQT_NAMED_BUILDER(qco::ctrlTwo)}, + QCToQCOTestCase{"CtrlTwoMixed", + MQT_NAMED_BUILDER(qc::ctrlTwoMixed), + MQT_NAMED_BUILDER(qco::ctrlTwoMixed)}, QCToQCOTestCase{"CtrlInvTwo", MQT_NAMED_BUILDER(qc::ctrlInvTwo), MQT_NAMED_BUILDER(qco::ctrlInvTwo)})); diff --git a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp index 893902448b..b47c9f97a7 100644 --- a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp +++ b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp @@ -421,6 +421,9 @@ INSTANTIATE_TEST_SUITE_P( QuantumComputationTranslationTestCase{ "CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), MQT_NAMED_BUILDER(mlir::qc::ctrlTwo)}, + QuantumComputationTranslationTestCase{ + "CtrlTwoMixed", MQT_NAMED_BUILDER(qc::ctrlTwoMixed), + MQT_NAMED_BUILDER(mlir::qc::ctrlTwoMixed)}, QuantumComputationTranslationTestCase{ "SimpleIf", MQT_NAMED_BUILDER(qc::simpleIf), MQT_NAMED_BUILDER(mlir::qc::simpleIf)}, diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 7c7608963d..232b8cbdee 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -1405,6 +1405,14 @@ void ctrlTwo(QCProgramBuilder& b) { }); } +void ctrlTwoMixed(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + b.cx(targets[0], targets[1]); + b.rxx(0.123, targets[0], targets[1]); + }); +} + void nestedCtrlTwo(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index 2f08c5236e..dbf855b982 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -832,6 +832,10 @@ void ctrlInvSandwich(QCProgramBuilder& b); /// Creates a circuit with a control modifier applied to two gates. void ctrlTwo(QCProgramBuilder& b); +/// Creates a circuit with a control modifier applied to a controlled and a +/// non-controlled gate. +void ctrlTwoMixed(QCProgramBuilder& b); + /// Creates a circuit with nested control modifiers applied to two gates. void nestedCtrlTwo(QCProgramBuilder& b); diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 868e3a2a3b..523f071f8a 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -2038,6 +2038,17 @@ void ctrlTwo(QCOProgramBuilder& b) { }); } +void ctrlTwoMixed(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + auto i0 = targets[0]; + auto i1 = targets[1]; + std::tie(i0, i1) = b.cx(i0, i1); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); +} + void nestedCtrlTwo(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(4); b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index 1ec606d103..f562cfff8a 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -985,6 +985,10 @@ void ctrlInvSandwich(QCOProgramBuilder& b); /// Creates a circuit with a control modifier applied to two gates. void ctrlTwo(QCOProgramBuilder& b); +/// Creates a circuit with a control modifier applied to a controlled and a +/// non-controlled gate. +void ctrlTwoMixed(QCOProgramBuilder& b); + /// Creates a circuit with nested control modifiers applied to two gates. void nestedCtrlTwo(QCOProgramBuilder& b); diff --git a/mlir/unittests/programs/quantum_computation_programs.cpp b/mlir/unittests/programs/quantum_computation_programs.cpp index 418798c68d..f0b9b305cd 100644 --- a/mlir/unittests/programs/quantum_computation_programs.cpp +++ b/mlir/unittests/programs/quantum_computation_programs.cpp @@ -553,6 +553,17 @@ void ctrlTwo(QuantumComputation& comp) { comp.emplace_back(std::move(compound)); } +void ctrlTwoMixed(QuantumComputation& comp) { + const auto& q = comp.addQubitRegister(4, "q"); + CompoundOperation compound; + compound.emplace_back(2, 3, X); + compound.emplace_back(Targets{2, 3}, RXX, + std::vector{0.123}); + compound.addControl(0); + compound.addControl(1); + comp.emplace_back(std::move(compound)); +} + void simpleIf(QuantumComputation& comp) { const auto& q = comp.addQubitRegister(1, "q"); const auto& c = comp.addClassicalRegister(1, "c"); diff --git a/mlir/unittests/programs/quantum_computation_programs.h b/mlir/unittests/programs/quantum_computation_programs.h index f0e1856d8f..f6dab6e1c2 100644 --- a/mlir/unittests/programs/quantum_computation_programs.h +++ b/mlir/unittests/programs/quantum_computation_programs.h @@ -390,6 +390,10 @@ void barrierMultipleQubits(QuantumComputation& comp); /// Creates a circuit with a control modifier applied to two gates. void ctrlTwo(QuantumComputation& comp); +/// Creates a circuit with a control modifier applied to a controlled and a +/// non-controlled gate. +void ctrlTwoMixed(QuantumComputation& comp); + // --- IfOp ----------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. From d84b58b6c309d9e5370fdf966e25f07c1fa6bde2 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:10:13 +0200 Subject: [PATCH 14/41] Improve verifiers --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 5 +++++ mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 5 +++++ mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 16 +++++++++++----- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 16 +++++++++++----- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 82b5dce077..1fd3d13fc1 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -225,6 +225,11 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); + } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 1417d6de75..6215a81c1c 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -395,6 +395,11 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); + } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 444f554393..7f084b9c06 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -308,22 +308,28 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); + } + if (!isa(block.back())) { + return emitOpError( + "last operation in body region must be a yield operation"); + } + const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( "number of block arguments must match the number of targets"); } - const auto qubitType = QubitType::get(getContext()); + auto qubitType = QubitType::get(getContext()); for (size_t i = 0; i < numTargets; ++i) { if (block.getArgument(i).getType() != qubitType) { return emitOpError("block argument type at index ") << i << " does not match target type"; } } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } if (const auto numYieldOperands = block.back().getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 3456e0b11c..cf2043b582 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -452,22 +452,28 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); + } + if (!isa(block.back())) { + return emitOpError( + "last operation in body region must be a yield operation"); + } + const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( "number of block arguments must match the number of targets"); } - const auto qubitType = QubitType::get(getContext()); + auto qubitType = QubitType::get(getContext()); for (size_t i = 0; i < numTargets; ++i) { if (block.getArgument(i).getType() != qubitType) { return emitOpError("block argument type at index ") << i << " does not match target type"; } } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } if (const auto numYieldOperands = block.back().getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") From 570878fd0f5f17049dbaed0bc712c40c39496850 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:20:16 +0200 Subject: [PATCH 15/41] Improve implementation of getNumBodyUnitaries() --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 9 ++------- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 9 ++------- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 9 ++------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 9 ++------- 4 files changed, 8 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 1fd3d13fc1..bc7bd64f9d 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -166,13 +166,8 @@ struct EraseEmptyCtrl final : OpRewritePattern { } // namespace size_t CtrlOp::getNumBodyUnitaries() { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - count++; - } - } - return count; + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 6215a81c1c..5794f08a35 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -354,13 +354,8 @@ struct EraseEmptyInv final : OpRewritePattern { } // namespace size_t InvOp::getNumBodyUnitaries() { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - count++; - } - } - return count; + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 7f084b9c06..bca38f47a1 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -189,13 +189,8 @@ struct EraseEmptyCtrl final : OpRewritePattern { } // namespace size_t CtrlOp::getNumBodyUnitaries() { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - count++; - } - } - return count; + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index cf2043b582..e892b3e900 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -379,13 +379,8 @@ struct EraseEmptyInv final : OpRewritePattern { } // namespace size_t InvOp::getNumBodyUnitaries() { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - count++; - } - } - return count; + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { From c199cc9babcae95bc86e69e11023ac5094b83313 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:26:34 +0200 Subject: [PATCH 16/41] Improve implementation of getBodyUnitary() --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 15 ++++++--------- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 15 ++++++--------- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 15 ++++++--------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 15 ++++++--------- 4 files changed, 24 insertions(+), 36 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index bc7bd64f9d..abff17e8d6 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -171,16 +171,13 @@ size_t CtrlOp::getNumBodyUnitaries() { } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - if (count == i) { - return cast(op); - } - count++; - } + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), i); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); } - llvm::reportFatalUsageError("Unitary index out of bounds"); + return cast(*it); } Value CtrlOp::getQubit(const size_t i) { diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 5794f08a35..398342f7ad 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -359,16 +359,13 @@ size_t InvOp::getNumBodyUnitaries() { } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - if (count == i) { - return cast(op); - } - count++; - } + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), i); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); } - llvm::reportFatalUsageError("Invalid unitary index"); + return cast(*it); } void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index bca38f47a1..81ea40f784 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -194,16 +194,13 @@ size_t CtrlOp::getNumBodyUnitaries() { } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - if (count == i) { - return cast(op); - } - count++; - } + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), i); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); } - llvm::reportFatalUsageError("Unitary index out of bounds"); + return cast(*it); } Value CtrlOp::getInputQubit(const size_t i) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index e892b3e900..cb7ad50acd 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -384,16 +384,13 @@ size_t InvOp::getNumBodyUnitaries() { } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - if (count == i) { - return cast(op); - } - count++; - } + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), i); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); } - llvm::reportFatalUsageError("Unitary index out of bounds"); + return cast(*it); } Value InvOp::getInputQubit(const size_t i) { From 4c4b31ca15ea9e3fd239dd80fb3a8656384719c7 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:50:24 +0200 Subject: [PATCH 17/41] Fix linter errors --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 3 ++- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 4 +++- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 3 ++- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 3 ++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index abff17e8d6..059e7dc736 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -24,6 +24,7 @@ #include #include +#include using namespace mlir; using namespace mlir::qc; @@ -173,7 +174,7 @@ size_t CtrlOp::getNumBodyUnitaries() { UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { auto unitaries = llvm::make_filter_range( *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), i); + auto it = std::next(unitaries.begin(), static_cast(i)); if (it == unitaries.end()) { llvm::reportFatalUsageError("Unitary index out of bounds"); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 398342f7ad..b1cf6e46ae 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" +#include #include #include #include @@ -24,6 +25,7 @@ #include #include +#include #include using namespace mlir; @@ -361,7 +363,7 @@ size_t InvOp::getNumBodyUnitaries() { UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { auto unitaries = llvm::make_filter_range( *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), i); + auto it = std::next(unitaries.begin(), static_cast(i)); if (it == unitaries.end()) { llvm::reportFatalUsageError("Unitary index out of bounds"); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 81ea40f784..d9a67fafca 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include using namespace mlir; @@ -196,7 +197,7 @@ size_t CtrlOp::getNumBodyUnitaries() { UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { auto unitaries = llvm::make_filter_range( *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), i); + auto it = std::next(unitaries.begin(), static_cast(i)); if (it == unitaries.end()) { llvm::reportFatalUsageError("Unitary index out of bounds"); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index cb7ad50acd..1f0e6d556c 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -386,7 +387,7 @@ size_t InvOp::getNumBodyUnitaries() { UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { auto unitaries = llvm::make_filter_range( *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), i); + auto it = std::next(unitaries.begin(), static_cast(i)); if (it == unitaries.end()) { llvm::reportFatalUsageError("Unitary index out of bounds"); } From e67219c3a3bf3c77d12c5d02fc205e15f9137d28 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 5 Jun 2026 14:04:19 +0200 Subject: [PATCH 18/41] Update documentation --- .../Dialect/QC/Builder/QCProgramBuilder.h | 68 ++++++++++--------- mlir/include/mlir/Dialect/QC/IR/QCOps.td | 27 ++++---- mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 23 ++++--- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 4 +- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 4 +- 5 files changed, 69 insertions(+), 57 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index c2483d9f06..ed94af35bc 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -406,8 +406,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.c##OP_NAME(q0, q1); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0) { \ - * qc.OP_NAME %q1 : !qc.qubit \ + * qc.ctrl(%q0) targets(%a0 = %q1) { \ + * qc.OP_NAME %a0 : !qc.qubit \ * } : !qc.qubit \ * ``` \ */ \ @@ -424,8 +424,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.mc##OP_NAME({q0, q1}, q2); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0, %q1) { \ - * qc.OP_NAME %q2 : !qc.qubit \ + * qc.ctrl(%q0, %q1) targets(%a0 = %q2) { \ + * qc.OP_NAME %a0 : !qc.qubit \ * } : !qc.qubit, !qc.qubit \ * ``` \ */ \ @@ -478,8 +478,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.c##OP_NAME(PARAM, q0, q1); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0) { \ - * qc.OP_NAME(%PARAM) %q1 : !qc.qubit \ + * qc.ctrl(%q0) targets(%a0 = %q1) { \ + * qc.OP_NAME(%PARAM) %a0 : !qc.qubit \ * } : !qc.qubit \ * ``` \ */ \ @@ -498,8 +498,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.mc##OP_NAME(PARAM, {q0, q1}, q2); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0, %q1) { \ - * qc.OP_NAME(%PARAM) %q2 : !qc.qubit \ + * qc.ctrl(%q0, %q1) targets(%a0 = %q2) { \ + * qc.OP_NAME(%PARAM) %a0 : !qc.qubit \ * } : !qc.qubit, !qc.qubit \ * ``` \ */ \ @@ -549,8 +549,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.c##OP_NAME(PARAM1, PARAM2, q0, q1); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0) { \ - * qc.OP_NAME(%PARAM1, %PARAM2) %q1 : !qc.qubit \ + * qc.ctrl(%q0) (%a0 = %q1) { \ + * qc.OP_NAME(%PARAM1, %PARAM2) %a0 : !qc.qubit \ * } : !qc.qubit \ * ``` \ */ \ @@ -571,8 +571,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.mc##OP_NAME(PARAM1, PARAM2, {q0, q1}, q2); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0, %q1) { \ - * qc.OP_NAME(%PARAM1, %PARAM2) %q2 : !qc.qubit \ + * qc.ctrl(%q0, %q1) targets(%a0 = %q2) { \ + * qc.OP_NAME(%PARAM1, %PARAM2) %a0 : !qc.qubit \ * } : !qc.qubit, !qc.qubit \ * ``` \ */ \ @@ -625,8 +625,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.c##OP_NAME(PARAM1, PARAM2, PARAM3, q0, q1); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0) { \ - * qc.OP_NAME(%PARAM1, %PARAM2, %PARAM3) %q1 : !qc.qubit \ + * qc.ctrl(%q0) targets(%a0 = %q1) { \ + * qc.OP_NAME(%PARAM1, %PARAM2, %PARAM3) %a0 : !qc.qubit \ * } : !qc.qubit \ * ``` \ */ \ @@ -649,8 +649,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.mc##OP_NAME(PARAM1, PARAM2, PARAM3, {q0, q1}, q2); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0, %q1) { \ - * qc.OP_NAME(%PARAM1, %PARAM2, %PARAM3) %q2 : !qc.qubit \ + * qc.ctrl(%q0, %q1) targets(%a0 = %q2) { \ + * qc.OP_NAME(%PARAM1, %PARAM2, %PARAM3) %a0 : !qc.qubit \ * } : !qc.qubit, !qc.qubit \ * ``` \ */ \ @@ -695,8 +695,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.c##OP_NAME(q0, q1, q2); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0) { \ - * qc.OP_NAME %q1, %q2 : !qc.qubit, !qc.qubit \ + * qc.ctrl(%q0) targets(%a0 = %q1, %a1 = %q2) { \ + * qc.OP_NAME %a0, %a1 : !qc.qubit, !qc.qubit \ * } : !qc.qubit \ * ``` \ */ \ @@ -714,8 +714,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.mc##OP_NAME({q0, q1}, q2, q3); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0, %q1) { \ - * qc.OP_NAME %q2, %q3 : !qc.qubit, !qc.qubit \ + * qc.ctrl(%q0, %q1) targets(%a0 = %q2, %a1 = %q3) { \ + * qc.OP_NAME %a0, %a1 : !qc.qubit, !qc.qubit \ * } : !qc.qubit, !qc.qubit \ * ``` \ */ \ @@ -764,8 +764,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.c##OP_NAME(PARAM, q0, q1, q2); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0) { \ - * qc.OP_NAME(%PARAM) %q1, %q2 : !qc.qubit, !qc.qubit \ + * qc.ctrl(%q0) targets(%a0 = %q1, %a1 = %q2) { \ + * qc.OP_NAME(%PARAM) %a0, %a1 : !qc.qubit, !qc.qubit \ * } : !qc.qubit \ * ``` \ */ \ @@ -785,8 +785,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.mc##OP_NAME(PARAM, {q0, q1}, q2, q3); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0, %q1) { \ - * qc.OP_NAME(%PARAM) %q2, %q3 : !qc.qubit, !qc.qubit \ + * qc.ctrl(%q0, %q1) targets(%a0 = %q2, %a1 = %q3) { \ + * qc.OP_NAME(%PARAM) %a0, %a1 : !qc.qubit, !qc.qubit \ * } : !qc.qubit, !qc.qubit \ * ``` \ */ \ @@ -839,8 +839,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.c##OP_NAME(PARAM1, PARAM2, q0, q1, q2); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0) { \ - * qc.OP_NAME(%PARAM1, %PARAM2) %q1, %q2 : !qc.qubit, \ + * qc.ctrl(%q0) targets(%a0 = %q1, %a1 = %q2) { \ + * qc.OP_NAME(%PARAM1, %PARAM2) %a0, %a1 : !qc.qubit, \ * !qc.qubit \ * } : !qc.qubit \ * ``` \ @@ -863,8 +863,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * builder.mc##OP_NAME(PARAM1, PARAM2, {q0, q1}, q2, q3); \ * ``` \ * ```mlir \ - * qc.ctrl(%q0, %q1) { \ - * qc.OP_NAME(%PARAM1, %PARAM2) %q2, %q3 : !qc.qubit, !qc.qubit \ + * qc.ctrl(%q0, %q1) targets(%a0 = %q2, %a1 = %q3) { \ + * qc.OP_NAME(%PARAM1, %PARAM2) %a0, %a1 : !qc.qubit, !qc.qubit \ * } : !qc.qubit, !qc.qubit \ * ``` \ */ \ @@ -909,11 +909,13 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * * @par Example: * ```c++ - * builder.ctrl(q0, [&] { builder.x(q1); }); + * builder.ctrl(q0, q1, [&](ValueRange targets) { + * builder.x(targets[0]); + * }); * ``` * ```mlir - * qc.ctrl(%q0) { - * qc.x %q1 : !qc.qubit + * qc.ctrl(%q0) targets(%a0 = %q1) { + * qc.x %a0 : !qc.qubit * } : !qc.qubit * ``` */ @@ -929,7 +931,9 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * * @par Example: * ```c++ - * builder.inv([&] { builder.s(q0); }); + * builder.inv(q0, [&](ValueRange qubits) { + * builder.h(qubits[0]); + * }); * ``` * ```mlir * qc.inv { diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index cce7265131..96aac45960 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -921,18 +921,19 @@ def CtrlOp RecursiveMemoryEffects]> { let summary = "Add control qubits to a unitary operation"; let description = [{ - A modifier operation that adds control qubits to the unitary operation - defined in its body region. The controlled operation applies the - underlying unitary only when all control qubits are in the |1⟩ state. + A modifier operation that adds control qubits to the unitary operation defined in its body region. + The controlled operation applies the underlying unitary only when all control qubits are in the $|1\rangle$ state. - Note that control qubits are logically unmodified by this operation in that - their quantum state remains unchanged. However, the `controls` argument - is marked with `MemWrite` to ensure correct dependency tracking in MLIR. + Note that control qubits are logically unmodified by this operation in that their quantum state remains unchanged. + However, the `controls` argument is marked with `MemWrite` to ensure correct dependency tracking in MLIR. + + The body region may contain an arbitrary amount of unitary and classical operations. + Non-unitary operations, such as `AllocOp` and `MeasureOp`, are not allowed. Example: ```mlir - qc.ctrl(%q0) { - qc.x %q1 : !qc.qubit + qc.ctrl(%q0) targets(%a0 = %q1) { + qc.x %a0 : !qc.qubit } : !qc.qubit ``` }]; @@ -978,13 +979,15 @@ def InvOp : QCOp<"inv", RecursiveMemoryEffects]> { let summary = "Invert a unitary operation"; let description = [{ - A modifier operation that inverts the unitary operation defined in its body - region. + A modifier operation that inverts the unitary operation defined in its body region. + + The body region may contain an arbitrary amount of unitary and classical operations. + Non-unitary operations, such as `AllocOp` and `MeasureOp`, are not allowed. Example: ```mlir - qc.inv { - qc.s %q0 : !qc.qubit + qc.inv (%a0 = %q0) { + qc.s %a0 : !qc.qubit } ``` }]; diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index 78e15ecc86..8d0025d286 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -1069,12 +1069,14 @@ def CtrlOp RecursiveMemoryEffects]> { let summary = "Add control qubits to a unitary operation"; let description = [{ - A modifier operation that adds control qubits to the unitary operation - defined in its body region. The controlled operation applies the - underlying unitary only when all control qubits are in the |1⟩ state. - The operation takes a variadic number of control and target qubits as - inputs and produces corresponding output qubits. Control qubits are not - modified by the operation and simply pass through to the outputs. + A modifier operation that adds control qubits to the unitary operation defined in its body region. + The controlled operation applies the underlying unitary only when all control qubits are in the $|1\rangle$ state. + + The operation takes a variadic number of control and target qubits as inputs and produces corresponding output qubits. + Control qubits are not modified by the operation and simply pass through to the outputs. + + The body region may contain an arbitrary amount of unitary and classical operations. + Non-unitary operations, such as `AllocOp` and `MeasureOp`, are not allowed. Example: ```mlir @@ -1147,9 +1149,12 @@ def InvOp RecursiveMemoryEffects]> { let summary = "Invert a unitary operation"; let description = [{ - A modifier operation that inverts the unitary operation defined in its body - region. The operation takes a variadic number of qubits as inputs and - produces corresponding output qubits. + A modifier operation that inverts the unitary operation defined in its body region. + + The operation takes a variadic number of qubits as inputs and produces corresponding output qubits. + + The body region may contain an arbitrary amount of unitary and classical operations. + Non-unitary operations, such as `AllocOp` and `MeasureOp`, are not allowed. Example: ```mlir diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 8cc099610d..a8ba24090d 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -630,8 +630,8 @@ struct ConvertQCOBarrierOp final : OpConversionPattern { * ``` * is converted to * ```mlir - * qc.ctrl(%q0) { - * qc.x %q1 : !qc.qubit + * qc.ctrl(%q0) targets(%a0 = %q1) { + * qc.x %a0 : !qc.qubit * } : !qc.qubit * ``` */ diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 33a2df9217..65877a3120 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1076,8 +1076,8 @@ struct ConvertQCBarrierOp final : StatefulOpConversionPattern { * * @par Example: * ```mlir - * qc.ctrl(%q0) { - * qc.x %q1 : !qc.qubit + * qc.ctrl(%q0) targets(%a0 = %q1) { + * qc.x %a0 : !qc.qubit * } : !qc.qubit * ``` * is converted to From 0082619bcd7810162eed42add63c188bcc2f7af6 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 9 Jun 2026 22:17:48 +0200 Subject: [PATCH 19/41] Update changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b24ffa3f8c..d4fea3213b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel - ✨ Add conversions between `jeff` and QCO ([#1479], [#1548], [#1565], [#1637], [#1676], [#1706]) ([**@denialhaag**], [**@burgholzer**]) - ✨ Add a `place-and-route` pass for mapping circuits to architectures with restricted topologies ([#1537], [#1547], [#1568], [#1581], [#1583], [#1588], [#1600], [#1664], [#1709], [#1716], [#1748]) ([**@MatthiasReumann**], [**@burgholzer**]) - ✨ Add initial infrastructure for new QC and QCO MLIR dialects - ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1717], [#1728], [#1730], [#1749], [#1762], [#1765], [#1774]) + ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1717], [#1728], [#1730], [#1749], [#1751], [#1762], [#1765], [#1774]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**]) ### Changed @@ -405,6 +405,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1774]: https://github.com/munich-quantum-toolkit/core/pull/1774 [#1765]: https://github.com/munich-quantum-toolkit/core/pull/1765 [#1762]: https://github.com/munich-quantum-toolkit/core/pull/1762 +[#1751]: https://github.com/munich-quantum-toolkit/core/pull/1751 [#1749]: https://github.com/munich-quantum-toolkit/core/pull/1749 [#1748]: https://github.com/munich-quantum-toolkit/core/pull/1748 [#1737]: https://github.com/munich-quantum-toolkit/core/pull/1737 From eea4c2d6e63071a3130cb78b5856e6bb904f3215 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 12:19:30 +0200 Subject: [PATCH 20/41] :recycle: Introduce a `getQubits` method for the QC UnitaryOpInterface and use `OperandRange` more consistently The addition of the method is possible now that modifiers in QC have block arguments for the target qubits. The OperandRange change just more naturally reflects what these values truly are. Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/QC/IR/QCDialect.h | 7 ++++--- mlir/include/mlir/Dialect/QC/IR/QCInterfaces.td | 9 ++++++--- mlir/include/mlir/Dialect/QC/IR/QCOps.td | 15 ++++++++------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/IR/QCDialect.h b/mlir/include/mlir/Dialect/QC/IR/QCDialect.h index 96fbb7f724..5b1f5a7e92 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCDialect.h +++ b/mlir/include/mlir/Dialect/QC/IR/QCDialect.h @@ -51,7 +51,7 @@ template class TargetAndParameterArityTrait { static size_t getNumQubits() { return T; } static size_t getNumTargets() { return T; } static size_t getNumControls() { return 0; } - static ValueRange getControls() { return {}; } + static OperandRange getControls() { return {nullptr, 0}; } Value getQubit(size_t i) { if constexpr (T == 0) { @@ -71,7 +71,8 @@ template class TargetAndParameterArityTrait { } return this->getOperation()->getOperand(i); } - ValueRange getTargets() { + OperandRange getQubits() { return getTargets(); } + OperandRange getTargets() { return this->getOperation()->getOperands().slice(0, T); } @@ -88,7 +89,7 @@ template class TargetAndParameterArityTrait { return this->getOperation()->getOperand(T + i); } - ValueRange getParameters() { + OperandRange getParameters() { return this->getOperation()->getOperands().slice(T, P); } }; diff --git a/mlir/include/mlir/Dialect/QC/IR/QCInterfaces.td b/mlir/include/mlir/Dialect/QC/IR/QCInterfaces.td index 9b2399bdff..e82752a127 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCInterfaces.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCInterfaces.td @@ -44,9 +44,12 @@ def UnitaryOpInterface : OpInterface<"UnitaryOpInterface"> { (ins "size_t":$i)>, InterfaceMethod<"Returns the i-th control qubit.", "Value", "getControl", (ins "size_t":$i)>, - InterfaceMethod<"Returns a range of all target qubits.", "ValueRange", + InterfaceMethod< + "Returns a range of all qubits (targets + controls combined).", + "OperandRange", "getQubits", (ins)>, + InterfaceMethod<"Returns a range of all target qubits.", "OperandRange", "getTargets", (ins)>, - InterfaceMethod<"Returns a range of all control qubits.", "ValueRange", + InterfaceMethod<"Returns a range of all control qubits.", "OperandRange", "getControls", (ins)>, // Parameter handling @@ -54,7 +57,7 @@ def UnitaryOpInterface : OpInterface<"UnitaryOpInterface"> { "getNumParams", (ins)>, InterfaceMethod<"Returns the i-th parameter.", "Value", "getParameter", (ins "size_t":$i)>, - InterfaceMethod<"Returns a range of all parameters.", "ValueRange", + InterfaceMethod<"Returns a range of all parameters.", "OperandRange", "getParameters", (ins)>, // Convenience methods diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index 96aac45960..dd120e52a5 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -889,14 +889,14 @@ def BarrierOp : QCOp<"barrier", traits = [UnitaryOpInterface]> { size_t getNumQubits() { return getNumTargets(); } size_t getNumTargets() { return getQubits().size(); } static size_t getNumControls() { return 0; } - static ValueRange getControls() { return {}; } + static OperandRange getControls() { return {nullptr, 0}; } Value getQubit(size_t i) { return getTarget(i); } Value getTarget(size_t i); - ValueRange getTargets() { return getQubits(); } + OperandRange getTargets() { return getQubits(); } static Value getControl(size_t i) { llvm::reportFatalUsageError("BarrierOp cannot be controlled"); } static size_t getNumParams() { return 0; } static Value getParameter(size_t i) { llvm::reportFatalUsageError("BarrierOp does not have parameters"); } - static ValueRange getParameters() { return {}; } + static OperandRange getParameters() { return {nullptr, 0}; } static StringRef getBaseSymbol() { return "barrier"; } }]; } @@ -960,9 +960,10 @@ def CtrlOp Value getQubit(size_t i); Value getTarget(size_t i) { return getTargets()[i]; } Value getControl(size_t i); + OperandRange getQubits() { return getOperands(); } size_t getNumParams() { return 0; } Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } - ValueRange getParameters() { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } + OperandRange getParameters() { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } static StringRef getBaseSymbol() { return "ctrl"; } }]; @@ -1010,12 +1011,12 @@ def InvOp : QCOp<"inv", size_t getNumControls() { return 0; } Value getQubit(size_t i) { return getTarget(i); } Value getTarget(size_t i) { return getQubits()[i]; } - ValueRange getTargets() { return getQubits(); } + OperandRange getTargets() { return getQubits(); } Value getControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } - ValueRange getControls() { return {nullptr, 0}; } + OperandRange getControls() { return {nullptr, 0}; } size_t getNumParams() { return 0; } Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); } - ValueRange getParameters() { return {nullptr, 0}; } + OperandRange getParameters() { return {nullptr, 0}; } static StringRef getBaseSymbol() { return "inv"; } }]; From 8bbe7222edbb5ef441c3d5a1596b09ee9c417b1e Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 12:20:09 +0200 Subject: [PATCH 21/41] :recycle: Use `OperandRange` for the return type of `getParameters` in the QCO UnitaryOpInterface Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/QCO/IR/QCODialect.h | 2 +- mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td | 2 +- mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h index 6e0ab28e0a..9ec5b5f384 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h +++ b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h @@ -103,7 +103,7 @@ template class TargetAndParameterArityTrait { } return this->getOperation()->getOperand(T + i); } - ValueRange getParameters() { + OperandRange getParameters() { return this->getOperation()->getOperands().slice(T, P); } diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td b/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td index 3384d7a048..854a3af5a8 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td @@ -168,7 +168,7 @@ def UnitaryOpInterface : OpInterface<"UnitaryOpInterface"> { "getNumParams", (ins)>, InterfaceMethod<"Returns the i-th parameter.", "Value", "getParameter", (ins "size_t":$i)>, - InterfaceMethod<"Returns a range of all parameters.", "ValueRange", + InterfaceMethod<"Returns a range of all parameters.", "OperandRange", "getParameters", (ins)>, // Convenience methods diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index b2cee4728a..3e4980b2b0 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -1028,7 +1028,7 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { Value getOutputForInput(Value input); static size_t getNumParams() { return 0; } static Value getParameter(size_t i) { llvm::reportFatalUsageError("BarrierOp has no parameters"); } - static ValueRange getParameters() { return {}; } + static OperandRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "barrier"; } }]; @@ -1124,7 +1124,7 @@ def CtrlOp : QCOOp<"ctrl", Value getOutputForInput(Value input); size_t getNumParams() { return 0; } Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } - ValueRange getParameters() { return {nullptr, 0}; } + OperandRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "ctrl"; } [[nodiscard]] std::optional getUnitaryMatrix(); }]; @@ -1197,7 +1197,7 @@ def InvOp : QCOOp<"inv", traits = [UnitaryOpInterface, Value getOutputForInput(Value input); size_t getNumParams() { return 0; } Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); } - ValueRange getParameters() { return {nullptr, 0}; } + OperandRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "inv"; } [[nodiscard]] std::optional getUnitaryMatrix(); }]; From 026ac04fd482ecc486c08344cdf104885858ef9c Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 12:46:26 +0200 Subject: [PATCH 22/41] :art: Avoid repeatedly calling `checkFinalized()` in builder Signed-off-by: burgholzer --- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 45 +++++++------------ 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 4bde38d301..03d7acb815 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -216,15 +216,13 @@ QCProgramBuilder& QCProgramBuilder::reset(Value qubit) { } \ QCProgramBuilder& QCProgramBuilder::c##OP_NAME( \ const std::variant&(PARAM), Value control) { \ - checkFinalized(); \ return mc##OP_NAME(PARAM, {control}); \ } \ QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ const std::variant&(PARAM), ValueRange controls) { \ - checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ ctrl(controls, ValueRange{}, \ - [&](ValueRange /*targets*/) { OP_NAME(param); }); \ + [&](ValueRange /*targets*/) { OP_CLASS::create(*this, param); }); \ return *this; \ } @@ -242,13 +240,12 @@ DEFINE_ZERO_TARGET_ONE_PARAMETER(GPhaseOp, gphase, theta) } \ QCProgramBuilder& QCProgramBuilder::c##OP_NAME(Value control, \ Value target) { \ - checkFinalized(); \ return mc##OP_NAME({control}, target); \ } \ QCProgramBuilder& QCProgramBuilder::mc##OP_NAME(ValueRange controls, \ Value target) { \ - checkFinalized(); \ - ctrl(controls, target, [&](ValueRange targets) { OP_NAME(targets[0]); }); \ + ctrl(controls, target, \ + [&](ValueRange targets) { OP_CLASS::create(*this, targets[0]); }); \ return *this; \ } @@ -278,16 +275,15 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg) QCProgramBuilder& QCProgramBuilder::c##OP_NAME( \ const std::variant&(PARAM), Value control, \ Value target) { \ - checkFinalized(); \ return mc##OP_NAME(PARAM, {control}, target); \ } \ QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ const std::variant&(PARAM), ValueRange controls, \ Value target) { \ - checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - ctrl(controls, target, \ - [&](ValueRange targets) { OP_NAME(param, targets[0]); }); \ + ctrl(controls, target, [&](ValueRange targets) { \ + OP_CLASS::create(*this, targets[0], param); \ + }); \ return *this; \ } @@ -312,18 +308,17 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, theta) const std::variant&(PARAM1), \ const std::variant&(PARAM2), Value control, \ Value target) { \ - checkFinalized(); \ return mc##OP_NAME(PARAM1, PARAM2, {control}, target); \ } \ QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ const std::variant&(PARAM1), \ const std::variant&(PARAM2), ValueRange controls, \ Value target) { \ - checkFinalized(); \ auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ - ctrl(controls, target, \ - [&](ValueRange targets) { OP_NAME(param1, param2, targets[0]); }); \ + ctrl(controls, target, [&](ValueRange targets) { \ + OP_CLASS::create(*this, targets[0], param1, param2); \ + }); \ return *this; \ } @@ -349,7 +344,6 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) const std::variant&(PARAM2), \ const std::variant&(PARAM3), Value control, \ Value target) { \ - checkFinalized(); \ return mc##OP_NAME(PARAM1, PARAM2, PARAM3, {control}, target); \ } \ QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ @@ -357,12 +351,11 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) const std::variant&(PARAM2), \ const std::variant&(PARAM3), ValueRange controls, \ Value target) { \ - checkFinalized(); \ auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ auto param3 = variantToValue(*this, getLoc(), PARAM3); \ ctrl(controls, target, [&](ValueRange targets) { \ - OP_NAME(param1, param2, param3, targets[0]); \ + OP_CLASS::create(*this, targets[0], param1, param2, param3); \ }); \ return *this; \ } @@ -381,14 +374,13 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, u, theta, phi, lambda) } \ QCProgramBuilder& QCProgramBuilder::c##OP_NAME(Value control, Value qubit0, \ Value qubit1) { \ - checkFinalized(); \ return mc##OP_NAME({control}, qubit0, qubit1); \ } \ QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ ValueRange controls, Value qubit0, Value qubit1) { \ - checkFinalized(); \ - ctrl(controls, ValueRange{qubit0, qubit1}, \ - [&](ValueRange targets) { OP_NAME(targets[0], targets[1]); }); \ + ctrl(controls, ValueRange{qubit0, qubit1}, [&](ValueRange targets) { \ + OP_CLASS::create(*this, targets[0], targets[1]); \ + }); \ return *this; \ } @@ -411,16 +403,15 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr) QCProgramBuilder& QCProgramBuilder::c##OP_NAME( \ const std::variant&(PARAM), Value control, Value qubit0, \ Value qubit1) { \ - checkFinalized(); \ return mc##OP_NAME(PARAM, {control}, qubit0, qubit1); \ } \ QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ const std::variant&(PARAM), ValueRange controls, \ Value qubit0, Value qubit1) { \ - checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - ctrl(controls, ValueRange{qubit0, qubit1}, \ - [&](ValueRange targets) { OP_NAME(param, targets[0], targets[1]); }); \ + ctrl(controls, ValueRange{qubit0, qubit1}, [&](ValueRange targets) { \ + OP_CLASS::create(*this, targets[0], targets[1], param); \ + }); \ return *this; \ } @@ -446,18 +437,16 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta) const std::variant&(PARAM1), \ const std::variant&(PARAM2), Value control, Value qubit0, \ Value qubit1) { \ - checkFinalized(); \ return mc##OP_NAME(PARAM1, PARAM2, {control}, qubit0, qubit1); \ } \ QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ const std::variant&(PARAM1), \ const std::variant&(PARAM2), ValueRange controls, \ Value qubit0, Value qubit1) { \ - checkFinalized(); \ auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ ctrl(controls, ValueRange{qubit0, qubit1}, [&](ValueRange targets) { \ - OP_NAME(param1, param2, targets[0], targets[1]); \ + OP_CLASS::create(*this, targets[0], targets[1], param1, param2); \ }); \ return *this; \ } From 95585c779f7b6dfe70d3ccb0430ec8d867895fe4 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 14:25:19 +0200 Subject: [PATCH 23/41] :art: Tidy up `Utils.h` helper linkage Give the non-template `Utils.h` helpers `inline` linkage instead of `static` so that they are not duplicated into every translation unit that includes the header, and modernize `valueToDouble` to use if-with-initializer. Assisted-By: Claude Opus 4.8 (1M context) --- mlir/include/mlir/Dialect/Utils/Utils.h | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index a885c2da2b..0e335aedde 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -65,12 +65,10 @@ template if (!constantOp) { return std::nullopt; } - auto floatAttr = dyn_cast(constantOp.getValue()); - if (floatAttr) { + if (auto floatAttr = dyn_cast(constantOp.getValue())) { return floatAttr.getValueAsDouble(); } - auto intAttr = dyn_cast(constantOp.getValue()); - if (intAttr) { + if (auto intAttr = dyn_cast(constantOp.getValue())) { if (intAttr.getType().isUnsignedInteger()) { return static_cast(intAttr.getValue().getZExtValue()); } @@ -82,7 +80,7 @@ template template [[nodiscard]] -static ParseResult +ParseResult parseTargetAliasing(OpAsmParser& parser, Region& region, SmallVectorImpl& operands) { // 1. Parse the opening parenthesis @@ -115,7 +113,7 @@ parseTargetAliasing(OpAsmParser& parser, Region& region, } operands.push_back(oldOperand); - // Hard-code QubitType since targets in qco.ctrl are always qubits. + // Hard-code QubitType since targets in CtrlOp are always qubits. // This avoids double-binding type($targets_in) in the assembly format // while keeping the parser simple and the assembly format clean. newArg.type = QubitType::get(parser.getBuilder().getContext()); @@ -138,7 +136,7 @@ parseTargetAliasing(OpAsmParser& parser, Region& region, return success(); } -static void printTargetAliasing(OpAsmPrinter& printer, Region& region, +inline void printTargetAliasing(OpAsmPrinter& printer, Region& region, OperandRange targetsIn) { printer << "("; if (region.empty()) { @@ -166,7 +164,7 @@ static void printTargetAliasing(OpAsmPrinter& printer, Region& region, * @brief Get the value corresponding to @p qubit from the block arguments @p * qubits if @p qubit is a block argument, otherwise return @p qubit itself. */ -static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { +inline Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { if (auto blockArg = dyn_cast(qubit)) { return qubits[blockArg.getArgNumber()]; } @@ -179,7 +177,7 @@ static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { * @details This helper function is used to resolve block arguments for nested * modifiers. */ -static void populateMapping(IRMapping& mapping, Block& block, +inline void populateMapping(IRMapping& mapping, Block& block, ValueRange innerQubits, ValueRange outerQubits, ValueRange newQubits, ValueRange qubitArgs) { assert(innerQubits.size() == block.getNumArguments() && From 805793b8711d155f95a216e2b9af9518933ccd1f Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 14:26:08 +0200 Subject: [PATCH 24/41] :recycle: Share body-unitary helpers across modifier canonicalizations Introduce templated `getNumBodyUnitaries`, `getBodyUnitary` and `getSoleBodyUnitary` helpers in `Utils.h` and route the otherwise identical `CtrlOp`/`InvOp` definitions in both dialects through them. Canonicalization patterns that need the unique body unitary now use `getSoleBodyUnitary`, replacing the repeated `getNumBodyUnitaries() != 1` plus `getBodyUnitary(0)` double scan with a single traversal. Assisted-By: Claude Opus 4.8 (1M context) --- mlir/include/mlir/Dialect/Utils/Utils.h | 49 ++++++++++++++++++++ mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 25 ++++------ mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 37 +++++++-------- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 30 +++++------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 43 ++++++++--------- 5 files changed, 107 insertions(+), 77 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 0e335aedde..2bd3ebcb9f 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -10,13 +10,18 @@ #pragma once +#include +#include #include +#include #include #include #include #include #include +#include +#include #include namespace mlir::utils { @@ -194,4 +199,48 @@ inline void populateMapping(IRMapping& mapping, Block& block, } } +/** + * @brief Returns the number of operations implementing @p UnitaryInterface in + * @p block. + */ +template +[[nodiscard]] size_t getNumBodyUnitaries(Block& block) { + return static_cast(llvm::count_if( + block, [](Operation& op) { return isa(op); })); +} + +/** + * @brief Returns the @p i-th operation implementing @p UnitaryInterface in + * @p block, reporting a fatal error if @p i is out of bounds. + */ +template +[[nodiscard]] UnitaryInterface getBodyUnitary(Block& block, size_t i) { + auto unitaries = llvm::make_filter_range( + block, [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), static_cast(i)); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); + } + return cast(*it); +} + +/** + * @brief Returns the single operation implementing @p UnitaryInterface in + * @p block, or a null interface if @p block does not contain exactly one. + */ +template +[[nodiscard]] UnitaryInterface getSoleBodyUnitary(Block& block) { + auto unitaries = llvm::make_filter_range( + block, [](Operation& op) { return isa(op); }); + auto it = unitaries.begin(); + if (it == unitaries.end()) { + return {}; + } + auto unitary = cast(*it); + if (++it != unitaries.end()) { + return {}; + } + return unitary; +} + } // namespace mlir::utils diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 059e7dc736..4d74d11a79 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -44,10 +44,11 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + auto innerCtrlOp = dyn_cast(inner.getOperation()); if (!innerCtrlOp) { return failure(); } @@ -92,10 +93,11 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto* innerOp = op.getBodyUnitary(0).getOperation(); + auto* innerOp = inner.getOperation(); // Inline ops from empty control modifiers, IdOp and BarrierOp if (op.getNumControls() == 0 || isa(innerOp)) { @@ -118,7 +120,6 @@ struct ReduceCtrl final : OpRewritePattern { if (!gPhaseOp) { return failure(); } - // Special case for single control: replace with a single POp if (op.getNumControls() == 1) { rewriter.replaceOpWithNewOp(op, op.getControl(0), @@ -126,7 +127,8 @@ struct ReduceCtrl final : OpRewritePattern { return success(); } - // Adjust the segment sizes of the control and target operands + // Reinterpret the last control as a target qubit and apply a phase gate to + // it inside the (smaller) controlled region. const auto opSegmentsAttrName = CtrlOp::getOperandSegmentSizeAttr(); auto segmentsAttr = op->getAttrOfType(opSegmentsAttrName); @@ -167,18 +169,11 @@ struct EraseEmptyCtrl final : OpRewritePattern { } // namespace size_t CtrlOp::getNumBodyUnitaries() { - return llvm::count_if( - *getBody(), [](Operation& op) { return isa(op); }); + return utils::getNumBodyUnitaries(*getBody()); } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { - auto unitaries = llvm::make_filter_range( - *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), static_cast(i)); - if (it == unitaries.end()) { - llvm::reportFatalUsageError("Unitary index out of bounds"); - } - return cast(*it); + return utils::getBodyUnitary(*getBody(), i); } Value CtrlOp::getQubit(const size_t i) { diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index b1cf6e46ae..39d7541b23 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -42,10 +42,11 @@ struct MoveCtrlOutside final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + auto innerCtrlOp = dyn_cast(inner.getOperation()); if (!innerCtrlOp) { return failure(); } @@ -85,10 +86,11 @@ struct InlineSelfAdjoint final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto* innerOp = op.getBodyUnitary(0).getOperation(); + auto* innerOp = inner.getOperation(); if (!isa(innerOp)) { return failure(); @@ -119,10 +121,11 @@ struct ReplaceWithKnownGates final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto* innerOp = op.getBodyUnitary(0).getOperation(); + auto* innerOp = inner.getOperation(); auto loc = op.getLoc(); auto outerQubits = op.getQubits(); @@ -307,18 +310,21 @@ struct CancelNestedInv final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto innerInvOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + auto innerInvOp = dyn_cast(inner.getOperation()); if (!innerInvOp) { return failure(); } - if (innerInvOp.getNumBodyUnitaries() != 1) { + auto innerInner = + utils::getSoleBodyUnitary(*innerInvOp.getBody()); + if (!innerInner) { return failure(); } - auto* innerInnerOp = innerInvOp.getBodyUnitary(0).getOperation(); + auto* innerInnerOp = innerInner.getOperation(); const auto numQubits = op.getNumQubits(); auto outerQubits = op.getQubits(); @@ -356,18 +362,11 @@ struct EraseEmptyInv final : OpRewritePattern { } // namespace size_t InvOp::getNumBodyUnitaries() { - return llvm::count_if( - *getBody(), [](Operation& op) { return isa(op); }); + return utils::getNumBodyUnitaries(*getBody()); } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { - auto unitaries = llvm::make_filter_range( - *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), static_cast(i)); - if (it == unitaries.end()) { - llvm::reportFatalUsageError("Unitary index out of bounds"); - } - return cast(*it); + return utils::getBodyUnitary(*getBody(), i); } void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index e06fdb641c..1f99383994 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -51,10 +51,11 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + auto innerCtrlOp = dyn_cast(inner.getOperation()); if (!innerCtrlOp) { return failure(); } @@ -105,10 +106,11 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto* innerOp = op.getBodyUnitary(0).getOperation(); + auto* innerOp = inner.getOperation(); // Inline ops from empty control modifiers, IdOp and BarrierOp if (op.getNumControls() == 0 || isa(innerOp)) { @@ -141,7 +143,8 @@ struct ReduceCtrl final : OpRewritePattern { return success(); } - // Adjust the segment sizes of the control and target operands + // Reinterpret the last control as a target qubit and apply a phase gate to + // it inside the (smaller) controlled region. const auto opSegmentsAttrName = CtrlOp::getOperandSegmentSizeAttr(); auto segmentsAttr = op->getAttrOfType(opSegmentsAttrName); @@ -191,18 +194,11 @@ struct EraseEmptyCtrl final : OpRewritePattern { } // namespace size_t CtrlOp::getNumBodyUnitaries() { - return llvm::count_if( - *getBody(), [](Operation& op) { return isa(op); }); + return utils::getNumBodyUnitaries(*getBody()); } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { - auto unitaries = llvm::make_filter_range( - *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), static_cast(i)); - if (it == unitaries.end()) { - llvm::reportFatalUsageError("Unitary index out of bounds"); - } - return cast(*it); + return utils::getBodyUnitary(*getBody(), i); } Value CtrlOp::getInputQubit(const size_t i) { @@ -363,11 +359,7 @@ void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional CtrlOp::getUnitaryMatrix() { - if (getNumBodyUnitaries() != 1) { - return std::nullopt; - } - - auto bodyUnitary = getBodyUnitary(0); + auto bodyUnitary = utils::getSoleBodyUnitary(*getBody()); if (!bodyUnitary) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index a6eee6d4bc..dd0d01c984 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -45,10 +45,11 @@ struct MoveCtrlOutside final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + auto innerCtrlOp = dyn_cast(inner.getOperation()); if (!innerCtrlOp) { return failure(); } @@ -97,10 +98,11 @@ struct InlineSelfAdjoint final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto* innerOp = op.getBodyUnitary(0).getOperation(); + auto* innerOp = inner.getOperation(); if (!isa(innerOp)) { return failure(); @@ -131,10 +133,11 @@ struct ReplaceWithKnownGates final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto* innerOp = op.getBodyUnitary(0).getOperation(); + auto* innerOp = inner.getOperation(); auto loc = op.getLoc(); auto outerQubits = op.getInputQubits(); @@ -331,18 +334,21 @@ struct CancelNestedInv final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - if (op.getNumBodyUnitaries() != 1) { + auto inner = utils::getSoleBodyUnitary(*op.getBody()); + if (!inner) { return failure(); } - auto innerInvOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + auto innerInvOp = dyn_cast(inner.getOperation()); if (!innerInvOp) { return failure(); } - if (innerInvOp.getNumBodyUnitaries() != 1) { + auto innerInner = + utils::getSoleBodyUnitary(*innerInvOp.getBody()); + if (!innerInner) { return failure(); } - auto* innerInnerOp = innerInvOp.getBodyUnitary(0).getOperation(); + auto* innerInnerOp = innerInner.getOperation(); const auto numQubits = op.getNumQubits(); auto outerQubits = op.getInputQubits(); @@ -380,18 +386,11 @@ struct EraseEmptyInv final : OpRewritePattern { } // namespace size_t InvOp::getNumBodyUnitaries() { - return llvm::count_if( - *getBody(), [](Operation& op) { return isa(op); }); + return utils::getNumBodyUnitaries(*getBody()); } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { - auto unitaries = llvm::make_filter_range( - *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), static_cast(i)); - if (it == unitaries.end()) { - llvm::reportFatalUsageError("Unitary index out of bounds"); - } - return cast(*it); + return utils::getBodyUnitary(*getBody(), i); } Value InvOp::getInputQubit(const size_t i) { @@ -490,11 +489,7 @@ void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional InvOp::getUnitaryMatrix() { - if (getNumBodyUnitaries() != 1) { - return std::nullopt; - } - - auto bodyUnitary = getBodyUnitary(0); + auto bodyUnitary = utils::getSoleBodyUnitary(*getBody()); if (!bodyUnitary) { return std::nullopt; } From 13ee270cc663454e843b84602d464ecde5eb20a0 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 14:26:17 +0200 Subject: [PATCH 25/41] :recycle: Share region-inlining helper across dialect conversions Extract the "inline a region and convert its entry-block signature" boilerplate, previously duplicated inline in the QC-to-QCO and QCO-to-QC modifier conversions and defined locally in QCOToJeff, into a shared `moveRegion` helper in `mlir/Conversion/ConversionUtils.h`. Assisted-By: Claude Opus 4.8 (1M context) --- .../include/mlir/Conversion/ConversionUtils.h | 49 +++++++++++++++++++ mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp | 18 +------ mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 19 ++----- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 23 +++------ 4 files changed, 62 insertions(+), 47 deletions(-) create mode 100644 mlir/include/mlir/Conversion/ConversionUtils.h diff --git a/mlir/include/mlir/Conversion/ConversionUtils.h b/mlir/include/mlir/Conversion/ConversionUtils.h new file mode 100644 index 0000000000..1144ba66d0 --- /dev/null +++ b/mlir/include/mlir/Conversion/ConversionUtils.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include +#include +#include +#include + +namespace mlir { + +/** + * @brief Inlines @p source into @p dest and converts the entry block signature. + * + * @details Moves all blocks of @p source to the end of @p dest and converts the + * argument types of the resulting entry block using @p typeConverter. This is + * the canonical way to migrate a region from one dialect to another during a + * dialect conversion when the block arguments change type. + * + * @param source The region whose blocks are moved out. + * @param dest The region the blocks are moved into. + * @param rewriter The conversion rewriter driving the current pass. + * @param typeConverter The type converter used to convert the entry block + * signature. + * @return Whether converting the entry block signature succeeded. + */ +inline LogicalResult moveRegion(Region& source, Region& dest, + ConversionPatternRewriter& rewriter, + const TypeConverter* typeConverter) { + rewriter.inlineRegionBefore(source, dest, dest.end()); + auto* block = &dest.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed( + typeConverter->convertSignatureArgs(block->getArgumentTypes(), sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); + return success(); +} + +} // namespace mlir diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index 96598b2c82..a6c947ad3d 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/QCOToJeff/QCOToJeff.h" +#include "mlir/Conversion/ConversionUtils.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QTensor/IR/QTensorDialect.h" @@ -341,23 +342,6 @@ static LogicalResult cleanUp(Operation* op, LoweringState& state) { return success(); } -/** - * @brief Move a region from QCO/SCF operation to a jeff operation - */ -static LogicalResult moveRegion(Region& source, Region& dest, - ConversionPatternRewriter& rewriter, - const TypeConverter* typeConverter) { - rewriter.inlineRegionBefore(source, dest, dest.end()); - auto* block = &dest.front(); - TypeConverter::SignatureConversion sc(block->getNumArguments()); - if (failed( - typeConverter->convertSignatureArgs(block->getArgumentTypes(), sc))) { - return failure(); - } - rewriter.applySignatureConversion(block, sc); - return success(); -} - namespace { /** diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index a8ba24090d..768e2d2379 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/QCOToQC/QCOToQC.h" +#include "mlir/Conversion/ConversionUtils.h" #include "mlir/Conversion/GateTable.h" #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/IR/QCOps.h" @@ -645,15 +646,10 @@ struct ConvertQCOCtrlOp final : OpConversionPattern { auto qcOp = qc::CtrlOp::create( rewriter, op.getLoc(), adaptor.getControlsIn(), adaptor.getTargetsIn()); - auto& dstRegion = qcOp.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); - auto* block = &dstRegion.front(); - TypeConverter::SignatureConversion sc(block->getNumArguments()); - if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), - sc))) { + if (failed(moveRegion(op.getRegion(), qcOp.getRegion(), rewriter, + getTypeConverter()))) { return failure(); } - rewriter.applySignatureConversion(block, sc); // Replace the output qubits with the same QC references rewriter.replaceOp(op, adaptor.getOperands()); @@ -688,15 +684,10 @@ struct ConvertQCOInvOp final : OpConversionPattern { // Create qc.inv operation auto qcOp = qc::InvOp::create(rewriter, op.getLoc(), adaptor.getQubitsIn()); - auto& dstRegion = qcOp.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); - auto* block = &dstRegion.front(); - TypeConverter::SignatureConversion sc(block->getNumArguments()); - if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), - sc))) { + if (failed(moveRegion(op.getRegion(), qcOp.getRegion(), rewriter, + getTypeConverter()))) { return failure(); } - rewriter.applySignatureConversion(block, sc); // Replace the output qubits with the same QC references rewriter.replaceOp(op, adaptor.getOperands()); diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 65877a3120..f778521d06 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/QCToQCO/QCToQCO.h" +#include "mlir/Conversion/ConversionUtils.h" #include "mlir/Conversion/GateTable.h" #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/IR/QCOps.h" @@ -1110,16 +1111,11 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { auto qcArgs = op.getRegion().front().getArguments(); - // Inline region - auto& dstRegion = qcoOp.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); - auto* block = &dstRegion.front(); - TypeConverter::SignatureConversion sc(block->getNumArguments()); - if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), - sc))) { + // Inline region and convert the block signature to QCO types. + if (failed(moveRegion(op.getRegion(), qcoOp.getRegion(), rewriter, + getTypeConverter()))) { return failure(); } - rewriter.applySignatureConversion(block, sc); pushModifierFrame(state, qcArgs, qcoOp.getRegion().front().getArguments()); @@ -1163,16 +1159,11 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { auto qcArgs = op.getRegion().front().getArguments(); - // Inline region - auto& dstRegion = qcoOp.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); - auto* block = &dstRegion.front(); - TypeConverter::SignatureConversion sc(block->getNumArguments()); - if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), - sc))) { + // Inline region and convert the block signature to QCO types. + if (failed(moveRegion(op.getRegion(), qcoOp.getRegion(), rewriter, + getTypeConverter()))) { return failure(); } - rewriter.applySignatureConversion(block, sc); pushModifierFrame(state, qcArgs, qcoOp.getRegion().front().getArguments()); From 822055f62179dac946b17702f916197bb567f093 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 14:50:35 +0200 Subject: [PATCH 26/41] :recycle: Pull out an `inlineModifierBody` helper for simplifying canonicalizations of modifier operations Assisted-By: Claude Opus 4.8 (1M context) Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/Utils/Utils.h | 18 ++++++++ mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 12 +----- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 44 +++++++------------- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 17 +++----- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 44 +++++++------------- 5 files changed, 55 insertions(+), 80 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 2bd3ebcb9f..1d5cd96884 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -243,4 +244,21 @@ template return unitary; } +/** + * @brief Inlines a modifier body and replaces the modifier with its results. + * + * @details Inlines the operations of @p body in front of @p op, substituting + * the block arguments of @p body with @p blockArgReplacements, and replaces + * @p op with the values yielded by the body's terminator. + */ +inline void inlineModifierBody(Operation* op, Block& body, + ValueRange blockArgReplacements, + RewriterBase& rewriter) { + auto* terminator = body.getTerminator(); + const SmallVector results(terminator->getOperands()); + rewriter.inlineBlockBefore(&body, op, blockArgReplacements); + rewriter.eraseOp(terminator); + rewriter.replaceOp(op, results); +} + } // namespace mlir::utils diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 4d74d11a79..80549da414 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -101,17 +101,7 @@ struct ReduceCtrl final : OpRewritePattern { // Inline ops from empty control modifiers, IdOp and BarrierOp if (op.getNumControls() == 0 || isa(innerOp)) { - const auto numTargets = op.getNumTargets(); - auto outerTargets = op.getTargets(); - SmallVector targets; - for (auto target : innerOp->getOperands().take_front(numTargets)) { - targets.push_back( - utils::getValueFromBlockArgument(target, outerTargets)); - } - - rewriter.moveOpBefore(innerOp, op); - innerOp->setOperands(0, numTargets, targets); - rewriter.eraseOp(op); + utils::inlineModifierBody(op, *op.getBody(), op.getTargets(), rewriter); return success(); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 39d7541b23..db53f95fb2 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -90,22 +90,15 @@ struct InlineSelfAdjoint final : OpRewritePattern { if (!inner) { return failure(); } - auto* innerOp = inner.getOperation(); - if (!isa(innerOp)) { + if (!isa( + inner.getOperation())) { return failure(); } - const auto numQubits = op.getNumQubits(); - auto outerQubits = op.getQubits(); - SmallVector qubits; - for (auto qubit : innerOp->getOperands().take_front(numQubits)) { - qubits.push_back(utils::getValueFromBlockArgument(qubit, outerQubits)); - } - - rewriter.moveOpBefore(innerOp, op); - innerOp->setOperands(0, numQubits, qubits); - rewriter.replaceOp(op, innerOp->getResults()); + // A self-adjoint gate is its own inverse, so the modifier can be dropped + // and its body applied directly to the involved qubits. + utils::inlineModifierBody(op, *op.getBody(), op.getQubits(), rewriter); return success(); } }; @@ -318,27 +311,20 @@ struct CancelNestedInv final : OpRewritePattern { if (!innerInvOp) { return failure(); } - - auto innerInner = - utils::getSoleBodyUnitary(*innerInvOp.getBody()); - if (!innerInner) { + if (!utils::getSoleBodyUnitary(*innerInvOp.getBody())) { return failure(); } - auto* innerInnerOp = innerInner.getOperation(); - const auto numQubits = op.getNumQubits(); - auto outerQubits = op.getQubits(); - auto innerQubits = innerInvOp.getQubits(); - SmallVector qubits; - for (auto qubit : innerInnerOp->getOperands().take_front(numQubits)) { - auto innerQubit = utils::getValueFromBlockArgument(qubit, innerQubits); - qubits.push_back( - utils::getValueFromBlockArgument(innerQubit, outerQubits)); + // inv(inv(x)) == x: inline the doubly-nested body directly onto the outer + // qubits. The inner body's block arguments alias the inner modifier's + // inputs, which in turn alias the outer qubits. + SmallVector replacements; + for (auto innerInput : innerInvOp.getQubits()) { + replacements.push_back( + utils::getValueFromBlockArgument(innerInput, op.getQubits())); } - - rewriter.moveOpBefore(innerInnerOp, op); - innerInnerOp->setOperands(0, numQubits, qubits); - rewriter.replaceOp(op, innerInnerOp->getResults()); + utils::inlineModifierBody(op, *innerInvOp.getBody(), replacements, + rewriter); return success(); } }; diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 1f99383994..935d4e43b2 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -114,18 +114,13 @@ struct ReduceCtrl final : OpRewritePattern { // Inline ops from empty control modifiers, IdOp and BarrierOp if (op.getNumControls() == 0 || isa(innerOp)) { - const auto numTargets = op.getNumTargets(); - auto outerTargets = op.getTargetsIn(); - SmallVector targets; - for (auto target : innerOp->getOperands().take_front(numTargets)) { - targets.push_back( - utils::getValueFromBlockArgument(target, outerTargets)); - } - - rewriter.moveOpBefore(innerOp, op); - innerOp->setOperands(0, numTargets, targets); + auto* body = op.getBody(); + auto* terminator = body->getTerminator(); + const SmallVector targets(terminator->getOperands()); + rewriter.inlineBlockBefore(body, op, op.getTargetsIn()); + rewriter.eraseOp(terminator); rewriter.replaceAllUsesWith(op.getControlsOut(), op.getControlsIn()); - rewriter.replaceAllUsesWith(op.getTargetsOut(), innerOp->getResults()); + rewriter.replaceAllUsesWith(op.getTargetsOut(), targets); rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index dd0d01c984..924f80eaa2 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -102,22 +102,15 @@ struct InlineSelfAdjoint final : OpRewritePattern { if (!inner) { return failure(); } - auto* innerOp = inner.getOperation(); - if (!isa(innerOp)) { + if (!isa( + inner.getOperation())) { return failure(); } - const auto numQubits = op.getNumQubits(); - auto outerQubits = op.getInputQubits(); - SmallVector qubits; - for (auto qubit : innerOp->getOperands().take_front(numQubits)) { - qubits.push_back(utils::getValueFromBlockArgument(qubit, outerQubits)); - } - - rewriter.moveOpBefore(innerOp, op); - innerOp->setOperands(0, numQubits, qubits); - rewriter.replaceOp(op, innerOp->getResults()); + // A self-adjoint gate is its own inverse, so the modifier can be dropped + // and its body applied directly to the input qubits. + utils::inlineModifierBody(op, *op.getBody(), op.getInputQubits(), rewriter); return success(); } }; @@ -342,27 +335,20 @@ struct CancelNestedInv final : OpRewritePattern { if (!innerInvOp) { return failure(); } - - auto innerInner = - utils::getSoleBodyUnitary(*innerInvOp.getBody()); - if (!innerInner) { + if (!utils::getSoleBodyUnitary(*innerInvOp.getBody())) { return failure(); } - auto* innerInnerOp = innerInner.getOperation(); - const auto numQubits = op.getNumQubits(); - auto outerQubits = op.getInputQubits(); - auto innerQubits = innerInvOp.getInputQubits(); - SmallVector qubits; - for (auto qubit : innerInnerOp->getOperands().take_front(numQubits)) { - auto innerQubit = utils::getValueFromBlockArgument(qubit, innerQubits); - qubits.push_back( - utils::getValueFromBlockArgument(innerQubit, outerQubits)); + // inv(inv(x)) == x: inline the doubly-nested body directly onto the outer + // input qubits. The inner body's block arguments alias the inner modifier's + // inputs, which in turn alias the outer input qubits. + SmallVector replacements; + for (auto innerInput : innerInvOp.getInputQubits()) { + replacements.push_back( + utils::getValueFromBlockArgument(innerInput, op.getInputQubits())); } - - rewriter.moveOpBefore(innerInnerOp, op); - innerInnerOp->setOperands(0, numQubits, qubits); - rewriter.replaceOp(op, innerInnerOp->getResults()); + utils::inlineModifierBody(op, *innerInvOp.getBody(), replacements, + rewriter); return success(); } }; From ff1e4f5486c6b5e1bb0e817b861353d6d4a140e8 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 16:27:16 +0200 Subject: [PATCH 27/41] :recycle: Simplify CtrlOp and InvOp canonicalizations by relying more on `inlineRegionBefore` Assisted-By: Claude Opus 4.8 (1M context) Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/Utils/Utils.h | 24 -- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 43 +-- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 338 ++++++++--------- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 70 ++-- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 362 ++++++++----------- 5 files changed, 354 insertions(+), 483 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 1d5cd96884..33562fb197 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -15,7 +15,6 @@ #include #include #include -#include #include #include #include @@ -177,29 +176,6 @@ inline Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { return qubit; } -/** - * @brief Create a mapping between block arguments and qubit values. - * - * @details This helper function is used to resolve block arguments for nested - * modifiers. - */ -inline void populateMapping(IRMapping& mapping, Block& block, - ValueRange innerQubits, ValueRange outerQubits, - ValueRange newQubits, ValueRange qubitArgs) { - assert(innerQubits.size() == block.getNumArguments() && - "Size of innerQubits must match number of block arguments"); - for (auto arg : block.getArguments()) { - auto innerQubit = innerQubits[arg.getArgNumber()]; - auto outerQubit = getValueFromBlockArgument(innerQubit, outerQubits); - if (auto it = llvm::find(newQubits, outerQubit); it != newQubits.end()) { - auto index = std::distance(newQubits.begin(), it); - mapping.map(arg, qubitArgs[index]); - } else { - llvm::reportFatalInternalError("Outer qubit not found in new qubits"); - } - } -} - /** * @brief Returns the number of operations implementing @p UnitaryInterface in * @p block. diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 80549da414..db9fed7a8c 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -53,33 +52,27 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } - auto outerControls = op.getControls(); + // The inner control's controls and targets are block arguments of the outer + // body that alias outer targets. Re-resolve them to the outer qubits: inner + // controls join the outer controls, inner targets become the merged + // targets. Keeping the inner-target order lets the inner body be reused + // verbatim, since its block arguments already line up with the merged + // targets. auto outerTargets = op.getTargets(); - auto innerTargets = innerCtrlOp.getTargets(); - - SmallVector controls; - SmallVector targets; - llvm::append_range(controls, outerControls); - for (auto [arg, qubit] : - llvm::zip_equal(op.getBody()->getArguments(), outerTargets)) { - if (llvm::is_contained(innerTargets, arg)) { - targets.push_back(qubit); - } else { - controls.push_back(qubit); - } + SmallVector controls(op.getControls()); + for (auto control : innerCtrlOp.getControls()) { + controls.push_back( + utils::getValueFromBlockArgument(control, outerTargets)); } - - rewriter.replaceOpWithNewOp( - op, controls, targets, [&](ValueRange targetArgs) { - auto* innerCtrlBody = innerCtrlOp.getBody(); - IRMapping mapping; - utils::populateMapping(mapping, *innerCtrlBody, innerTargets, - outerTargets, targets, targetArgs); - for (auto& op : innerCtrlBody->without_terminator()) { - rewriter.clone(op, mapping); - } + const auto targets = + llvm::map_to_vector(innerCtrlOp.getTargets(), [&](Value t) { + return utils::getValueFromBlockArgument(t, outerTargets); }); + auto merged = CtrlOp::create(rewriter, op.getLoc(), controls, targets); + rewriter.inlineRegionBefore(innerCtrlOp.getRegion(), merged.getRegion(), + merged.getRegion().end()); + rewriter.eraseOp(op); return success(); } }; @@ -141,7 +134,7 @@ struct ReduceCtrl final : OpRewritePattern { }; /** - * @brief Erase control modifiers that do not have any body unitaries. + * @brief Erase control modifiers without unitary operations in the body. */ struct EraseEmptyCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index db53f95fb2..3ae36f1fd6 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -51,25 +50,26 @@ struct MoveCtrlOutside final : OpRewritePattern { return failure(); } - const auto numControls = innerCtrlOp.getNumControls(); - const auto numTargets = innerCtrlOp.getNumTargets(); + // The inner control's controls and targets are block arguments aliasing the + // inverse modifier's qubits. Pull the controls out to a new control + // modifier and wrap the inner body in an inverse modifier whose block + // arguments match the inner targets, so the inner body is reused verbatim. auto outerQubits = op.getQubits(); - auto controls = outerQubits.take_front(numControls); - auto targets = outerQubits.take_back(numTargets); + const auto controls = + llvm::map_to_vector(innerCtrlOp.getControls(), [&](Value c) { + return utils::getValueFromBlockArgument(c, outerQubits); + }); + const auto targets = + llvm::map_to_vector(innerCtrlOp.getTargets(), [&](Value t) { + return utils::getValueFromBlockArgument(t, outerQubits); + }); rewriter.replaceOpWithNewOp( op, controls, targets, [&](ValueRange targetArgs) { - InvOp::create( - rewriter, op.getLoc(), targetArgs, [&](ValueRange qubitArgs) { - auto* innerCtrlBody = innerCtrlOp.getBody(); - IRMapping mapping; - utils::populateMapping(mapping, *innerCtrlBody, - innerCtrlOp.getTargets(), outerQubits, - targets, qubitArgs); - for (auto& op : innerCtrlBody->without_terminator()) { - rewriter.clone(op, mapping); - } - }); + auto innerInv = InvOp::create(rewriter, op.getLoc(), targetArgs); + rewriter.inlineRegionBefore(innerCtrlOp.getRegion(), + innerInv.getRegion(), + innerInv.getRegion().end()); }); return success(); @@ -120,179 +120,132 @@ struct ReplaceWithKnownGates final : OpRewritePattern { } auto* innerOp = inner.getOperation(); - auto loc = op.getLoc(); - auto outerQubits = op.getQubits(); + // Replace the body gate in place with its inverse, operating on the same + // (block-argument) operands; inlining the body afterwards substitutes those + // block arguments with the modifier's qubits. + const auto loc = innerOp->getLoc(); + rewriter.setInsertionPoint(innerOp); + const auto negTheta = [&](auto g) { + return arith::NegFOp::create(rewriter, loc, g.getTheta()).getResult(); + }; + const auto replaced = + TypeSwitch(innerOp) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0), negTheta(g), + g.getPhi()); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + Value newPhi = + arith::NegFOp::create(rewriter, loc, g.getLambda()); + Value newLambda = + arith::NegFOp::create(rewriter, loc, g.getPhi()); + Value newTheta = + arith::NegFOp::create(rewriter, loc, g.getTheta()); + rewriter.replaceOpWithNewOp(g, g.getTarget(0), newTheta, + newPhi, newLambda); + return success(); + }) + .Case([&](auto g) { + Value pi = arith::ConstantOp::create( + rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); + Value newPhi = + arith::NegFOp::create(rewriter, loc, g.getLambda()); + newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); + Value newLambda = + arith::NegFOp::create(rewriter, loc, g.getPhi()); + newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); + rewriter.replaceOpWithNewOp(g, g.getTarget(0), newPhi, + newLambda); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(1), + g.getTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0), + g.getTarget(1), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0), + g.getTarget(1), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0), + g.getTarget(1), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getTarget(0), + g.getTarget(1), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp( + g, g.getTarget(0), g.getTarget(1), negTheta(g), g.getBeta()); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp( + g, g.getTarget(0), g.getTarget(1), negTheta(g), g.getBeta()); + return success(); + }) + .Default([&](auto) { return failure(); }); + + if (failed(replaced)) { + return failure(); + } - return TypeSwitch(innerOp) - .Case([&](auto g) { - Value negTheta = arith::NegFOp::create(rewriter, loc, g.getTheta()); - rewriter.replaceOpWithNewOp(op, negTheta); - return success(); - }) - .Case([&](auto t) { - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(t.getTarget(0), outerQubits)); - return success(); - }) - .Case([&](auto tdg) { - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(tdg.getTarget(0), outerQubits)); - return success(); - }) - .Case([&](auto s) { - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(s.getTarget(0), outerQubits)); - return success(); - }) - .Case([&](auto sdg) { - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(sdg.getTarget(0), outerQubits)); - return success(); - }) - .Case([&](auto sx) { - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(sx.getTarget(0), outerQubits)); - return success(); - }) - .Case([&](auto sxdg) { - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(sxdg.getTarget(0), outerQubits)); - return success(); - }) - .Case([&](auto p) { - Value negTheta = arith::NegFOp::create(rewriter, loc, p.getTheta()); - rewriter.replaceOpWithNewOp( - op, utils::getValueFromBlockArgument(p.getTarget(0), outerQubits), - negTheta); - return success(); - }) - .Case([&](auto r) { - auto negTheta = arith::NegFOp::create(rewriter, loc, r.getTheta()); - rewriter.replaceOpWithNewOp( - op, utils::getValueFromBlockArgument(r.getTarget(0), outerQubits), - negTheta, r.getPhi()); - return success(); - }) - .Case([&](auto rx) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rx.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rx.getTarget(0), outerQubits), - negTheta); - return success(); - }) - .Case([&](auto u) { - Value newPhi = arith::NegFOp::create(rewriter, loc, u.getLambda()); - Value newLambda = arith::NegFOp::create(rewriter, loc, u.getPhi()); - Value newTheta = arith::NegFOp::create(rewriter, loc, u.getTheta()); - rewriter.replaceOpWithNewOp( - op, utils::getValueFromBlockArgument(u.getTarget(0), outerQubits), - newTheta, newPhi, newLambda); - return success(); - }) - .Case([&](auto u2) { - Value pi = arith::ConstantOp::create( - rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); - Value newPhi = arith::NegFOp::create(rewriter, loc, u2.getLambda()); - newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); - Value newLambda = arith::NegFOp::create(rewriter, loc, u2.getPhi()); - newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(u2.getTarget(0), outerQubits), - newPhi, newLambda); - return success(); - }) - .Case([&](auto dcx) { - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(dcx.getTarget(1), outerQubits), - utils::getValueFromBlockArgument(dcx.getTarget(0), outerQubits)); - return success(); - }) - .Case([&](auto rxx) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rxx.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rxx.getTarget(0), outerQubits), - utils::getValueFromBlockArgument(rxx.getTarget(1), outerQubits), - negTheta); - return success(); - }) - .Case([&](auto ry) { - Value negTheta = arith::NegFOp::create(rewriter, loc, ry.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(ry.getTarget(0), outerQubits), - negTheta); - return success(); - }) - .Case([&](auto ryy) { - Value negTheta = arith::NegFOp::create(rewriter, loc, ryy.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(ryy.getTarget(0), outerQubits), - utils::getValueFromBlockArgument(ryy.getTarget(1), outerQubits), - negTheta); - return success(); - }) - .Case([&](auto rz) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rz.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rz.getTarget(0), outerQubits), - negTheta); - return success(); - }) - .Case([&](auto rzx) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rzx.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rzx.getTarget(0), outerQubits), - utils::getValueFromBlockArgument(rzx.getTarget(1), outerQubits), - negTheta); - return success(); - }) - .Case([&](auto rzz) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rzz.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rzz.getTarget(0), outerQubits), - utils::getValueFromBlockArgument(rzz.getTarget(1), outerQubits), - negTheta); - return success(); - }) - .Case([&](auto xxminusyy) { - Value negTheta = - arith::NegFOp::create(rewriter, loc, xxminusyy.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(xxminusyy.getTarget(0), - outerQubits), - utils::getValueFromBlockArgument(xxminusyy.getTarget(1), - outerQubits), - negTheta, xxminusyy.getBeta()); - return success(); - }) - .Case([&](auto xxplusyy) { - Value negTheta = - arith::NegFOp::create(rewriter, loc, xxplusyy.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(xxplusyy.getTarget(0), - outerQubits), - utils::getValueFromBlockArgument(xxplusyy.getTarget(1), - outerQubits), - negTheta, xxplusyy.getBeta()); - return success(); - }) - .Default([&](auto) { return failure(); }); + utils::inlineModifierBody(op, *op.getBody(), op.getQubits(), rewriter); + return success(); } }; @@ -318,11 +271,10 @@ struct CancelNestedInv final : OpRewritePattern { // inv(inv(x)) == x: inline the doubly-nested body directly onto the outer // qubits. The inner body's block arguments alias the inner modifier's // inputs, which in turn alias the outer qubits. - SmallVector replacements; - for (auto innerInput : innerInvOp.getQubits()) { - replacements.push_back( - utils::getValueFromBlockArgument(innerInput, op.getQubits())); - } + const auto replacements = + llvm::map_to_vector(innerInvOp.getQubits(), [&](Value q) { + return utils::getValueFromBlockArgument(q, op.getQubits()); + }); utils::inlineModifierBody(op, *innerInvOp.getBody(), replacements, rewriter); return success(); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 935d4e43b2..b1fb647d95 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -60,39 +59,54 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } + // The inner control's controls and targets are block arguments of the outer + // body that alias outer targets. Re-resolve them to the outer qubits: inner + // controls join the outer controls, inner targets become the merged + // targets. Inner-target order is kept so the inner body's block arguments + // line up with the merged targets and the body can be reused verbatim. auto outerTargets = op.getTargetsIn(); - auto outerControls = op.getControlsIn(); + auto innerControls = innerCtrlOp.getControlsIn(); auto innerTargets = innerCtrlOp.getTargetsIn(); - SmallVector controls; - SmallVector targets; - llvm::append_range(controls, outerControls); - for (auto [arg, qubit] : - llvm::zip_equal(op.getBody()->getArguments(), outerTargets)) { - if (llvm::is_contained(innerTargets, arg)) { - targets.push_back(qubit); + SmallVector controls(op.getControlsIn()); + for (auto control : innerControls) { + controls.push_back( + utils::getValueFromBlockArgument(control, outerTargets)); + } + const auto targets = llvm::map_to_vector(innerTargets, [&](Value t) { + return utils::getValueFromBlockArgument(t, outerTargets); + }); + + auto merged = CtrlOp::create(rewriter, op.getLoc(), controls, targets); + rewriter.inlineRegionBefore(innerCtrlOp.getRegion(), merged.getRegion(), + merged.getRegion().end()); + + // Outer and inner controls pass through to the merged controls; each outer + // target follows its block argument to either a merged control output (if + // it was an inner control) or a merged target output (if it was an inner + // target). + const auto numOuterControls = op.getNumControls(); + rewriter.replaceAllUsesWith( + op.getControlsOut(), + merged.getControlsOut().take_front(numOuterControls)); + auto innerControlsOut = + merged.getControlsOut().drop_front(numOuterControls); + auto mergedTargetsOut = merged.getTargetsOut(); + for (auto [blockArg, outerTargetOut] : + llvm::zip_equal(op.getBody()->getArguments(), op.getTargetsOut())) { + if (auto it = llvm::find(innerControls, blockArg); + it != innerControls.end()) { + rewriter.replaceAllUsesWith( + outerTargetOut, + innerControlsOut[std::distance(innerControls.begin(), it)]); } else { - controls.push_back(qubit); + const auto it2 = llvm::find(innerTargets, blockArg); + rewriter.replaceAllUsesWith( + outerTargetOut, + mergedTargetsOut[std::distance(innerTargets.begin(), it2)]); } } - - rewriter.replaceOpWithNewOp( - op, controls, targets, - [&](ValueRange targetArgs) -> SmallVector { - auto* innerCtrlBody = innerCtrlOp.getBody(); - IRMapping mapping; - utils::populateMapping(mapping, *innerCtrlBody, innerTargets, - outerTargets, targets, targetArgs); - for (auto& op : innerCtrlBody->without_terminator()) { - rewriter.clone(op, mapping); - } - SmallVector yields; - for (auto value : innerCtrlBody->getTerminator()->getOperands()) { - yields.push_back(mapping.lookup(value)); - } - return yields; - }); - + rewriter.eraseOp(op); return success(); } }; diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 924f80eaa2..d0f7132f81 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -54,34 +53,29 @@ struct MoveCtrlOutside final : OpRewritePattern { return failure(); } - const auto numControls = innerCtrlOp.getNumControls(); - const auto numTargets = innerCtrlOp.getNumTargets(); + // inv(ctrl(x)) == ctrl(inv(x)). The inner control's controls and targets + // are block arguments aliasing the inverse modifier's qubits. Pull the + // controls out to a new control modifier and wrap the inner body in an + // inverse modifier whose block arguments match the inner targets, so the + // inner body is reused verbatim. auto outerQubits = op.getQubitsIn(); - auto controls = outerQubits.take_front(numControls); - auto targets = outerQubits.take_back(numTargets); + const auto controls = + llvm::map_to_vector(innerCtrlOp.getControlsIn(), [&](Value c) { + return utils::getValueFromBlockArgument(c, outerQubits); + }); + const auto targets = + llvm::map_to_vector(innerCtrlOp.getTargetsIn(), [&](Value t) { + return utils::getValueFromBlockArgument(t, outerQubits); + }); rewriter.replaceOpWithNewOp( op, controls, targets, [&](ValueRange targetArgs) -> SmallVector { - return InvOp::create( - rewriter, op.getLoc(), targetArgs, - [&](ValueRange qubitArgs) -> SmallVector { - auto* innerCtrlBody = innerCtrlOp.getBody(); - IRMapping mapping; - utils::populateMapping(mapping, *innerCtrlBody, - innerCtrlOp.getTargetsIn(), - outerQubits, targets, qubitArgs); - for (auto& op : innerCtrlBody->without_terminator()) { - rewriter.clone(op, mapping); - } - SmallVector yields; - for (auto value : - innerCtrlBody->getTerminator()->getOperands()) { - yields.push_back(mapping.lookup(value)); - } - return yields; - }) - .getResults(); + auto innerInv = InvOp::create(rewriter, op.getLoc(), targetArgs); + rewriter.inlineRegionBefore(innerCtrlOp.getRegion(), + innerInv.getRegion(), + innerInv.getRegion().end()); + return innerInv.getResults(); }); return success(); @@ -132,190 +126,133 @@ struct ReplaceWithKnownGates final : OpRewritePattern { } auto* innerOp = inner.getOperation(); - auto loc = op.getLoc(); - auto outerQubits = op.getInputQubits(); - - return TypeSwitch(innerOp) - .Case([&](auto g) { - Value negTheta = arith::NegFOp::create(rewriter, loc, g.getTheta()); - rewriter.replaceOpWithNewOp(op, negTheta); - return success(); - }) - .Case([&](auto t) { - rewriter.replaceOpWithNewOp( - op, utils::getValueFromBlockArgument(t.getInputTarget(0), - outerQubits)); - return success(); - }) - .Case([&](auto tdg) { - rewriter.replaceOpWithNewOp( - op, utils::getValueFromBlockArgument(tdg.getInputTarget(0), - outerQubits)); - return success(); - }) - .Case([&](auto s) { - rewriter.replaceOpWithNewOp( - op, utils::getValueFromBlockArgument(s.getInputTarget(0), - outerQubits)); - return success(); - }) - .Case([&](auto sdg) { - rewriter.replaceOpWithNewOp( - op, utils::getValueFromBlockArgument(sdg.getInputTarget(0), - outerQubits)); - return success(); - }) - .Case([&](auto sx) { - rewriter.replaceOpWithNewOp( - op, utils::getValueFromBlockArgument(sx.getInputTarget(0), - outerQubits)); - return success(); - }) - .Case([&](auto sxdg) { - rewriter.replaceOpWithNewOp( - op, utils::getValueFromBlockArgument(sxdg.getInputTarget(0), - outerQubits)); - return success(); - }) - .Case([&](auto p) { - Value negTheta = arith::NegFOp::create(rewriter, loc, p.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(p.getInputTarget(0), - outerQubits), - negTheta); - return success(); - }) - .Case([&](auto r) { - Value negTheta = arith::NegFOp::create(rewriter, loc, r.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(r.getInputTarget(0), - outerQubits), - negTheta, r.getPhi()); - return success(); - }) - .Case([&](auto rx) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rx.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rx.getInputTarget(0), - outerQubits), - negTheta); - return success(); - }) - .Case([&](auto u) { - Value newPhi = arith::NegFOp::create(rewriter, loc, u.getLambda()); - Value newLambda = arith::NegFOp::create(rewriter, loc, u.getPhi()); - Value newTheta = arith::NegFOp::create(rewriter, loc, u.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(u.getInputTarget(0), - outerQubits), - newTheta, newPhi, newLambda); - return success(); - }) - .Case([&](auto u2) { - auto pi = arith::ConstantOp::create( - rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); - Value newPhi = arith::NegFOp::create(rewriter, loc, u2.getLambda()); - newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); - Value newLambda = arith::NegFOp::create(rewriter, loc, u2.getPhi()); - newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(u2.getInputTarget(0), - outerQubits), - newPhi, newLambda); - return success(); - }) - .Case([&](auto rxx) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rxx.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rxx.getInputTarget(0), - outerQubits), - utils::getValueFromBlockArgument(rxx.getInputTarget(1), - outerQubits), - negTheta); - return success(); - }) - .Case([&](auto ry) { - Value negTheta = arith::NegFOp::create(rewriter, loc, ry.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(ry.getInputTarget(0), - outerQubits), - negTheta); - return success(); - }) - .Case([&](auto ryy) { - Value negTheta = arith::NegFOp::create(rewriter, loc, ryy.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(ryy.getInputTarget(0), - outerQubits), - utils::getValueFromBlockArgument(ryy.getInputTarget(1), - outerQubits), - negTheta); - return success(); - }) - .Case([&](auto rz) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rz.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rz.getInputTarget(0), - outerQubits), - negTheta); - return success(); - }) - .Case([&](auto rzx) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rzx.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rzx.getInputTarget(0), - outerQubits), - utils::getValueFromBlockArgument(rzx.getInputTarget(1), - outerQubits), - negTheta); - return success(); - }) - .Case([&](auto rzz) { - Value negTheta = arith::NegFOp::create(rewriter, loc, rzz.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(rzz.getInputTarget(0), - outerQubits), - utils::getValueFromBlockArgument(rzz.getInputTarget(1), - outerQubits), - negTheta); - return success(); - }) - .Case([&](auto xxminusyy) { - Value negTheta = - arith::NegFOp::create(rewriter, loc, xxminusyy.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(xxminusyy.getInputTarget(0), - outerQubits), - utils::getValueFromBlockArgument(xxminusyy.getInputTarget(1), - outerQubits), - negTheta, xxminusyy.getBeta()); - return success(); - }) - .Case([&](auto xxplusyy) { - Value negTheta = - arith::NegFOp::create(rewriter, loc, xxplusyy.getTheta()); - rewriter.replaceOpWithNewOp( - op, - utils::getValueFromBlockArgument(xxplusyy.getInputTarget(0), - outerQubits), - utils::getValueFromBlockArgument(xxplusyy.getInputTarget(1), - outerQubits), - negTheta, xxplusyy.getBeta()); - return success(); - }) - .Default([&](auto) { return failure(); }); + // Replace the body gate in place with its inverse, operating on the same + // (block-argument) operands; inlining the body afterwards substitutes those + // block arguments with the modifier's input qubits. + const auto loc = innerOp->getLoc(); + rewriter.setInsertionPoint(innerOp); + const auto negTheta = [&](auto g) { + return arith::NegFOp::create(rewriter, loc, g.getTheta()).getResult(); + }; + const auto replaced = + TypeSwitch(innerOp) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0), + negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0), + negTheta(g), g.getPhi()); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0), + negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0), + negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0), + negTheta(g)); + return success(); + }) + .Case([&](auto g) { + Value newPhi = + arith::NegFOp::create(rewriter, loc, g.getLambda()); + Value newLambda = + arith::NegFOp::create(rewriter, loc, g.getPhi()); + Value newTheta = + arith::NegFOp::create(rewriter, loc, g.getTheta()); + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0), newTheta, + newPhi, newLambda); + return success(); + }) + .Case([&](auto g) { + Value pi = arith::ConstantOp::create( + rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); + Value newPhi = + arith::NegFOp::create(rewriter, loc, g.getLambda()); + newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); + Value newLambda = + arith::NegFOp::create(rewriter, loc, g.getPhi()); + newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0), newPhi, + newLambda); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp( + g, g.getInputTarget(0), g.getInputTarget(1), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp( + g, g.getInputTarget(0), g.getInputTarget(1), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp( + g, g.getInputTarget(0), g.getInputTarget(1), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp( + g, g.getInputTarget(0), g.getInputTarget(1), negTheta(g)); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp( + g, g.getInputTarget(0), g.getInputTarget(1), negTheta(g), + g.getBeta()); + return success(); + }) + .Case([&](auto g) { + rewriter.replaceOpWithNewOp(g, g.getInputTarget(0), + g.getInputTarget(1), + negTheta(g), g.getBeta()); + return success(); + }) + .Default([&](auto) { return failure(); }); + + if (failed(replaced)) { + return failure(); + } + + utils::inlineModifierBody(op, *op.getBody(), op.getInputQubits(), rewriter); + return success(); } }; @@ -342,11 +279,10 @@ struct CancelNestedInv final : OpRewritePattern { // inv(inv(x)) == x: inline the doubly-nested body directly onto the outer // input qubits. The inner body's block arguments alias the inner modifier's // inputs, which in turn alias the outer input qubits. - SmallVector replacements; - for (auto innerInput : innerInvOp.getInputQubits()) { - replacements.push_back( - utils::getValueFromBlockArgument(innerInput, op.getInputQubits())); - } + const auto replacements = + llvm::map_to_vector(innerInvOp.getInputQubits(), [&](Value q) { + return utils::getValueFromBlockArgument(q, op.getInputQubits()); + }); utils::inlineModifierBody(op, *innerInvOp.getBody(), replacements, rewriter); return success(); From 55b78ae835e73427506495d9b2f2f6d678cdfcdd Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 17:40:34 +0200 Subject: [PATCH 28/41] :recycle: Further simplify CtrlOp and InvOp canonicalizations by relying more on `getOutputForInput` Assisted-By: Claude Opus 4.8 (1M context) Signed-off-by: burgholzer --- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 40 +++++--------------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 10 ++++- 2 files changed, 17 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index b1fb647d95..6604e70a46 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #include using namespace mlir; @@ -81,32 +80,12 @@ struct MergeNestedCtrl final : OpRewritePattern { rewriter.inlineRegionBefore(innerCtrlOp.getRegion(), merged.getRegion(), merged.getRegion().end()); - // Outer and inner controls pass through to the merged controls; each outer - // target follows its block argument to either a merged control output (if - // it was an inner control) or a merged target output (if it was an inner - // target). - const auto numOuterControls = op.getNumControls(); - rewriter.replaceAllUsesWith( - op.getControlsOut(), - merged.getControlsOut().take_front(numOuterControls)); - auto innerControlsOut = - merged.getControlsOut().drop_front(numOuterControls); - auto mergedTargetsOut = merged.getTargetsOut(); - for (auto [blockArg, outerTargetOut] : - llvm::zip_equal(op.getBody()->getArguments(), op.getTargetsOut())) { - if (auto it = llvm::find(innerControls, blockArg); - it != innerControls.end()) { - rewriter.replaceAllUsesWith( - outerTargetOut, - innerControlsOut[std::distance(innerControls.begin(), it)]); - } else { - const auto it2 = llvm::find(innerTargets, blockArg); - rewriter.replaceAllUsesWith( - outerTargetOut, - mergedTargetsOut[std::distance(innerTargets.begin(), it2)]); - } - } - rewriter.eraseOp(op); + // Every qubit output of the original control follows its input qubit to the + // corresponding output of the merged control. + rewriter.replaceOp(op, + llvm::map_to_vector(op.getInputQubits(), [&](Value in) { + return merged.getOutputForInput(in); + })); return success(); } }; @@ -130,12 +109,11 @@ struct ReduceCtrl final : OpRewritePattern { if (op.getNumControls() == 0 || isa(innerOp)) { auto* body = op.getBody(); auto* terminator = body->getTerminator(); - const SmallVector targets(terminator->getOperands()); + SmallVector outputs(op.getControlsIn()); + llvm::append_range(outputs, terminator->getOperands()); rewriter.inlineBlockBefore(body, op, op.getTargetsIn()); rewriter.eraseOp(terminator); - rewriter.replaceAllUsesWith(op.getControlsOut(), op.getControlsIn()); - rewriter.replaceAllUsesWith(op.getTargetsOut(), targets); - rewriter.eraseOp(op); + rewriter.replaceOp(op, outputs); return success(); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index d0f7132f81..df9ea3d785 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -68,8 +68,8 @@ struct MoveCtrlOutside final : OpRewritePattern { return utils::getValueFromBlockArgument(t, outerQubits); }); - rewriter.replaceOpWithNewOp( - op, controls, targets, + auto newCtrl = CtrlOp::create( + rewriter, op.getLoc(), controls, targets, [&](ValueRange targetArgs) -> SmallVector { auto innerInv = InvOp::create(rewriter, op.getLoc(), targetArgs); rewriter.inlineRegionBefore(innerCtrlOp.getRegion(), @@ -78,6 +78,12 @@ struct MoveCtrlOutside final : OpRewritePattern { return innerInv.getResults(); }); + // Each qubit output of the inverse modifier follows its input qubit to the + // corresponding output of the new control modifier. + rewriter.replaceOp(op, + llvm::map_to_vector(op.getInputQubits(), [&](Value in) { + return newCtrl.getOutputForInput(in); + })); return success(); } }; From 2db1271772d7ca40058e0249e03945b96c4182cf Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 17:58:04 +0200 Subject: [PATCH 29/41] :recycle: Simplify verifiers of modifier operations Signed-off-by: burgholzer --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 16 +++------------ mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 5 ----- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 21 +++++++------------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 9 +++------ 4 files changed, 13 insertions(+), 38 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index db9fed7a8c..88a286b05f 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -195,26 +195,16 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, } LogicalResult CtrlOp::verify() { - auto& block = *getBody(); if (llvm::any_of(*getBody(), [](Operation& op) { return isa(op); })) { return emitOpError("body must not contain non-unitary quantum operations"); } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } SmallPtrSet uniqueQubits; - for (const auto& control : getControls()) { - if (!uniqueQubits.insert(control).second) { - return emitOpError("duplicate control qubit found"); - } - } - for (const auto& target : getTargets()) { - if (!uniqueQubits.insert(target).second) { - return emitOpError("duplicate target qubit found"); + for (const auto& qubit : getQubits()) { + if (!uniqueQubits.insert(qubit).second) { + return emitOpError("duplicate qubit found"); } } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 3ae36f1fd6..b8f49fd41f 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -325,16 +325,11 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, } LogicalResult InvOp::verify() { - auto& block = *getBody(); if (llvm::any_of(*getBody(), [](Operation& op) { return isa(op); })) { return emitOpError("body must not contain non-unitary quantum operations"); } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } return success(); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 6604e70a46..63deb3f816 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -285,15 +285,11 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); - if (llvm::any_of(*getBody(), [](Operation& op) { + if (llvm::any_of(block, [](Operation& op) { return isa(op); })) { return emitOpError("body must not contain non-unitary quantum operations"); } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { @@ -307,21 +303,17 @@ LogicalResult CtrlOp::verify() { << i << " does not match target type"; } } - if (const auto numYieldOperands = block.back().getNumOperands(); + auto blockTerminator = block.getTerminator(); + if (const auto numYieldOperands = blockTerminator->getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") << numTargets << " values, but found " << numYieldOperands; } SmallPtrSet uniqueQubitsIn; - for (const auto& control : getControlsIn()) { + for (const auto& control : getInputQubits()) { if (!uniqueQubitsIn.insert(control).second) { - return emitOpError("duplicate control qubit found"); - } - } - for (const auto& target : getTargetsIn()) { - if (!uniqueQubitsIn.insert(target).second) { - return emitOpError("duplicate target qubit found"); + return emitOpError("duplicate qubit found"); } } @@ -331,8 +323,9 @@ LogicalResult CtrlOp::verify() { return emitOpError("duplicate control qubit found"); } } + for (size_t i = 0; i < numTargets; i++) { - if (!uniqueQubitsOut.insert(block.back().getOperand(i)).second) { + if (!uniqueQubitsOut.insert(blockTerminator->getOperand(i)).second) { return emitOpError("duplicate qubit found"); } } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index df9ea3d785..fb8aaef198 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -372,15 +372,11 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); - if (llvm::any_of(*getBody(), [](Operation& op) { + if (llvm::any_of(block, [](Operation& op) { return isa(op); })) { return emitOpError("body must not contain non-unitary quantum operations"); } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { @@ -394,7 +390,8 @@ LogicalResult InvOp::verify() { << i << " does not match target type"; } } - if (const auto numYieldOperands = block.back().getNumOperands(); + auto blockTerminator = block.getTerminator(); + if (const auto numYieldOperands = blockTerminator->getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") << numTargets << " values, but found " << numYieldOperands; From 0b3d8af38cda56b4778f4e935ab1ba770e38aadb Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 19:00:56 +0200 Subject: [PATCH 30/41] :art: Small simplification in QIR conversion Signed-off-by: burgholzer --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index e7a6d2910f..45a71ab2ba 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -874,9 +874,7 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { // Update modifier information state.inCtrlOp = op.getNumBodyUnitaries(); - const SmallVector controls(adaptor.getControls().begin(), - adaptor.getControls().end()); - state.controls = controls; + state.controls = llvm::to_vector(adaptor.getControls()); // Inline block and remove operation rewriter.inlineBlockBefore(&op.getRegion().front(), op, From da147e0fd246eadfa578a52d5f7391be80f05c62 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 19:06:11 +0200 Subject: [PATCH 31/41] :art: Simplify implementation of `getQubit` and `getControl` for QC CtrlOp Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/QC/IR/QCOps.td | 6 +++--- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 18 ------------------ 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index dd120e52a5..e3fcbec359 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -957,10 +957,10 @@ def CtrlOp size_t getNumQubits() { return getNumTargets() + getNumControls(); } size_t getNumTargets() { return getTargets().size(); } size_t getNumControls() { return getControls().size(); } - Value getQubit(size_t i); - Value getTarget(size_t i) { return getTargets()[i]; } - Value getControl(size_t i); OperandRange getQubits() { return getOperands(); } + Value getQubit(size_t i) { return getQubits()[i]; } + Value getTarget(size_t i) { return getTargets()[i]; } + Value getControl(size_t i) { return getControls()[i]; } size_t getNumParams() { return 0; } Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } OperandRange getParameters() { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 88a286b05f..395f851c76 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -159,24 +159,6 @@ UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { return utils::getBodyUnitary(*getBody(), i); } -Value CtrlOp::getQubit(const size_t i) { - const auto numControls = getNumControls(); - if (i < numControls) { - return getControls()[i]; - } - if (numControls <= i && i < getNumQubits()) { - return getTarget(i - numControls); - } - llvm::reportFatalUsageError("Qubit index out of bounds"); -} - -Value CtrlOp::getControl(const size_t i) { - if (i >= getNumControls()) { - llvm::reportFatalUsageError("Control index out of bounds"); - } - return getControls()[i]; -} - void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, ValueRange controls, ValueRange targets, const function_ref& body) { From 15ce87b1689b4f716b7ccef90e4566fba1f60e12 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 20:01:32 +0200 Subject: [PATCH 32/41] :art: Simplify implementation of UnitaryOpInterface methods for QCO BarrierOp, CtrlOp, and InvOp Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 20 ++--- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 73 ++----------------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 27 ++----- .../IR/Operations/StandardGates/BarrierOp.cpp | 27 ++----- 4 files changed, 28 insertions(+), 119 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index 3e4980b2b0..17bab8174b 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -1017,8 +1017,8 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { OperandRange getInputTargets() { return getInputQubits(); } Value getOutputQubit(size_t i) { return getOutputTarget(i); } ResultRange getOutputQubits() { return getQubitsOut(); } - Value getInputTarget(size_t i); - Value getOutputTarget(size_t i); + Value getInputTarget(size_t i) { return getQubitsIn()[i]; } + Value getOutputTarget(size_t i) { return getQubitsOut()[i]; } static Value getInputControl(size_t i) { llvm::reportFatalUsageError("BarrierOp cannot be controlled"); } static OperandRange getInputControls() { return {nullptr, 0}; } static Value getOutputControl(size_t i) { llvm::reportFatalUsageError("BarrierOp cannot be controlled"); } @@ -1108,16 +1108,16 @@ def CtrlOp : QCOOp<"ctrl", size_t getNumQubits() { return getNumControls() + getNumTargets(); } size_t getNumTargets() { return getTargetsIn().size(); } size_t getNumControls() { return getControlsIn().size(); } - Value getInputQubit(size_t i); OperandRange getInputQubits() { return getOperands(); } + Value getInputQubit(size_t i) { return getInputQubits()[i]; } OperandRange getInputTargets() { return getTargetsIn(); } - Value getOutputQubit(size_t i); ResultRange getOutputQubits() { return getResults(); } - Value getInputTarget(size_t i); - Value getOutputTarget(size_t i); - Value getInputControl(size_t i); + Value getOutputQubit(size_t i) { return getOutputQubits()[i]; } + Value getInputTarget(size_t i) { return getInputTargets()[i]; } + Value getOutputTarget(size_t i) { return getOutputTargets()[i]; } + Value getInputControl(size_t i) { return getInputControls()[i]; } OperandRange getInputControls() { return getControlsIn(); } - Value getOutputControl(size_t i); + Value getOutputControl(size_t i) { return getOutputControls()[i]; } ResultRange getOutputTargets() { return getTargetsOut(); } ResultRange getOutputControls() { return getControlsOut(); } Value getInputForOutput(Value output); @@ -1181,11 +1181,11 @@ def InvOp : QCOOp<"inv", traits = [UnitaryOpInterface, size_t getNumQubits() { return getNumTargets(); } size_t getNumTargets() { return getQubitsIn().size(); } static size_t getNumControls() { return 0; } - Value getInputQubit(size_t i); OperandRange getInputQubits() { return getQubitsIn(); } + Value getInputQubit(size_t i) { return getQubitsIn()[i]; } OperandRange getInputTargets() { return getInputQubits(); } - Value getOutputQubit(size_t i); ResultRange getOutputQubits() { return getQubitsOut(); } + Value getOutputQubit(size_t i) { return getQubitsOut()[i]; } Value getInputTarget(size_t i) { return getInputQubit(i); } Value getOutputTarget(size_t i) { return getOutputQubit(i); } static Value getInputControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 63deb3f816..b54ac9e392 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -188,79 +188,18 @@ UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { return utils::getBodyUnitary(*getBody(), i); } -Value CtrlOp::getInputQubit(const size_t i) { - const auto numControls = getNumControls(); - if (i < numControls) { - return getControlsIn()[i]; - } - if (numControls <= i && i < getNumQubits()) { - return getTargetsIn()[i - numControls]; - } - llvm::reportFatalUsageError("Qubit index out of bounds"); -} - -Value CtrlOp::getOutputQubit(const size_t i) { - const auto numControls = getNumControls(); - if (i < numControls) { - return getControlsOut()[i]; - } - if (numControls <= i && i < getNumQubits()) { - return getTargetsOut()[i - numControls]; - } - llvm::reportFatalUsageError("Qubit index out of bounds"); -} - -Value CtrlOp::getInputTarget(const size_t i) { - if (i >= getNumTargets()) { - llvm::reportFatalUsageError("Target index out of bounds"); - } - return getTargetsIn()[i]; -} - -Value CtrlOp::getOutputTarget(const size_t i) { - if (i >= getNumTargets()) { - llvm::reportFatalUsageError("Target index out of bounds"); - } - return getTargetsOut()[i]; -} - -Value CtrlOp::getInputControl(const size_t i) { - if (i >= getNumControls()) { - llvm::reportFatalUsageError("Control index out of bounds"); - } - return getControlsIn()[i]; -} - -Value CtrlOp::getOutputControl(const size_t i) { - if (i >= getNumControls()) { - llvm::reportFatalUsageError("Control index out of bounds"); - } - return getControlsOut()[i]; -} - Value CtrlOp::getInputForOutput(Value output) { - for (size_t i = 0; i < getNumControls(); ++i) { - if (output == getControlsOut()[i]) { - return getControlsIn()[i]; - } - } - for (size_t i = 0; i < getNumTargets(); ++i) { - if (output == getTargetsOut()[i]) { - return getTargetsIn()[i]; - } + if (const auto result = dyn_cast(output); + result && result.getOwner() == getOperation()) { + return getInputQubit(result.getResultNumber()); } llvm::reportFatalUsageError("Given qubit is not an output of the operation"); } Value CtrlOp::getOutputForInput(Value input) { - for (size_t i = 0; i < getNumControls(); ++i) { - if (input == getControlsIn()[i]) { - return getControlsOut()[i]; - } - } - for (size_t i = 0; i < getNumTargets(); ++i) { - if (input == getTargetsIn()[i]) { - return getTargetsOut()[i]; + for (auto [in, out] : llvm::zip_equal(getInputQubits(), getOutputQubits())) { + if (in == input) { + return out; } } llvm::reportFatalUsageError("Given qubit is not an input of the operation"); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index fb8aaef198..b24e9f9d19 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -321,33 +321,18 @@ UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { return utils::getBodyUnitary(*getBody(), i); } -Value InvOp::getInputQubit(const size_t i) { - if (i >= getNumTargets()) { - llvm::reportFatalUsageError("Qubit index out of bounds"); - } - return getQubitsIn()[i]; -} - -Value InvOp::getOutputQubit(const size_t i) { - if (i >= getNumTargets()) { - llvm::reportFatalUsageError("Qubit index out of bounds"); - } - return getQubitsOut()[i]; -} - Value InvOp::getInputForOutput(Value output) { - for (size_t i = 0; i < getNumTargets(); ++i) { - if (output == getQubitsOut()[i]) { - return getQubitsIn()[i]; - } + if (const auto result = dyn_cast(output); + result && result.getOwner() == getOperation()) { + return getInputQubit(result.getResultNumber()); } llvm::reportFatalUsageError("Given qubit is not an output of the operation"); } Value InvOp::getOutputForInput(Value input) { - for (size_t i = 0; i < getNumTargets(); ++i) { - if (input == getQubitsIn()[i]) { - return getQubitsOut()[i]; + for (auto [in, out] : llvm::zip_equal(getInputQubits(), getOutputQubits())) { + if (in == input) { + return out; } } llvm::reportFatalUsageError("Given qubit is not an input of the operation"); diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp index 0bd50dd3a3..1f9552c0a2 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp @@ -77,33 +77,18 @@ struct MergeSubsequentBarrier final : OpRewritePattern { } // namespace -Value BarrierOp::getInputTarget(const size_t i) { - if (i < getNumTargets()) { - return getQubitsIn()[i]; - } - llvm::reportFatalUsageError("Invalid qubit index"); -} - -Value BarrierOp::getOutputTarget(const size_t i) { - if (i < getNumTargets()) { - return getQubitsOut()[i]; - } - llvm::reportFatalUsageError("Invalid qubit index"); -} - Value BarrierOp::getInputForOutput(Value output) { - for (size_t i = 0; i < getNumTargets(); ++i) { - if (output == getQubitsOut()[i]) { - return getQubitsIn()[i]; - } + if (const auto result = dyn_cast(output); + result && result.getOwner() == getOperation()) { + return getQubitsIn()[result.getResultNumber()]; } llvm::reportFatalUsageError("Given qubit is not an output of the operation"); } Value BarrierOp::getOutputForInput(Value input) { - for (size_t i = 0; i < getNumTargets(); ++i) { - if (input == getQubitsIn()[i]) { - return getQubitsOut()[i]; + for (auto [in, out] : llvm::zip_equal(getQubitsIn(), getQubitsOut())) { + if (in == input) { + return out; } } llvm::reportFatalUsageError("Given qubit is not an input of the operation"); From 51c56414559162b0437b31d7daa6064db13ef4ce Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 20:04:24 +0200 Subject: [PATCH 33/41] :art: Use `getSoleBodyUnitary` helper in HadamardLifting pass Signed-off-by: burgholzer --- .../QCO/Transforms/Optimizations/HadamardLifting.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp index 3d874533b6..8392f09bac 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QCO/IR/QCOInterfaces.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/Transforms/Passes.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -162,8 +163,9 @@ struct LiftHadamardAboveCNOTPattern final : OpRewritePattern { if (!cnotGate) { return failure(); } - if (cnotGate.getNumBodyUnitaries() != 1 || - !isa(cnotGate.getBodyUnitary(0)) || + if (auto innerUnitary = + utils::getSoleBodyUnitary(*cnotGate.getBody()); + !innerUnitary || !isa(innerUnitary.getOperation()) || cnotGate.getOutputTarget(0) != inQubitHadamard) { return failure(); } From 936aa149b4d65688cac287c0d5aa3186e7d78789 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 20:17:18 +0200 Subject: [PATCH 34/41] :art: Simplify QC program construction slightly Signed-off-by: burgholzer --- mlir/unittests/programs/qc_programs.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 232b8cbdee..21da2d43f2 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -1594,7 +1594,7 @@ void nestedForLoopCtrlOpWithSeparateQubit(QCProgramBuilder& b) { b.scfFor(0, 3, 1, [&](Value iv) { auto q0 = b.memrefLoad(reg.value, iv); b.h(q0); - b.ctrl(control, q0, [&](ValueRange targets) { b.x(targets[0]); }); + b.cx(control, q0); }); } @@ -1604,7 +1604,7 @@ void nestedForLoopCtrlOpWithExtractedQubit(QCProgramBuilder& b) { b.scfFor(1, 4, 1, [&](Value iv) { auto q0 = b.memrefLoad(reg.value, iv); b.h(q0); - b.ctrl(reg[0], q0, [&](ValueRange targets) { b.x(targets[0]); }); + b.cx(reg[0], q0); }); } From e028c35fc427be98412ad8b0bb2dbaabbb40f7a2 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 20:37:22 +0200 Subject: [PATCH 35/41] =?UTF-8?q?=F0=9F=94=A5=20Remove=20tests=20that=20do?= =?UTF-8?q?=20not=20exercise=20the=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp | 4 +--- mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index 7dda9ccfda..f25ce73b6f 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -176,9 +176,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qco::inverseMultipleControlledDcx), MQT_NAMED_BUILDER(qc::multipleControlledDcx)}, QCOToQCTestCase{"InvTwo", MQT_NAMED_BUILDER(qco::invTwo), - MQT_NAMED_BUILDER(qc::invTwo)}, - QCOToQCTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(qco::invCtrlTwo), - MQT_NAMED_BUILDER(qc::ctrlInvTwo)})); + MQT_NAMED_BUILDER(qc::invTwo)})); /// @} /// \name QCOToQC/Operations/StandardGates/BarrierOp.cpp diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index 00b2c7fe7b..9e66a0b655 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -169,9 +169,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::inverseMultipleControlledIswap), MQT_NAMED_BUILDER(qco::inverseMultipleControlledIswap)}, QCToQCOTestCase{"InvTwo", MQT_NAMED_BUILDER(qc::invTwo), - MQT_NAMED_BUILDER(qco::invTwo)}, - QCToQCOTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(qc::invCtrlTwo), - MQT_NAMED_BUILDER(qco::ctrlInvTwo)})); + MQT_NAMED_BUILDER(qco::invTwo)})); /// @} /// \name QCToQCO/Operations/StandardGates/BarrierOp.cpp From 45a59286e10101968250633c3ee7005f64ac67a4 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 20:46:14 +0200 Subject: [PATCH 36/41] =?UTF-8?q?=F0=9F=9A=A8=20Address=20clang-tidy=20war?= =?UTF-8?q?nings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 4 ++-- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 4 ++-- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 4 +++- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 5 +++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 395f851c76..b3a032f8a0 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -9,11 +9,12 @@ */ #include "mlir/Dialect/QC/IR/QCDialect.h" +#include "mlir/Dialect/QC/IR/QCInterfaces.h" #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" #include -#include +#include #include #include #include @@ -23,7 +24,6 @@ #include #include -#include using namespace mlir; using namespace mlir::qc; diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index b8f49fd41f..efada34b74 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -9,12 +9,13 @@ */ #include "mlir/Dialect/QC/IR/QCDialect.h" +#include "mlir/Dialect/QC/IR/QCInterfaces.h" #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" #include +#include #include -#include #include #include #include @@ -24,7 +25,6 @@ #include #include -#include #include using namespace mlir; diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index b54ac9e392..427f5ec491 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/Utils/Matrix.h" #include "mlir/Dialect/Utils/Utils.h" @@ -16,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -242,7 +244,7 @@ LogicalResult CtrlOp::verify() { << i << " does not match target type"; } } - auto blockTerminator = block.getTerminator(); + auto* blockTerminator = block.getTerminator(); if (const auto numYieldOperands = blockTerminator->getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index b24e9f9d19..d91442f4d4 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -9,12 +9,14 @@ */ #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/Utils/Matrix.h" #include "mlir/Dialect/Utils/Utils.h" #include #include +#include #include #include #include @@ -26,7 +28,6 @@ #include #include -#include #include #include @@ -375,7 +376,7 @@ LogicalResult InvOp::verify() { << i << " does not match target type"; } } - auto blockTerminator = block.getTerminator(); + auto* blockTerminator = block.getTerminator(); if (const auto numYieldOperands = blockTerminator->getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") From c160d6c29828374b28cee41b1399c8d2880978e4 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 21:01:28 +0200 Subject: [PATCH 37/41] =?UTF-8?q?=F0=9F=9A=A8=20Address=20clang-tidy=20war?= =?UTF-8?q?nings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: burgholzer --- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 1 + mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 1 + mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp | 1 + 3 files changed, 3 insertions(+) diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 427f5ec491..8141744a23 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index d91442f4d4..eb30c5fb40 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp index 1f9552c0a2..1b6b50fd6d 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include From e7711c0e534d41606d7811ee73a2bf6de2715849 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 22:18:46 +0200 Subject: [PATCH 38/41] :ok_hand: Address review comment Signed-off-by: burgholzer --- mlir/include/mlir/Dialect/QC/IR/QCOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index e3fcbec359..88df7b5586 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -963,7 +963,7 @@ def CtrlOp Value getControl(size_t i) { return getControls()[i]; } size_t getNumParams() { return 0; } Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } - OperandRange getParameters() { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } + OperandRange getParameters() { return {nullptr, 0}; } static StringRef getBaseSymbol() { return "ctrl"; } }]; From 9d2217ada8fe3df4ddd4e7b86488909559cebd01 Mon Sep 17 00:00:00 2001 From: burgholzer Date: Wed, 10 Jun 2026 22:22:48 +0200 Subject: [PATCH 39/41] :ok_hand: Conservatively address second review comment Signed-off-by: burgholzer --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index b3a032f8a0..8d56834d0c 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -43,6 +43,11 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } + // Only proceed when the outer body is exactly [innerCtrlOp, yield]. + if (op.getBody()->getOperations().size() != 2) { + return failure(); + } + auto inner = utils::getSoleBodyUnitary(*op.getBody()); if (!inner) { return failure(); From bcea1f8881d12bc73b8e90d1a23e8b6bbecf9488 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Thu, 11 Jun 2026 00:29:04 +0200 Subject: [PATCH 40/41] Address the Rabbit's out-of-diff comments --- mlir/include/mlir/Dialect/Utils/Utils.h | 37 +++++++++++++++++--- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 17 ++++++--- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 7 ++-- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 19 ++++++++-- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 7 ++-- 5 files changed, 71 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 33562fb197..05f78dc8df 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -83,6 +83,21 @@ template return std::nullopt; } +/** + * @brief Parse a list of aliased qubits. + * + * @details + * The modifier operations use aliased qubits inside of their region. This + * function resolves the relationship between the block arguments and the qubit + * operands. In the example below, the block argument `%a0` aliases the operand + * `%q1`. + * + * ```mlir + * qc.ctrl(%q0) targets(%a0 = %q1) { + * qc.x %a0 : !qc.qubit + * } : !qc.qubit + * ``` + */ template [[nodiscard]] ParseResult @@ -118,12 +133,9 @@ parseTargetAliasing(OpAsmParser& parser, Region& region, } operands.push_back(oldOperand); - // Hard-code QubitType since targets in CtrlOp are always qubits. - // This avoids double-binding type($targets_in) in the assembly format - // while keeping the parser simple and the assembly format clean. + // Hard-code QubitType because the modifiers only alias qubits newArg.type = QubitType::get(parser.getBuilder().getContext()); blockArgs.push_back(newArg); - } while (succeeded(parser.parseOptionalComma())); if (parser.parseRParen()) { @@ -141,6 +153,21 @@ parseTargetAliasing(OpAsmParser& parser, Region& region, return success(); } +/** + * @brief Print a list of aliased qubits. + * + * @details + * The modifier operations use aliased qubits inside of their region. This + * function prints a representation of the relationship between the block + * arguments and the qubit operands. In the example below, the block argument + * `%a0` aliases the operand `%q1`. + * + * ```mlir + * qc.ctrl(%q0) targets(%a0 = %q1) { + * qc.x %a0 : !qc.qubit + * } : !qc.qubit + * ``` + */ inline void printTargetAliasing(OpAsmPrinter& printer, Region& region, OperandRange targetsIn) { printer << "("; @@ -149,7 +176,7 @@ inline void printTargetAliasing(OpAsmPrinter& printer, Region& region, printer.printRegion(region, false); return; } - Block& entryBlock = region.front(); + auto& entryBlock = region.front(); const auto numTargets = targetsIn.size(); for (unsigned i = 0; i < numTargets; ++i) { diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 8d56834d0c..86dd60e8d1 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -43,7 +44,7 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } - // Only proceed when the outer body is exactly [innerCtrlOp, yield]. + // Only proceed if body contains only one operation besides terminator if (op.getBody()->getOperations().size() != 2) { return failure(); } @@ -108,6 +109,12 @@ struct ReduceCtrl final : OpRewritePattern { if (!gPhaseOp) { return failure(); } + + // Only proceed if the GPhaseOp is the only operation besides the terminator + if (op.getBody()->getOperations().size() != 2) { + return failure(); + } + // Special case for single control: replace with a single POp if (op.getNumControls() == 1) { rewriter.replaceOpWithNewOp(op, op.getControl(0), @@ -116,7 +123,7 @@ struct ReduceCtrl final : OpRewritePattern { } // Reinterpret the last control as a target qubit and apply a phase gate to - // it inside the (smaller) controlled region. + // it inside the (smaller) controlled region const auto opSegmentsAttrName = CtrlOp::getOperandSegmentSizeAttr(); auto segmentsAttr = op->getAttrOfType(opSegmentsAttrName); @@ -183,9 +190,11 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { if (llvm::any_of(*getBody(), [](Operation& op) { - return isa(op); + return isa(op); })) { - return emitOpError("body must not contain non-unitary quantum operations"); + return emitOpError("body must not contain non-unitary quantum operations " + "or modify a quantum register"); } SmallPtrSet uniqueQubits; diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index efada34b74..b1caa45344 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -326,9 +327,11 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { if (llvm::any_of(*getBody(), [](Operation& op) { - return isa(op); + return isa(op); })) { - return emitOpError("body must not contain non-unitary quantum operations"); + return emitOpError("body must not contain non-unitary quantum operations " + "or modify a quantum register"); } return success(); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 8141744a23..9c5a46837a 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -52,6 +53,11 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } + // Only proceed if body contains only one operation besides terminator + if (op.getBody()->getOperations().size() != 2) { + return failure(); + } + auto inner = utils::getSoleBodyUnitary(*op.getBody()); if (!inner) { return failure(); @@ -126,6 +132,11 @@ struct ReduceCtrl final : OpRewritePattern { return failure(); } + // Only proceed if the GPhaseOp is the only operation besides the terminator + if (op.getBody()->getOperations().size() != 2) { + return failure(); + } + // Special case for single control: replace with a single POp if (op.getNumControls() == 1) { rewriter.replaceOpWithNewOp(op, op.getInputControl(0), @@ -134,7 +145,7 @@ struct ReduceCtrl final : OpRewritePattern { } // Reinterpret the last control as a target qubit and apply a phase gate to - // it inside the (smaller) controlled region. + // it inside the (smaller) controlled region const auto opSegmentsAttrName = CtrlOp::getOperandSegmentSizeAttr(); auto segmentsAttr = op->getAttrOfType(opSegmentsAttrName); @@ -228,9 +239,11 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); if (llvm::any_of(block, [](Operation& op) { - return isa(op); + return isa(op); })) { - return emitOpError("body must not contain non-unitary quantum operations"); + return emitOpError("body must not contain non-unitary quantum operations " + "or modify a quantum register"); } const auto numTargets = getNumTargets(); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index eb30c5fb40..4bb5bf4d9f 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -360,9 +361,11 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); if (llvm::any_of(block, [](Operation& op) { - return isa(op); + return isa(op); })) { - return emitOpError("body must not contain non-unitary quantum operations"); + return emitOpError("body must not contain non-unitary quantum operations " + "or modify a quantum register"); } const auto numTargets = getNumTargets(); From db03b3ce8c733dc60324a16b2cc09af676d948f7 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Thu, 11 Jun 2026 01:13:42 +0200 Subject: [PATCH 41/41] Add defensive check --- mlir/include/mlir/Dialect/Utils/Utils.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 05f78dc8df..91c8d341f4 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -198,6 +198,8 @@ inline void printTargetAliasing(OpAsmPrinter& printer, Region& region, */ inline Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { if (auto blockArg = dyn_cast(qubit)) { + assert(blockArg.getArgNumber() < qubits.size() && + "block argument index must be within qubits range"); return qubits[blockArg.getArgNumber()]; } return qubit;