diff --git a/CHANGELOG.md b/CHANGELOG.md index 0412800d9d..941d811227 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel - ✨ Add conversions between `jeff` and QCO ([#1479], [#1548], [#1565], [#1637], [#1676], [#1706], [#1776]) ([**@denialhaag**], [**@burgholzer**]) - ✨ Add a `place-and-route` pass for mapping circuits to architectures with restricted topologies ([#1537], [#1547], [#1568], [#1581], [#1583], [#1588], [#1600], [#1664], [#1709], [#1716], [#1748]) ([**@MatthiasReumann**], [**@burgholzer**]) - ✨ Add initial infrastructure for new QC and QCO MLIR dialects - ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1710], [#1717], [#1728], [#1730], [#1749], [#1751], [#1762], [#1765], [#1774]) + ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1710], [#1717], [#1728], [#1730], [#1749], [#1751], [#1762], [#1765], [#1774], [#1781]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**]) ### Changed @@ -402,6 +402,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool +[#1781]: https://github.com/munich-quantum-toolkit/core/pull/1781 [#1776]: https://github.com/munich-quantum-toolkit/core/pull/1776 [#1774]: https://github.com/munich-quantum-toolkit/core/pull/1774 [#1765]: https://github.com/munich-quantum-toolkit/core/pull/1765 diff --git a/mlir/include/mlir/Support/IRVerification.h b/mlir/include/mlir/Support/IRVerification.h index 9ecb49a59c..5fd3ac9fd4 100644 --- a/mlir/include/mlir/Support/IRVerification.h +++ b/mlir/include/mlir/Support/IRVerification.h @@ -12,9 +12,10 @@ namespace mlir { class ModuleOp; -} +} // namespace mlir -/// Compare two MLIR modules for structural equivalence, allowing permutations -/// of speculatable operations. -[[nodiscard]] bool areModulesEquivalentWithPermutations(mlir::ModuleOp lhs, - mlir::ModuleOp rhs); +/// Compare two (quantum) module operations for structural equivalence, allowing +/// some permutations. This function is especially tailored to compare quantum +/// computations. +[[nodiscard]] bool areModulesEquivalentWithPermutations(mlir::ModuleOp, + mlir::ModuleOp); diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 0221464606..a1700fa319 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -11,407 +11,249 @@ #include "mlir/Support/IRVerification.h" #include "mlir/Dialect/QC/IR/QCOps.h" -#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QTensor/Utils/TensorIterator.h" -#include -#include -#include -#include -#include #include #include #include -#include -#include -#include +#include +#include #include -#include +#include #include #include +#include #include #include -#include +#include +#include #include -#include #include -#include -#include #include +#include #include #include #include -#include +#include using namespace mlir; namespace { +struct TensorMapping { + /// Maps all tensor values of the lhs to its equiv group. + DenseMap lhsEquivGroups; + /// Maps all tensor values of the rhs to its equiv group. + DenseMap rhsEquivGroups; + /// Maps the i-th group of lhs to the j-th group of rhs. + DenseMap equivGroupMapping; -/// Compute a structural hash for an operation (excluding SSA value identities). -/// This hash is based on operation name, types, and attributes only. -struct OperationStructuralHash { - size_t operator()(Operation* op) const { - size_t hash = llvm::hash_value(op->getName().getStringRef()); - - // Hash result types - for (auto type : op->getResultTypes()) { - hash = llvm::hash_combine(hash, type.getAsOpaquePointer()); - } - - // Hash operand types (not values) - for (auto operand : op->getOperands()) { - hash = llvm::hash_combine(hash, operand.getType().getAsOpaquePointer()); - } - - // Hash attributes - // for (const auto& attr : op->getAttrDictionary()) { - // hash = llvm::hash_combine(hash, attr.getName().str()); - // hash = llvm::hash_combine(hash, attr.getValue().getAsOpaquePointer()); - // } - - return hash; - } -}; - -/// Check if two operations are structurally equivalent (excluding SSA value -/// identities). -struct OperationStructuralEquality { - bool operator()(Operation* lhs, Operation* rhs) const { - // Check operation name - if (lhs->getName() != rhs->getName()) { - return false; - } - - // Check result types - if (lhs->getResultTypes() != rhs->getResultTypes()) { - return false; - } - - // Check operand types (not values) - auto lhsOperandTypes = lhs->getOperandTypes(); - auto rhsOperandTypes = rhs->getOperandTypes(); - return llvm::equal(lhsOperandTypes, rhsOperandTypes); - - // Note: Attributes are intentionally not checked here to allow relaxed - // comparison. Attributes like function names, parameter names, etc. may - // differ while operations are still structurally equivalent. - } -}; - -/// Wrapper for Operation* with structural comparison semantics -struct StructuralOperationKey { - Operation* op; - - explicit StructuralOperationKey(Operation* operation = nullptr) - : op(operation) {} - - bool operator==(const StructuralOperationKey& other) const { - if (op == other.op) { - return true; - } - if (op == nullptr || other.op == nullptr) { - return false; - } - return OperationStructuralEquality{}(op, other.op); + /// Map equivalence group identifiers of two tensors. + void map(Value lhs, Value rhs) { + equivGroupMapping[lhsEquivGroups[lhs]] = rhsEquivGroups[rhs]; } - bool operator!=(const StructuralOperationKey& other) const { - return !(*this == other); + /// Return true if the given tensor values have the same equiv group. + [[nodiscard]] bool equals(Value lhs, Value rhs) const { + const auto i = lhsEquivGroups.at(lhs); + return equivGroupMapping.at(i) == rhsEquivGroups.at(rhs); } }; - -/// Map to track value equivalence between two modules. -using ValueEquivalenceMap = DenseMap; - -using OperationSet = DenseSet; - -struct InsertWrite { - Value scalar; - Value index; -}; - -struct InsertChainSummary { - Value baseTensor; - Value finalTensor; - SmallVector writes; -}; - } // namespace -static bool areValuesEquivalent(Value lhs, Value rhs, - ValueEquivalenceMap& valueMap) { - if (auto it = valueMap.find(lhs); it != valueMap.end()) { - return it->second == rhs; - } - valueMap[lhs] = rhs; - return true; -} - -static bool areIndexValuesEquivalent(Value lhs, Value rhs, - ValueEquivalenceMap& valueMap) { - if (qtensor::areEquivalentIndices(lhs, rhs)) { - return true; - } - return areValuesEquivalent(lhs, rhs, valueMap); -} - -static bool isQTensorInsertOp(Operation* op) { - return isa(op); -} +static bool compareRegions(Region& lhs, Region& rhs, + SetVector& lhsClosed, + SetVector& rhsClosed, IRMapping& m, + TensorMapping& tm); -static bool isCommutableQTensorInsertDependency(Operation* dependent, - Operation* dependency) { - auto dependentInsert = dyn_cast(dependent); - auto dependencyInsert = dyn_cast(dependency); - if (!dependentInsert || !dependencyInsert) { +/// Return true, if the given value has the type `tensor`. +static bool hasTypeQubitTensor(Value v) { + auto tensor = dyn_cast(v.getType()); + if (!tensor) { return false; } - if (dependentInsert.getDest() != dependencyInsert.getResult()) { - return false; - } - auto dependentIndex = dependentInsert.getIndex(); - auto dependencyIndex = dependencyInsert.getIndex(); - if (!getConstantIntValue(dependentIndex) || - !getConstantIntValue(dependencyIndex)) { - return false; - } - return !qtensor::areEquivalentIndices(dependentIndex, dependencyIndex); -} -static Value getInsertChainBaseTensor(Value tensor, const OperationSet& group) { - auto current = tensor; - while (auto insertOp = current.getDefiningOp()) { - if (!group.contains(insertOp.getOperation())) { - break; - } - current = insertOp.getDest(); - } - return current; + return isa(tensor.getElementType()); } -static bool summarizeInsertGroup(ArrayRef ops, - SmallVectorImpl& chains) { - OperationSet groupOps; - for (Operation* op : ops) { - groupOps.insert(op); - } - - DenseSet consumedInsertResults; - for (Operation* op : ops) { - auto insertOp = cast(op); - if (auto definingInsert = - insertOp.getDest().getDefiningOp()) { - if (groupOps.contains(definingInsert.getOperation())) { - consumedInsertResults.insert(insertOp.getDest()); - } +/// Recursively initialize the equivalence group for a tensor value. +static void initEquivGroup(TypedValue v, size_t id, + DenseMap& group) { + for (qtensor::TensorIterator it(v); it != std::default_sentinel; ++it) { + if (it.tensor() == nullptr) { + continue; } - } - DenseMap chainByBaseTensor; - for (Operation* op : ops) { - auto insertOp = cast(op); - const Value baseTensor = - getInsertChainBaseTensor(insertOp.getDest(), groupOps); + group[it.tensor()] = id; - size_t chainIdx = 0; - if (auto it = chainByBaseTensor.find(baseTensor); - it != chainByBaseTensor.end()) { - chainIdx = it->second; - } else { - chainIdx = chains.size(); - chainByBaseTensor[baseTensor] = chainIdx; - InsertChainSummary summary; - summary.baseTensor = baseTensor; - chains.emplace_back(std::move(summary)); + if (isa(it.tensor())) { + continue; } - auto& chain = chains[chainIdx]; - chain.writes.push_back(InsertWrite{.scalar = insertOp.getScalar(), - .index = insertOp.getIndex()}); + if (auto op = dyn_cast(it.operation())) { + const auto prev = std::prev(it); + const auto qIt = llvm::find(op.getQubits(), prev.tensor()); + assert(qIt != op.getQubits().end()); + const auto idx = std::distance(op.getQubits().begin(), qIt); - if (!consumedInsertResults.contains(insertOp.getResult())) { - if (chain.finalTensor) { - return false; - } - chain.finalTensor = insertOp.getResult(); - } - } + auto& thenRegion = op.getThenRegion(); + auto& elseRegion = op.getElseRegion(); - for (const auto& chain : chains) { - if (!chain.finalTensor) { - return false; - } + const auto& thenArg = thenRegion.getArgument(idx); + const auto& elseArg = elseRegion.getArgument(idx); - // Reordering writes to the same index is not semantics-preserving. - SmallVector seenIndices; - for (const auto& write : chain.writes) { - if (llvm::any_of(seenIndices, [&](Value seenIndex) { - return qtensor::areEquivalentIndices(seenIndex, write.index); - })) { - return false; - } - seenIndices.push_back(write.index); + initEquivGroup(cast>(thenArg), id, group); + initEquivGroup(cast>(elseArg), id, group); + } else if (auto forOp = dyn_cast(it.operation())) { + const auto& arg = + forOp.getTiedLoopRegionIterArg(cast(it.tensor())); + initEquivGroup(cast>(arg), id, group); } } - - return true; } -static bool areInsertWritesEquivalentRec(const size_t lhsIdx, - ArrayRef lhsWrites, - ArrayRef rhsWrites, - SmallVectorImpl& rhsUsed, - ValueEquivalenceMap& valueMap) { - if (lhsIdx == lhsWrites.size()) { - return true; - } - - for (size_t rhsIdx = 0; rhsIdx < rhsWrites.size(); ++rhsIdx) { - if (rhsUsed[rhsIdx] != 0) { - continue; - } - - ValueEquivalenceMap tempMap = valueMap; - if (!areValuesEquivalent(lhsWrites[lhsIdx].scalar, rhsWrites[rhsIdx].scalar, - tempMap) || - !areIndexValuesEquivalent(lhsWrites[lhsIdx].index, - rhsWrites[rhsIdx].index, tempMap)) { - continue; - } +/// Generate equivalence group for all allocated and created tensors. +static DenseMap getEquivGroup(ModuleOp mod) { + size_t id = 0; + DenseMap group; - rhsUsed[rhsIdx] = 1; - if (areInsertWritesEquivalentRec(lhsIdx + 1, lhsWrites, rhsWrites, rhsUsed, - tempMap)) { - valueMap = std::move(tempMap); - return true; + mod->walk([&](Operation* op) { + if (auto alloc = dyn_cast(op)) { + initEquivGroup(alloc.getResult(), id, group); + ++id; + } else if (auto from = dyn_cast(op)) { + initEquivGroup(cast>(from.getResult()), id, + group); + ++id; } - rhsUsed[rhsIdx] = 0; - } + }); - return false; + return group; } -static bool areInsertWritesEquivalent(ArrayRef lhsWrites, - ArrayRef rhsWrites, - ValueEquivalenceMap& valueMap) { - if (lhsWrites.size() != rhsWrites.size()) { - return false; +/// Map all results from one op to another using the given permutation. +/// Assumes that `lhs->getNumResults() == rhs->getNumResults()`. +/// Assumes that the two operations are equivalent to each other. +static void mapResults(Operation* lhs, Operation* rhs, + ArrayRef permutation, IRMapping& m) { + for (const auto& [i, lhsResult] : llvm::enumerate(lhs->getResults())) { + m.map(lhsResult, rhs->getResult(permutation[i])); } - SmallVector rhsUsed(rhsWrites.size(), 0); - return areInsertWritesEquivalentRec(0, lhsWrites, rhsWrites, rhsUsed, - valueMap); } -static bool areInsertChainsEquivalent(const InsertChainSummary& lhsChain, - const InsertChainSummary& rhsChain, - ValueEquivalenceMap& valueMap) { - ValueEquivalenceMap tempMap = valueMap; - if (!areValuesEquivalent(lhsChain.baseTensor, rhsChain.baseTensor, tempMap)) { - return false; - } - - if (!areInsertWritesEquivalent(lhsChain.writes, rhsChain.writes, tempMap)) { - return false; - } - - if (!areValuesEquivalent(lhsChain.finalTensor, rhsChain.finalTensor, - tempMap)) { - return false; +/// Map arguments from one block to another using the given permutation. +/// Assumes that `lhs.getNumArguments() == rhs.getNumArguments()`. +/// Assumes that `permutation.size() == lhs.getNumArguments()`. +static void mapArguments(Block& lhs, Block& rhs, ArrayRef permutation, + IRMapping& m) { + for (const auto& [i, lhsArg] : enumerate(lhs.getArguments())) { + m.map(lhsArg, rhs.getArgument(permutation[i])); } - - valueMap = std::move(tempMap); - return true; } -static bool areInsertGroupsEquivalentRec(const size_t lhsChainIdx, - ArrayRef lhsChains, - ArrayRef rhsChains, - SmallVectorImpl& rhsChainUsed, - ValueEquivalenceMap& valueMap) { - if (lhsChainIdx == lhsChains.size()) { - return true; - } +/// Return a permutation vector, where permutation[i] maps the i-th target of +/// the lhs to the j-th target of the rhs. +static SmallVector getTargetPermutation(qc::CtrlOp lhs, qc::CtrlOp rhs, + const IRMapping& m) { + SmallVector permutation(lhs.getNumTargets()); + for (const auto& [i, trgt] : llvm::enumerate(lhs.getTargets())) { + const auto it = llvm::find(rhs.getTargets(), m.lookup(trgt)); + const auto j = std::distance(rhs.getTargets().begin(), it); + permutation[i] = j; + } + return permutation; +} - for (size_t rhsChainIdx = 0; rhsChainIdx < rhsChains.size(); ++rhsChainIdx) { - if (rhsChainUsed[rhsChainIdx] != 0) { - continue; - } +/// Return a permutation vector, where permutation[i] maps the i-th input +/// target of the lhs to the j-th input target of the rhs. +static SmallVector +getTargetPermutation(qco::CtrlOp lhs, qco::CtrlOp rhs, const IRMapping& m) { + SmallVector permutation(lhs.getNumTargets()); + for (const auto& [i, trgt] : llvm::enumerate(lhs.getInputTargets())) { + const auto it = llvm::find(rhs.getInputTargets(), m.lookup(trgt)); + const auto j = std::distance(rhs.getInputTargets().begin(), it); + permutation[i] = j; + } + return permutation; +} - ValueEquivalenceMap tempMap = valueMap; - if (!areInsertChainsEquivalent(lhsChains[lhsChainIdx], - rhsChains[rhsChainIdx], tempMap)) { - continue; - } +/// Return a permutation vector, where permutation[i] maps the i-th input +/// target of the lhs to the j-th input target of the rhs. +static SmallVector +getControlPermutation(qco::CtrlOp lhs, qco::CtrlOp rhs, const IRMapping& m) { + SmallVector permutation(lhs.getNumControls()); + for (const auto& [i, trgt] : llvm::enumerate(lhs.getInputControls())) { + const auto it = llvm::find(rhs.getInputControls(), m.lookup(trgt)); + const auto j = std::distance(rhs.getInputControls().begin(), it); + permutation[i] = j; + } + return permutation; +} - rhsChainUsed[rhsChainIdx] = 1; - if (areInsertGroupsEquivalentRec(lhsChainIdx + 1, lhsChains, rhsChains, - rhsChainUsed, tempMap)) { - valueMap = std::move(tempMap); - return true; +/// Compare two ctrl operations, allowing permutations of control and target +/// qubits. +static bool compareCtrlOps(qc::CtrlOp lhs, qc::CtrlOp rhs, const IRMapping& m) { + DenseSet workset; + workset.insert_range(rhs.getControls()); + for (const auto& ctrl : lhs.getControls()) { + const auto& v = m.lookup(ctrl); + if (!workset.contains(v)) { + return false; } - rhsChainUsed[rhsChainIdx] = 0; + workset.erase(v); } - return false; -} - -static bool areInsertGroupsEquivalent(ArrayRef lhsOps, - ArrayRef rhsOps, - ValueEquivalenceMap& valueMap) { - if (lhsOps.size() != rhsOps.size()) { + if (!workset.empty()) { return false; } - SmallVector lhsChains; - SmallVector rhsChains; - if (!summarizeInsertGroup(lhsOps, lhsChains) || - !summarizeInsertGroup(rhsOps, rhsChains)) { - return false; - } - if (lhsChains.size() != rhsChains.size()) { - return false; + workset.insert_range(rhs.getTargets()); + for (const auto& trgt : lhs.getTargets()) { + const auto& v = m.lookup(trgt); + if (!workset.contains(v)) { + return false; + } + workset.erase(v); } - SmallVector rhsChainUsed(rhsChains.size(), 0); - return areInsertGroupsEquivalentRec(0, lhsChains, rhsChains, rhsChainUsed, - valueMap); + return workset.empty(); } -/// DenseMapInfo specialization for StructuralOperationKey -template <> struct llvm::DenseMapInfo { - static StructuralOperationKey getEmptyKey() { - return StructuralOperationKey(DenseMapInfo::getEmptyKey()); - } - - static StructuralOperationKey getTombstoneKey() { - return StructuralOperationKey(DenseMapInfo::getTombstoneKey()); +/// Compare two ctrl operations, allowing permutations of input control and +/// input target qubits. +static bool compareCtrlOps(qco::CtrlOp lhs, qco::CtrlOp rhs, + const IRMapping& m) { + DenseSet workset; + workset.insert_range(rhs.getInputControls()); + for (const auto& ctrl : lhs.getInputControls()) { + const auto& v = m.lookup(ctrl); + if (!workset.contains(v)) { + return false; + } + workset.erase(v); } - static unsigned getHashValue(const StructuralOperationKey& key) { - if (key.op == getEmptyKey().op || key.op == getTombstoneKey().op) { - return DenseMapInfo::getHashValue(key.op); - } - return OperationStructuralHash{}(key.op); + if (!workset.empty()) { + return false; } - static bool isEqual(const StructuralOperationKey& lhs, - const StructuralOperationKey& rhs) { - // Handle special keys - if (lhs.op == getEmptyKey().op) { - return rhs.op == getEmptyKey().op; - } - if (lhs.op == getTombstoneKey().op) { - return rhs.op == getTombstoneKey().op; - } - if (rhs.op == getEmptyKey().op || rhs.op == getTombstoneKey().op) { + workset.insert_range(rhs.getInputTargets()); + for (const auto& trgt : lhs.getInputTargets()) { + const auto& v = m.lookup(trgt); + if (!workset.contains(v)) { return false; } - return lhs == rhs; + workset.erase(v); } -}; -static bool areFloatValuesNear(const APFloat& lhs, const APFloat& rhs, - const unsigned width) { + return workset.empty(); +} + +/// Compare two floating point numbers for approximate equivalence. +static bool approxCompareFloats(const APFloat& lhs, const APFloat& rhs, + const unsigned width) { if (lhs.isNaN() || rhs.isNaN()) { return lhs.isNaN() && rhs.isNaN(); } @@ -439,406 +281,381 @@ static bool areFloatValuesNear(const APFloat& lhs, const APFloat& rhs, return absDiff <= absTol + (relTol * scale); } -static bool areConstantAttributesEquivalent(const Attribute& lhs, - const Attribute& rhs) { - if (lhs == rhs) { - return true; - } - - if (auto lhsFloat = dyn_cast(lhs)) { - auto rhsFloat = dyn_cast(rhs); - if (!rhsFloat) { +/// Compare two attributes for equivality. +/// Explicitly checks `UnitAttr`, `IntegerAttr`, `FloatAttr`, `StringAttr`, +/// and `FlatSymbolRefAttr`. For any other type, the function simply returns +/// true. +static bool compareAttributes(Attribute lhs, Attribute rhs) { + if (dyn_cast(lhs)) { + if (!dyn_cast(rhs)) { return false; } - return areFloatValuesNear(lhsFloat.getValue(), rhsFloat.getValue(), - lhsFloat.getType().getIntOrFloatBitWidth()); - } - - return false; -} - -/// Compare two operations for structural equivalence. -/// Updates valueMap to track corresponding SSA values. -static bool areOperationsEquivalent(Operation* lhs, Operation* rhs, - ValueEquivalenceMap& valueMap) { - // Check operation name - if (lhs->getName() != rhs->getName()) { - return false; - } - - // Check arith::ConstantOp - if (auto lhsConst = dyn_cast(lhs)) { - auto rhsConst = dyn_cast(rhs); - if (!rhsConst) { + } else if (auto intAttrA = dyn_cast(lhs)) { + if (auto intAttrB = dyn_cast(rhs); + !intAttrB || intAttrA.getValue() != intAttrB.getValue() || + (intAttrA.getType().isInteger() && !intAttrB.getType().isInteger())) { return false; } - if (!areConstantAttributesEquivalent(lhsConst.getValue(), - rhsConst.getValue())) { + } else if (auto floatAttrA = dyn_cast(lhs)) { + if (auto floatAttrB = dyn_cast(rhs); + !floatAttrB || + !approxCompareFloats(floatAttrA.getValue(), floatAttrB.getValue(), + floatAttrA.getType().getIntOrFloatBitWidth())) { return false; } - } - - // Check LLVM::ConstantOp - if (auto lhsConst = dyn_cast(lhs)) { - auto rhsConst = dyn_cast(rhs); - if (!rhsConst) { + } else if (auto strAttrA = dyn_cast(lhs)) { + if (auto strAttrB = dyn_cast(rhs); + !strAttrB || strAttrA.getValue() != strAttrB.getValue()) { return false; } - if (!areConstantAttributesEquivalent(lhsConst.getValue(), - rhsConst.getValue())) { + } else if (auto symbolRefAttrA = + llvm::dyn_cast(lhs)) { + auto symbolRefAttrB = llvm::dyn_cast(rhs); + if (!symbolRefAttrB) { return false; } - } - // Check LLVM::CallOp - if (auto lhsCall = dyn_cast(lhs)) { - auto rhsCall = dyn_cast(rhs); - if (!rhsCall) { - return false; - } - if (lhsCall.getCallee() != rhsCall.getCallee()) { + if (symbolRefAttrA.getValue() != symbolRefAttrB.getValue()) { return false; } } - // Check number of operands and results - if (lhs->getNumOperands() != rhs->getNumOperands() || + return true; +} + +/// Compare two operations for structural equivalence, applying special +/// rules for `CtrlOp` s and `qtensor` s. +static bool compareOperations(Operation* lhs, Operation* rhs, + const IRMapping& m, const TensorMapping& tm) { + + // Compare top-level signature-like characteristics. + + if (lhs->getName() != rhs->getName() || + lhs->getNumOperands() != rhs->getNumOperands() || + lhs->getOperandTypes() != rhs->getOperandTypes() || lhs->getNumResults() != rhs->getNumResults() || + lhs->getResultTypes() != rhs->getResultTypes() || lhs->getNumRegions() != rhs->getNumRegions()) { return false; } - // Note: Attributes are intentionally not checked to allow relaxed comparison + // Compare attributes with specific types. + // Silently ignore missing ones. - // Check result types - if (lhs->getResultTypes() != rhs->getResultTypes()) { - return false; - } + for (const auto& namedAttrLhs : lhs->getAttrs()) { + const StringRef keyLhs = namedAttrLhs.getName().strref(); + if (!rhs->hasAttr(keyLhs)) { + continue; + } - ValueRange lhsOperands; - ValueRange rhsOperands; - if (auto lhsCtrl = dyn_cast(lhs)) { - auto rhsCtrl = dyn_cast(rhs); - if (!rhsCtrl) { + if (!compareAttributes(namedAttrLhs.getValue(), rhs->getAttr(keyLhs))) { return false; } - if (lhsCtrl.getTargets().size() != rhsCtrl.getTargets().size()) { + } + + // Compare operands. + // Because the order of target (control) qubits of CtrlOps doesn't matter, + // explicitly handle them here. + + if (isa(lhs)) { + assert(isa(rhs)); + if (!compareCtrlOps(cast(lhs), cast(rhs), m)) { return false; } - for (auto [lhsTarget, lhsArg] : - llvm::zip(lhsCtrl.getTargets(), lhsCtrl.getBody()->getArguments())) { - auto rhsTarget = valueMap[lhsTarget]; - if (!llvm::is_contained(rhsCtrl.getTargets(), rhsTarget)) { - return false; - } - auto it = llvm::find(rhsCtrl.getTargets(), rhsTarget); - auto index = std::distance(rhsCtrl.getTargets().begin(), it); - valueMap[lhsArg] = rhsCtrl.getBody()->getArgument(index); + } else if (isa(lhs)) { + assert(isa(rhs)); + if (!compareCtrlOps(cast(lhs), cast(rhs), m)) { + return false; } - lhsOperands = lhsCtrl.getControls(); - rhsOperands = rhsCtrl.getControls(); } else { - lhsOperands = lhs->getOperands(); - rhsOperands = rhs->getOperands(); - } + for (const auto& [lhsOperand, rhsOperand] : + llvm::zip_equal(lhs->getOperands(), rhs->getOperands())) { + if (hasTypeQubitTensor(lhsOperand)) { + assert(hasTypeQubitTensor(rhsOperand)); - // Check operands according to value mapping - for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOperands, rhsOperands)) { - if (!areValuesEquivalent(lhsOperand, rhsOperand, valueMap)) { - return false; + if (!tm.equals(lhsOperand, rhsOperand)) { + return false; + } + } else { + const auto& v = m.lookup(lhsOperand); + if (v != rhsOperand) { + return false; + } + } } } - // Update value mapping for results - for (auto [lhsResult, rhsResult] : - llvm::zip(lhs->getResults(), rhs->getResults())) { - valueMap[lhsResult] = rhsResult; - } - return true; } -/// Forward declaration for mutual recursion. -static bool areBlocksEquivalent(Block& lhs, Block& rhs, - ValueEquivalenceMap& valueMap); - -/// Compare two regions for structural equivalence. -static bool areRegionsEquivalent(Region& lhs, Region& rhs, - ValueEquivalenceMap& valueMap) { - if (lhs.getBlocks().size() != rhs.getBlocks().size()) { - return false; - } - - for (auto [lhsBlock, rhsBlock] : llvm::zip(lhs, rhs)) { - if (!areBlocksEquivalent(lhsBlock, rhsBlock, valueMap)) { - return false; +/// Extract and return "ready" operations. +/// These are operations that are independent from each other. +static SetVector getReadyOps(const SetVector& open, + const SetVector& closed) { + const auto isReady = [&closed](Value v) { + if (isa(v)) { + return true; } - } + return closed.contains(v.getDefiningOp()); + }; - return true; -} + SetVector ready; + for (Operation* op : open) { + if (ready.contains(op)) { + continue; + } -/// Check if an operation has memory effects or control flow side effects -/// that would prevent reordering. -static bool hasOrderingConstraints(Operation* op) { - // Terminators must maintain their position - if (op->hasTrait()) { - return true; - } + if (auto insert = dyn_cast(op)) { - // Symbol-defining operations (like function declarations) can be reordered - if (op->hasTrait() || - isa(op)) { - return false; - } + // If any of the inserts on the chain are ready, we consider the entire + // chain ready because the ready operations could be moved to the front + // of the chain. The analogous logic is applied to extracts. - // Check for memory effects that enforce ordering - if (auto memInterface = dyn_cast(op)) { - SmallVector effects; - memInterface.getEffects(effects); - - bool hasNonAllocFreeEffects = false; - for (const auto& effect : effects) { - // Allow operations with no effects or pure allocation/free effects - if (!isa( - effect.getEffect())) { - hasNonAllocFreeEffects = true; - break; + SmallVector chain; + for (qtensor::TensorIterator it(insert.getResult()); + it != std::default_sentinel; ++it) { + auto chainInsert = dyn_cast(it.operation()); + if (!chainInsert) { + break; + } + if (isReady(chainInsert.getScalar()) && + isReady(chainInsert.getIndex()) && !closed.contains(chainInsert)) { + chain.emplace_back(chainInsert); + } } - } - if (hasNonAllocFreeEffects) { - return true; - } - } + if (!chain.empty()) { + ready.insert_range(chain); + } - return false; -} + } else if (auto extract = dyn_cast(op)) { + SmallVector chain; + for (qtensor::TensorIterator it(extract.getOutTensor()); + it != std::default_sentinel; ++it) { + auto chainExtract = dyn_cast(it.operation()); + if (!chainExtract) { + break; + } -/// Build a dependence graph for operations. -/// Returns a map from each operation to the set of operations it depends on. -DenseMap> static buildDependenceGraph( - ArrayRef ops) { - DenseMap> dependsOn; - DenseMap valueProducers; - - // Build value-to-producer map and dependence relationships - for (Operation* op : ops) { - dependsOn[op] = DenseSet(); - - // This operation depends on the producers of its operands - for (const auto operand : op->getOperands()) { - if (auto it = valueProducers.find(operand); it != valueProducers.end()) { - dependsOn[op].insert(it->second); + if (isReady(chainExtract.getIndex()) && + !closed.contains(chainExtract)) { + chain.emplace_back(chainExtract); + } } - } - // Register this operation as the producer of its results - for (auto result : op->getResults()) { - valueProducers[result] = op; - } - } + if (!chain.empty()) { + ready.insert_range(chain); + } + } else if (auto dealloc = dyn_cast(op)) { - return dependsOn; -} + // Deallocations are ready whenever we've visited each op on the tensor + // chain. Because we initialize the iterator with its input tensor, the + // iterator already points at the previous operation. Thus use a + // do-while loop instead of a regular while. -/// Partition operations into groups that can be compared as multisets. -/// Operations in the same group are independent and can be reordered. -SmallVector> static partitionIndependentGroups( - ArrayRef ops) { - SmallVector> groups; - if (ops.empty()) { - return groups; - } + bool fullChain{true}; + qtensor::TensorIterator it(dealloc.getTensor()); - auto dependsOn = buildDependenceGraph(ops); - SmallVector currentGroup; + do { + if (!closed.contains(it.operation())) { + fullChain = false; + break; + } - for (auto* op : ops) { - bool dependsOnCurrent = false; + --it; + } while (std::prev(it) != it); - // Check if this operation depends on any operation in the current group - for (auto* groupOp : currentGroup) { - if (!dependsOn[op].contains(groupOp)) { - continue; + if (fullChain) { + ready.insert(dealloc); } - if (isCommutableQTensorInsertDependency(op, groupOp)) { - continue; - } - dependsOnCurrent = true; - break; - } - - // Check if this operation has ordering constraints - const auto hasConstraints = hasOrderingConstraints(op); - // If it depends on current group or has ordering constraints, - // finalize the current group and start a new one - if (dependsOnCurrent || (hasConstraints && !currentGroup.empty())) { - if (!currentGroup.empty()) { - groups.push_back(std::move(currentGroup)); - currentGroup = {}; - } - } + } else { - currentGroup.push_back(op); + // Otherwise, simply check if all operands are ready. - // If this operation has ordering constraints, finalize the group - if (hasConstraints) { - groups.push_back(std::move(currentGroup)); - currentGroup = {}; + if (llvm::all_of(op->getOperands(), isReady)) { + ready.insert(op); + } } } - // Add any remaining operations - if (!currentGroup.empty()) { - groups.push_back(std::move(currentGroup)); - } - - return groups; + return ready; } -/// Compare two groups of independent operations using multiset equivalence. -static bool areIndependentGroupsEquivalent(ArrayRef lhsOps, - ArrayRef rhsOps) { - if (lhsOps.size() != rhsOps.size()) { +static bool compareBlocks(Block& lhs, Block& rhs, + SetVector& lhsClosed, + SetVector& rhsClosed, IRMapping& m, + TensorMapping& tm) { + if (lhs.getNumArguments() != rhs.getNumArguments()) { return false; } - // Build frequency maps for both groups - DenseMap lhsFrequencyMap; - DenseMap rhsFrequencyMap; + // Map block arguments while allowing commutation of operands for `CtrlOp`s. - for (auto* op : lhsOps) { - lhsFrequencyMap[StructuralOperationKey(op)]++; + if (isa(lhs.getParentOp())) { + assert(isa(rhs.getParentOp())); + auto lhsCtrl = cast(lhs.getParentOp()); + auto rhsCtrl = cast(rhs.getParentOp()); + mapArguments(lhs, rhs, getTargetPermutation(lhsCtrl, rhsCtrl, m), m); + } else if (isa(lhs.getParentOp())) { + assert(isa(rhs.getParentOp())); + auto lhsCtrl = cast(lhs.getParentOp()); + auto rhsCtrl = cast(rhs.getParentOp()); + mapArguments(lhs, rhs, getTargetPermutation(lhsCtrl, rhsCtrl, m), m); + } else { + SmallVector permutation(lhs.getNumArguments()); + std::iota(permutation.begin(), permutation.end(), 0); + mapArguments(lhs, rhs, permutation, m); } - for (auto* op : rhsOps) { - rhsFrequencyMap[StructuralOperationKey(op)]++; - } + SetVector lhsOpen; + SetVector rhsOpen; - // Check structural equivalence - if (lhsFrequencyMap.size() != rhsFrequencyMap.size()) { - return false; - } + for_each(lhs.getOperations(), [&](auto& op) { lhsOpen.insert(&op); }); + for_each(rhs.getOperations(), [&](auto& op) { rhsOpen.insert(&op); }); - // NOLINTNEXTLINE(bugprone-nondeterministic-pointer-iteration-order) - for (const auto& [lhsKey, lhsCount] : lhsFrequencyMap) { - auto it = rhsFrequencyMap.find(lhsKey); - if (it == rhsFrequencyMap.end() || it->second != lhsCount) { - return false; - } - } + // Compare block operations topologically. - return true; -} + while (true) { + const auto lhsReady = getReadyOps(lhsOpen, lhsClosed); + const auto rhsReady = getReadyOps(rhsOpen, rhsClosed); -/// Compare two blocks for structural equivalence, allowing permutations -/// of independent operations. -static bool areBlocksEquivalent(Block& lhs, Block& rhs, - ValueEquivalenceMap& valueMap) { - // Check block arguments - if (lhs.getNumArguments() != rhs.getNumArguments()) { - return false; - } + if (lhsReady.empty() && rhsReady.empty()) { + break; + } - for (auto [lhsArg, rhsArg] : - llvm::zip(lhs.getArguments(), rhs.getArguments())) { - if (lhsArg.getType() != rhsArg.getType()) { + if (lhsReady.size() != rhsReady.size()) { return false; } - if (!valueMap.contains(lhsArg)) { - valueMap[lhsArg] = rhsArg; - } - } - // Collect all operations - SmallVector lhsOps; - SmallVector rhsOps; + // Because there may be multiple structural equivalent operations (think + // arith.constant, for example), we apply the assumption that the first + // occurrence on the lhs corresponds to the first one on the rhs, etc. - for (Operation& op : lhs) { - lhsOps.push_back(&op); - } + DenseSet matched; + matched.reserve(rhsReady.size()); - for (Operation& op : rhs) { - rhsOps.push_back(&op); - } + for (Operation* lhsOp : lhsReady) { + SetVector::iterator it = rhsReady.begin(); + for (; it != rhsReady.end(); it = std::next(it)) { + Operation* rhsOp = *it; - if (lhsOps.size() != rhsOps.size()) { - return false; - } + if (matched.contains(rhsOp)) { + continue; + } - // Partition operations into independent groups - auto lhsGroups = partitionIndependentGroups(lhsOps); - auto rhsGroups = partitionIndependentGroups(rhsOps); + if (compareOperations(lhsOp, rhsOp, m, tm)) { + matched.insert(rhsOp); - if (lhsGroups.size() != rhsGroups.size()) { - return false; - } + if (isa(lhsOp)) { + assert(isa(rhsOp)); + auto lhsCtrl = cast(lhsOp); + auto rhsCtrl = cast(rhsOp); - // Compare each group - for (size_t groupIdx = 0; groupIdx < lhsGroups.size(); ++groupIdx) { - auto& lhsGroup = lhsGroups[groupIdx]; - auto& rhsGroup = rhsGroups[groupIdx]; + SmallVector permutation; + permutation.reserve(lhsCtrl.getNumQubits()); + permutation.append(getControlPermutation(lhsCtrl, rhsCtrl, m)); + for (const auto i : getTargetPermutation(lhsCtrl, rhsCtrl, m)) { + permutation.emplace_back(lhsCtrl.getNumControls() + i); + } + mapResults(lhsCtrl, rhsCtrl, permutation, m); + } else if (isa(lhsOp)) { + assert(isa(rhsOp)); + auto lhsAlloc = cast(lhsOp); + auto rhsAlloc = cast(rhsOp); + tm.map(lhsAlloc.getResult(), rhsAlloc.getResult()); + } else if (isa(lhsOp)) { + assert(isa(rhsOp)); + auto lhsFrom = cast(lhsOp); + auto rhsFrom = cast(rhsOp); + tm.map(lhsFrom.getResult(), rhsFrom.getResult()); + } else if (isa(lhsOp)) { + assert(isa(rhsOp)); + auto lhsExtract = cast(lhsOp); + auto rhsExtract = cast(rhsOp); + m.map(lhsExtract.getResult(), rhsExtract.getResult()); + } else { + SmallVector permutation(lhsOp->getNumResults()); + std::iota(permutation.begin(), permutation.end(), 0); + mapResults(lhsOp, rhsOp, permutation, m); + } - const bool lhsInsertGroup = llvm::all_of(lhsGroup, isQTensorInsertOp); - const bool rhsInsertGroup = llvm::all_of(rhsGroup, isQTensorInsertOp); - if (lhsInsertGroup || rhsInsertGroup) { - if (!lhsInsertGroup || !rhsInsertGroup) { - return false; + m.map(lhsOp, rhsOp); + break; + } } - if (!areInsertGroupsEquivalent(lhsGroup, rhsGroup, valueMap)) { + + if (it == rhsReady.end()) { return false; } - continue; } - if (!areIndependentGroupsEquivalent(lhsGroup, rhsGroup)) { - return false; - } + // At this point, we've successfully matched each operation on the lhs + // with one on the rhs. Subsequently, update the open and closed sets and + // recursively compare the nested regions of each operation pair. - // Update value mappings for operations in this group - // We need to match operations and update the value map - // Since they are structurally equivalent, we can match them - // by trying all permutations (for small groups) or use a greedy approach - - // Use a simple greedy matching - DenseSet matchedRhs; - for (Operation* lhsOp : lhsGroup) { - bool matched = false; - for (Operation* rhsOp : rhsGroup) { - if (matchedRhs.contains(rhsOp)) { - continue; - } + lhsOpen.set_subtract(lhsReady); + lhsClosed.set_union(lhsReady); - ValueEquivalenceMap tempMap = valueMap; - if (areOperationsEquivalent(lhsOp, rhsOp, tempMap)) { - valueMap = std::move(tempMap); - matchedRhs.insert(rhsOp); - matched = true; - - // Recursively compare regions - for (auto [lhsRegion, rhsRegion] : - llvm::zip(lhsOp->getRegions(), rhsOp->getRegions())) { - if (!areRegionsEquivalent(lhsRegion, rhsRegion, valueMap)) { - return false; - } - } + rhsOpen.set_subtract(rhsReady); + rhsClosed.set_union(rhsReady); + + SetVector::iterator it = lhsReady.begin(); + for (; it != lhsReady.end(); it = std::next(it)) { + Operation* opLhs = *it; + + if (opLhs->getNumRegions() > 0) { + Operation* opRhs = m.lookup(opLhs); + assert(opLhs->getNumRegions() == opRhs->getNumRegions()); + const auto nequiv = range_size(make_filter_range( + llvm::zip_equal(opLhs->getRegions(), opRhs->getRegions()), + [&](const auto& zip) { + const auto& [lhsRegion, rhsRegion] = zip; + return compareRegions(lhsRegion, rhsRegion, lhsClosed, rhsClosed, + m, tm); + })); + if (nequiv != opLhs->getNumRegions()) { break; } } + } - if (!matched) { - return false; - } + if (it != lhsReady.end()) { + return false; } } return true; } +/// Compare two regions for structural equivalence. +static bool compareRegions(Region& lhs, Region& rhs, + SetVector& lhsClosed, + SetVector& rhsClosed, IRMapping& m, + TensorMapping& tm) { + if (lhs.getBlocks().size() != rhs.getBlocks().size()) { + return false; + } + + for (const auto [lhsBlock, rhsBlock] : llvm::zip_equal(lhs, rhs)) { + if (!compareBlocks(lhsBlock, rhsBlock, lhsClosed, rhsClosed, m, tm)) { + return false; + } + + m.map(&lhsBlock, &rhsBlock); + } + + return true; +} + bool areModulesEquivalentWithPermutations(ModuleOp lhs, ModuleOp rhs) { - ValueEquivalenceMap valueMap; - return areRegionsEquivalent(lhs.getBodyRegion(), rhs.getBodyRegion(), - valueMap); + IRMapping m; + SetVector lhsClosed; + SetVector rhsClosed; + TensorMapping tm{.lhsEquivGroups = getEquivGroup(lhs), + .rhsEquivGroups = getEquivGroup(rhs), + .equivGroupMapping = DenseMap{}}; + + return compareRegions(lhs.getBodyRegion(), rhs.getBodyRegion(), lhsClosed, + rhsClosed, m, tm); } diff --git a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp index 27aeec9455..a76e20cc1a 100644 --- a/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp +++ b/mlir/unittests/Dialect/QTensor/IR/test_qtensor_ir.cpp @@ -506,7 +506,10 @@ INSTANTIATE_TEST_SUITE_P( QTensorIntegrationTestCase{ "QTensorInsertExtractIndexMismatch", MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch), - MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch)})); + MQT_NAMED_BUILDER(qtensorInsertExtractIndexMismatch)}, + QTensorIntegrationTestCase{"QTensorAlternativeInsertChain", + MQT_NAMED_BUILDER(qtensorAlternativeChain), + MQT_NAMED_BUILDER(qtensorChain)})); /// @} } // namespace diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index d6dca2dca4..8410afd7a8 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -287,11 +287,10 @@ void trivialControlledX(QCProgramBuilder& b) { } void repeatedControlledX(QCProgramBuilder& b) { - auto control = b.allocQubit(); - b.h(control); - for (auto i = 0; i < 50; i++) { - auto qubit = b.allocQubit(); - b.cx(control, qubit); + auto q = b.allocQubitRegister(64); + b.h(q[0]); + for (auto i = 1; i < 64; i++) { + b.cx(q[0], q[i]); } } diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 6c43d7f668..1fadf83eb2 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -272,11 +272,21 @@ void trivialControlledX(QCOProgramBuilder& b) { } void repeatedControlledX(QCOProgramBuilder& b) { - auto q0 = b.allocQubit(); - auto control = b.h(q0); - for (auto i = 0; i < 50; i++) { - auto qubit = b.allocQubit(); - control = b.cx(control, qubit).first; + auto tensor = b.qtensorAlloc(64); + + Value q0; + std::tie(tensor, q0) = b.qtensorExtract(tensor, 0); + + SmallVector values(63); + for (auto i = 1; i < 64; i++) { + Value qi; + std::tie(tensor, qi) = b.qtensorExtract(tensor, i); + values[i - 1] = qi; + } + + q0 = b.h(q0); + for (auto i = 1; i < 64; i++) { + std::tie(q0, values[i - 1]) = b.cx(q0, values[i - 1]); } } @@ -2335,6 +2345,42 @@ void qtensorInsertExtractSameIndex(QCOProgramBuilder& b) { b.qtensorInsert(q2, extractOutTensor1, 0); } +void qtensorChain(QCOProgramBuilder& b) { + Value q0; + Value q1; + Value q2; + auto qtensor = b.qtensorAlloc(3); + std::tie(qtensor, q0) = b.qtensorExtract(qtensor, 0); + std::tie(qtensor, q1) = b.qtensorExtract(qtensor, 1); + std::tie(qtensor, q2) = b.qtensorExtract(qtensor, 2); + q0 = b.h(q0); + q1 = b.h(q1); + std::tie(q1, q2) = b.cx(q1, q2); + + qtensor = b.qtensorInsert(q2, qtensor, 2); + qtensor = b.qtensorInsert(q1, qtensor, 1); + qtensor = b.qtensorInsert(q0, qtensor, 0); + b.qtensorDealloc(qtensor); +} + +void qtensorAlternativeChain(QCOProgramBuilder& b) { + Value q0; + Value q1; + Value q2; + auto qtensor = b.qtensorAlloc(3); + std::tie(qtensor, q0) = b.qtensorExtract(qtensor, 0); + q0 = b.h(q0); + std::tie(qtensor, q1) = b.qtensorExtract(qtensor, 1); + q1 = b.h(q1); + std::tie(qtensor, q2) = b.qtensorExtract(qtensor, 2); + std::tie(q1, q2) = b.cx(q1, q2); + + qtensor = b.qtensorInsert(q0, qtensor, 0); + qtensor = b.qtensorInsert(q1, qtensor, 1); + qtensor = b.qtensorInsert(q2, qtensor, 2); + b.qtensorDealloc(qtensor); +} + void simpleWhileReset(QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto q1 = b.h(q0); diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index 6f6323db22..1a5f5ce229 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -1122,4 +1122,13 @@ void qtensorInsertExtractIndexMismatch(QCOProgramBuilder& b); /// Inserts a qubit into a tensor and extracts it immediately at the same index. void qtensorInsertExtractSameIndex(QCOProgramBuilder& b); +/// Extracts three qubits with ascending index (0, 1, 2), performs a +/// computation, and finally inserts the qubits in ascending order (0, 1, 2). +void qtensorChain(QCOProgramBuilder& b); + +/// Performs the same computation as the `qtensorChain` function, but uses +/// qubits immediately after the extract and inserts the qubits in descending +/// order (2, 1, 0). +void qtensorAlternativeChain(QCOProgramBuilder& b); + } // namespace mlir::qco