diff --git a/include/jeff/IR/JeffOps.td b/include/jeff/IR/JeffOps.td index f5045d1..f76626b 100644 --- a/include/jeff/IR/JeffOps.td +++ b/include/jeff/IR/JeffOps.td @@ -1882,9 +1882,87 @@ def FloatArrayCreateOp : FloatArray_Op<"float_array_create", []> { class SCF_Op traits = []> : Jeff_Op; -def SwitchOp : SCF_Op<"switch", []> { - let summary = ""; - let description = [{}]; +def SwitchOp : SCF_Op<"switch", [SingleBlockImplicitTerminator<"jeff::YieldOp">]> { + let summary = "Multi-way branch on an integer selector"; + let description = [{ + The `jeff.switch` operation dispatches on an integer selector to one of + several branch regions or, if no case matches, an optional default + region. Branch labels are positional: the i-th branch region runs when + `selection` equals `i`. If `selection` does not match any branch label, + control flows to the `default` region. + + The selector operand `selection` must be a `SupportedIntType` (one of + `i1`, `i8`, `i16`, `i32`, `i64`). + + The op carries a list of in-values (`in_values`) shared across all + branches and the default region: each region declares its own local + block arguments, bound positionally to the in-values on entry. Each + region must terminate with a `jeff.yield` whose operands provide the + op's results for the selected branch. + + The default region is optional. When present, it covers any selector + value not matched by an explicit branch. + + Unlike a `for` or `while` loop, a `switch` has no iteration semantics: + it dispatches once. The op's result count and types are therefore not + related to the in-value count and types — branches may yield any + number of values of any type, provided every branch (and the default, + if present) yields the same shape. + + Schematic syntax: + ``` + jeff.switch ( $selection, $in_values ) : ( $selection_type, $in_value_types ) + -> ( $result_types ) + ( case $i args ( $local_names ) { $branch_body } )* + ( default args ( $local_names ) { $default_body } )? + ``` + + Example: + ```mlir + %r1, %r2 = jeff.switch (%sel, %a, %b) : (i32, i32, i64) -> (i32, i64) + case 0 args(%x, %y) { + jeff.yield %x, %y : i32, i64 + } + case 1 args(%x, %y) { + // ... compute %next_x, %next_y ... + jeff.yield %next_x, %next_y : i32, i64 + } + default args(%x, %y) { + jeff.yield %x, %y : i32, i64 + } + ``` + + Implicit inferences in the syntax: + - Each region's `args(%x, %y)` lists *only the local block-argument + names*, not assignments. The block arguments are bound positionally + to the op's `in_values` (in the example above, `%x` to `%a`, `%y` + to `%b`) — the operands are already named in the op's header. + - Block-argument types are not written: they are inferred from the + in-value types in the header. + + Invariants: + - `selection` is a `SupportedIntType`. + - Case labels are contiguous integers starting at 0: the i-th + branch region is labeled `case i`. The parser enforces this so + that print-then-parse is faithful. + - Every region's block argument count and types match `in_values`. + - Every region's yielded value count and types match the op's + result types. + + Underspecified: + - The semantics when `selection` does not match any branch label and + the default region is absent. + - Structurally, the op admits zero branches and an empty default; + such a form is legal but meaningless and is not constrained out. + + Differences from `scf.index_switch`: + - Selector type is restricted to `SupportedIntType` (no `index`). + - Case labels are implicit and positional (`case 0`, `case 1`, ...) + rather than arbitrary integer literals. + - The default region is optional rather than mandatory. + - Supports in-values shared across regions, decoupled from the + op's results (`scf.index_switch` has no shared-value mechanism). + }]; let arguments = (ins SupportedIntType:$selection, @@ -1901,13 +1979,87 @@ def SwitchOp : SCF_Op<"switch", []> { ); let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; let hasRegionVerifier = 1; } -def ForOp : SCF_Op<"for", []> { - let summary = ""; - let description = [{}]; +def ForOp : SCF_Op<"for", [SingleBlockImplicitTerminator<"jeff::YieldOp">]> { + let summary = "Counted for loop with iteration arguments"; + let description = [{ + The `jeff.for` operation represents a counted loop over a half-open + integer range `[start, stop)`, advancing by `step` on each iteration. + The body executes once for every value the induction variable takes + while it remains less than `stop`, starting at `start`. + + The op takes three SSA integer operands — `start`, `stop`, and + `step`, each of which must be a `SupportedIntType` (one of `i1`, + `i8`, `i16`, `i32`, `i64`). The induction variable has the same + type as `start`, `stop`, and `step`. + + The op may optionally carry a list of loop-carried values + (`in_values`). These are bound to the corresponding body block + arguments on entry to each iteration and replaced at the end of + each iteration by the operands of the terminating `jeff.yield`. + The op's results (`out_values`) hold the final values of the + carried variables after the last iteration; their types must match + the types of `in_values`. + + The body is a single block whose arguments are the induction + variable followed by one argument per loop-carried value. It must + terminate with a `jeff.yield` whose operand types and count match + `in_values`. When no values are carried, the terminator is inserted + implicitly by the parser. + + Schematic syntax: + ``` + jeff.for $induction_variable = $start to $stop step $step + ( args ( $assignments ) -> ( $result_types ) )? + : $induction_variable_type + { $body } + ``` + + The body's block-argument types are not printed explicitly: they + are recovered from `-> ( $result_types )` and must equal the iter-arg types. + + Example without carried values: + ```mlir + jeff.for %i = %lo to %hi step %s : i32 { + // body + } + ``` + + Example with carried values: + ```mlir + %r1, %r2 = jeff.for %i = %lo to %hi step %s + args(%x = %a, %y = %b) -> (i32, i64) : i32 { + // ... compute %next_x, %next_y ... + jeff.yield %next_x, %next_y : i32, i64 + } + ``` + + Invariants enforced by the verifier: + - `start`, `stop`, and `step` are each a `SupportedIntType`. + - The induction variable has the same type as `start`, `stop`, and `step`. + - The number of `in_values` equals the number of results. + - Each `in_value`, the corresponding body block argument, and the + corresponding result have matching types. + + Underspecified (not currently enforced): + - Comparison signedness (the operand types are signless MLIR + integers). + - That `step` is non-zero and positive. + + Differences from `scf.for`: + - Bounds and step are restricted to `SupportedIntType` + (no `index`, no widths outside `{1, 8, 16, 32, 64}`). + - Uses the keyword `args` instead of `iter_args`, with a trailing + `-> ( result_types )` instead of the SCF arrow placement. + - The induction-variable type annotation `:` is always printed, + even when no values are carried. + - Does not implement `LoopLikeOpInterface`, + `RegionBranchOpInterface`, `RecursiveMemoryEffects`, or + `AutomaticAllocationScope`, so upstream loop-aware passes + (LICM, peeling, generic unrolling) do not apply. + }]; let arguments = (ins SupportedIntType:$start, @@ -1927,9 +2079,90 @@ def ForOp : SCF_Op<"for", []> { let hasRegionVerifier = 1; } -def WhileOp : SCF_Op<"while", []> { - let summary = ""; - let description = [{}]; +def WhileOp : SCF_Op<"while", [SingleBlockImplicitTerminator<"jeff::YieldOp">]> { + let summary = "While loop with iteration arguments"; + let description = [{ + The `jeff.while` operation represents a loop that repeats while a + condition holds. It has two regions: + + - The `condition` region computes the loop's continuation predicate. + It must terminate with a `jeff.yield` whose only operand is an + `i1`. By the `jeff` language spec, the condition region may not + modify quantum state (qubits or qubit registers); it may only + inspect classical metadata. + - The `body` region is executed when the condition yields `true`. + It must terminate with a `jeff.yield` whose operands become the + next iteration's input values, or — on the last iteration — the + op's results. + + The op carries a list of `in_values` shared between both regions. + Each region declares its own local block arguments, bound to those + `in_values` on entry. Because the condition region does not modify + state, the same iteration values are visible to the body region + (formally, they are forwarded from the condition's inputs). + + The types of `in_values`, the body region's yielded values, and the + op's results all match. There is a single signature for the loop. + + Syntax: + ``` + [ $results = ] jeff.while [ : ( $types ) ] + args ( $cond_assignments ) $cond_body + args ( $body_names ) $body_body + ``` + + The `: ( $types )` annotation is present iff `$types` is non-empty. + Both `args(...)` clauses are always emitted (possibly empty), + serving as a structural separator between the regions. + + Implicit inferences in the syntax: + - Result types are not written: they are the same as the `: (...)` + types (Jeff's single-signature design choice, T1 = T2). + - Block-argument types inside the `args(...)` clauses are not + written: they are inferred from the same `: (...)` annotation. + - The body region's `args(...)` lists only local block-argument + names, not assignments. The operands are stated once in the + condition's `args(...)`; the body's block arguments are bound + positionally to those same operands. + + Example with loop-carried values: + ```mlir + %r1, %r2 = jeff.while : (i32, i64) args(%c_x = %a, %c_y = %b) { + // condition + jeff.yield %pred : i1 + } args(%b_x, %b_y) { + // body + jeff.yield %next_x, %next_y : i32, i64 + } + ``` + + Example with no loop-carried values: + ```mlir + jeff.while args() { + jeff.yield %pred : i1 + } args() { + // body (no values to yield) + } + ``` + + Invariants enforced by the verifier: + - The number of `in_values` equals the number of results. + - Each `in_value`, the corresponding block argument of each region, + and the corresponding result have matching types. + - Both regions terminate with a `jeff.yield` (per + `SingleBlockImplicitTerminator`). + + Differences from `scf.while`: + - No `do` keyword between the regions; `args(...)` is the + structural separator. + - The condition region yields only `i1`; it does not also yield + values to the body region (in SCF, `scf.condition` yields an + `i1` plus the carried values). + - Both regions share a single op-level operand list; in SCF the + operand list feeds only the `before` region. + - Input types and result types are required to match (single + signature), reflecting the `jeff` language design choice of T1 = T2. + }]; let arguments = (ins Variadic:$in_values @@ -1949,9 +2182,91 @@ def WhileOp : SCF_Op<"while", []> { let hasRegionVerifier = 1; } -def DoWhileOp : SCF_Op<"doWhile", []> { - let summary = ""; - let description = [{}]; +def DoWhileOp : SCF_Op<"doWhile", [SingleBlockImplicitTerminator<"jeff::YieldOp">]> { + let summary = "Do-while loop with iteration arguments"; + let description = [{ + The `jeff.doWhile` operation represents a loop that repeats while a + condition holds. Unlike the `jeff.while` operation, an initial + iteration of the loop occurs before the condition is checked for the + first time. It has two regions: + + - The `condition` region computes the loop's continuation predicate. + It must terminate with a `jeff.yield` whose only operand is an + `i1`. By the `jeff` language spec, the condition region may not + modify quantum state (qubits or qubit registers); it may only + inspect classical metadata. + - The `body` region executes unconditionally on the first iteration, + then repeats while the condition yields `true`. + It must terminate with a `jeff.yield` whose operands become the + next iteration's input values, or — on the last iteration — the + op's results. + + The op carries a list of `in_values` shared between both regions. + Each region declares its own local block arguments, bound to those + `in_values` on entry. On the first iteration the body's block + arguments are bound directly from the op's operands; on subsequent + iterations they are bound from the body's previous yield. Because + the condition region does not modify state, the same iteration + values are visible to both regions. + + The types of `in_values`, the body region's yielded values, and the + op's results all match. There is a single signature for the loop. + + Syntax: + ``` + [ $results = ] jeff.doWhile [ : ( $types ) ] + args ( $body_assignments ) $body_body + args ( $cond_names ) $cond_body + ``` + + The `: ( $types )` annotation is present iff `$types` is non-empty. + Both `args(...)` clauses are always emitted (possibly empty), + serving as a structural separator between the regions. + + Implicit inferences in the syntax: + - Result types are not written: they are the same as the `: (...)` + types (Jeff's single-signature design choice, T1 = T2). + - Block-argument types inside the `args(...)` clauses are not + written: they are inferred from the same `: (...)` annotation. + - The condition region's `args(...)` lists only local + block-argument names, not assignments. The operands are stated + once in the body's `args(...)`; the condition's block arguments + are bound positionally to those same operands. + + Example with loop-carried values: + ```mlir + %r1, %r2 = jeff.doWhile : (i32, i64) args(%b_x = %a, %b_y = %b) { + // body + jeff.yield %next_x, %next_y : i32, i64 + } args(%c_x, %c_y) { + // condition + jeff.yield %pred : i1 + } + ``` + + Example with no loop-carried values: + ```mlir + jeff.doWhile args() { + } args() { + jeff.yield %pred : i1 + } + ``` + + Invariants enforced by the verifier: + - The number of `in_values` equals the number of results. + - Each `in_value`, the corresponding block argument of each region, + and the corresponding result have matching types. + - Both regions terminate with a `jeff.yield` (per + `SingleBlockImplicitTerminator`). + + Differences from `jeff.while`: + - The body runs unconditionally on the first iteration, before the + condition is ever evaluated. The body region is therefore + declared before the condition region in the op's region list and + in the printed form. + + The differences from `scf.while` are the same as for `jeff.while`. + }]; let arguments = (ins Variadic:$in_values diff --git a/lib/IR/JeffOps.cpp b/lib/IR/JeffOps.cpp index 63ecbb8..8b1f9b4 100644 --- a/lib/IR/JeffOps.cpp +++ b/lib/IR/JeffOps.cpp @@ -27,10 +27,11 @@ #include #include #include +#include #include #include -#include +#include using namespace mlir; using namespace mlir::jeff; @@ -158,69 +159,215 @@ void IntBinaryOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRCo } void SwitchOp::print(OpAsmPrinter& p) { - auto inValues = getInValues(); - - p << '(' << getSelection() << ")"; - p << " : " << getSelection().getType(); - p << " -> (" << inValues.getTypes() << ") "; - - auto branches = getBranches(); - for (size_t i = 0; i < branches.size(); ++i) { + // The op's operand list is `selection` followed by `in_values`, in that + // order, so we can stream both together. + auto operands = (*this)->getOperands(); + auto resultTypes = getResultTypes(); + + // Header: `(%sel, %a, %b) : (i32, T_in...) -> (T_out...)`. + p << " ("; + llvm::interleaveComma(operands, p); + p << ") : ("; + llvm::interleaveComma(operands.getTypes(), p); + p << ") -> (" << resultTypes << ")"; + + auto printRegionWithArgs = [&](Region& region) { + p << " args("; + llvm::interleaveComma(region.getArguments(), p); + p << ") "; + p.printRegion(region, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/!resultTypes.empty()); + }; + + for (auto [i, branch] : llvm::enumerate(getBranches())) { p.printNewline(); - p << "case " << i << ' '; - auto& branch = branches[i]; - auto regionArgs = branch.getArguments(); - printInitializationList(p, regionArgs, inValues, "args"); - p << ' '; - p.printRegion(branch, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/!inValues.empty()); + p << "case " << i; + printRegionWithArgs(branch); } auto& defaultRegion = getDefault(); if (!defaultRegion.empty()) { p.printNewline(); - p << "default "; - auto regionArgs = defaultRegion.getArguments(); - printInitializationList(p, regionArgs, inValues, "args"); - p << ' '; - p.printRegion(defaultRegion, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/!inValues.empty()); + p << "default"; + printRegionWithArgs(defaultRegion); } p.printOptionalAttrDict((*this)->getAttrs()); } -ParseResult SwitchOp::parse(OpAsmParser& /*parser*/, OperationState& /*result*/) { - // TODO: Implement this - llvm::report_fatal_error("SwitchOp::parse is not implemented yet"); -} +ParseResult SwitchOp::parse(OpAsmParser& parser, OperationState& result) { + auto& builder = parser.getBuilder(); -LogicalResult SwitchOp::verify() { - if (getInValues().size() != getNumResults()) { - return emitOpError("mismatch in number of input and output values"); + // The op declares `$default` first and `$branches` after, + // so the default region must occupy index 0. + // Pre-allocate it; populate later if present. + Region* defaultRegion = result.addRegion(); + + // Parse `(%sel, %a, %b)` — selector first, then in-values. + llvm::SmallVector operands; + if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + return parser.parseOperand(operands.emplace_back()); + })) { + return failure(); + } + if (operands.empty()) { + return parser.emitError(parser.getNameLoc(), "expected at least the selector operand"); + } + + // Parse `: (T_sel, T_in...)` — operand types, in the same order. + llvm::SmallVector operandTypes; + if (parser.parseColon() || parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + return parser.parseType(operandTypes.emplace_back()); + })) { + return failure(); + } + if (operandTypes.size() != operands.size()) { + return parser.emitError(parser.getNameLoc()) + << "expected " << operands.size() << " operand types but got " + << operandTypes.size(); + } + + // Parse `-> (T_out...)` — result types, independent of operand types. + llvm::SmallVector resultTypes; + if (parser.parseArrow() || parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + return parser.parseType(resultTypes.emplace_back()); + })) { + return failure(); + } + + if (parser.resolveOperands(operands, operandTypes, parser.getCurrentLocation(), + result.operands)) { + return failure(); + } + + // In-value types are everything after the selector. + auto inValueTypes = llvm::ArrayRef(operandTypes).drop_front(1); + + // Helper that parses `args(%x, %y) { ... }` into a region. + auto parseRegionWithArgs = [&](Region& region) -> ParseResult { + llvm::SmallVector regionArgs; + if (parser.parseKeyword("args") || + parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + return parser.parseArgument(regionArgs.emplace_back()); + })) { + return failure(); + } + if (regionArgs.size() != inValueTypes.size()) { + return parser.emitError(parser.getNameLoc()) + << "expected " << inValueTypes.size() << " region arguments but got " + << regionArgs.size(); + } + for (auto [arg, ty] : llvm::zip_equal(regionArgs, inValueTypes)) { + arg.type = ty; + } + if (parser.parseRegion(region, regionArgs)) { + return failure(); + } + SwitchOp::ensureTerminator(region, builder, result.location); + return success(); + }; + + // Parse `case N args(...) { ... }` while the keyword is present. + // Case labels are positional; require them to be 0, 1, 2, ... + // so that print-then-parse round-trips faithfully. + int64_t expectedCase = 0; + while (succeeded(parser.parseOptionalKeyword("case"))) { + int64_t caseValue = 0; + if (parser.parseInteger(caseValue)) { + return failure(); + } + if (caseValue != expectedCase) { + return parser.emitError(parser.getNameLoc()) + << "expected `case " << expectedCase << "` but got `case " << caseValue << "`"; + } + ++expectedCase; + + Region* branch = result.addRegion(); + if (parseRegionWithArgs(*branch)) { + return failure(); + } + } + + // Optional `default args(...) { ... }`. + if (succeeded(parser.parseOptionalKeyword("default"))) { + if (parseRegionWithArgs(*defaultRegion)) { + return failure(); + } + } + + result.addTypes(resultTypes); + + if (parser.parseOptionalAttrDict(result.attributes)) { + return failure(); } return success(); } +/** + * @brief Verifies the 'case' and 'default' regions of a `jeff.switch`. + * + * `in_values` and `out_values` are independent for switch: there is no count or type relationship + * between them. Each region's block arguments mirror `in_values`, and each region's `jeff.yield` + * mirrors the op's results - two separate checks. + */ LogicalResult SwitchOp::verifyRegions() { - llvm::SmallVector regions; - auto branches = getBranches(); - regions.reserve(1 + branches.size()); - regions.push_back(&getDefault()); - for (auto& branch : branches) { - regions.push_back(&branch); - } + auto inValueTypes = getInValues().getTypes(); + auto resultTypes = getResultTypes(); + + // Helper that verifies one region (a `case` body or the `default`). + auto verifyRegion = [&](Region& region, const llvm::Twine& name) -> LogicalResult { + // The parser always pre-allocates the default region, + // leaving it empty when the source has no `default { ... }` clause. + if (region.empty()) { + return success(); + } - auto inValues = getInValues(); - auto outValues = getOutValues(); + // Block-argument count matches the `in_values` count. + auto regionArgs = region.getArguments(); + if (regionArgs.size() != inValueTypes.size()) { + return emitOpError() << name << " region has " << regionArgs.size() + << " block arguments but op has " << inValueTypes.size() + << " in-values"; + } + // Block-argument types match the `in_values` types. + for (auto [i, regionArg, inTy] : llvm::enumerate(regionArgs, inValueTypes)) { + if (regionArg.getType() != inTy) { + return emitOpError() << name << " region block argument " << i + << " type does not match the corresponding in-value type"; + } + } - for (auto& region : regions) { - auto regionArgs = region->getArguments(); - if (verifyRegionArgs(*this, inValues, outValues, regionArgs).failed()) { + // `jeff.yield` is present. + auto yield = dyn_cast(region.front().back()); + if (!yield) { + return emitOpError() << name << " region must terminate with `jeff.yield`"; + } + // `jeff.yield` operand count matches the op's result count. + if (yield.getNumOperands() != resultTypes.size()) { + return emitOpError() << name << " region yields " << yield.getNumOperands() + << " values but op has " << resultTypes.size() << " results"; + } + // `jeff.yield` operand types match the op's result types. + for (auto [i, yieldOp, resTy] : llvm::enumerate(yield.getOperands(), resultTypes)) { + if (yieldOp.getType() != resTy) { + return emitOpError() << name << " region yield operand " << i + << " type does not match the corresponding result type"; + } + } + return success(); + }; + + // Verify `case` regions. + for (auto [i, branch] : llvm::enumerate(getBranches())) { + if (verifyRegion(branch, "case " + llvm::Twine(i)).failed()) { return failure(); } } + // Verify `default` region. + if (verifyRegion(getDefault(), "default").failed()) { + return failure(); + } return success(); } @@ -239,11 +386,7 @@ void ForOp::print(OpAsmPrinter& p) { p << " -> (" << inValues.getTypes() << ')'; } - if (Type t = inductionVar.getType(); !t.isIndex()) { - p << " : " << t << ' '; - } else { - p << ' '; - } + p << " : " << inductionVar.getType() << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, @@ -251,9 +394,85 @@ void ForOp::print(OpAsmPrinter& p) { p.printOptionalAttrDict((*this)->getAttrs()); } -ParseResult ForOp::parse(OpAsmParser& /*parser*/, OperationState& /*result*/) { - // TODO: Implement this - llvm::report_fatal_error("ForOp::parse is not implemented yet"); +// Adapted from +// https://github.com/llvm/llvm-project/blob/a58268a77cdbfeb0b71f3e76d169ddd7edf7a4df/mlir/lib/Dialect/SCF/IR/SCF.cpp#L516 +ParseResult ForOp::parse(OpAsmParser& parser, OperationState& result) { + auto& builder = parser.getBuilder(); + Type type; + + OpAsmParser::Argument inductionVar; + OpAsmParser::UnresolvedOperand start; + OpAsmParser::UnresolvedOperand stop; + OpAsmParser::UnresolvedOperand step; + + // Parse the induction variable followed by '='. + if (parser.parseOperand(inductionVar.ssaName) || parser.parseEqual() || + // Parse loop bounds. + parser.parseOperand(start) || parser.parseKeyword("to") || parser.parseOperand(stop) || + parser.parseKeyword("step") || parser.parseOperand(step)) { + return failure(); + } + + // Parse the optional initial iteration arguments. + llvm::SmallVector regionArgs; + llvm::SmallVector operands; + regionArgs.push_back(inductionVar); + + // Parse assignment list and result types list. + bool hasArgs = succeeded(parser.parseOptionalKeyword("args")); + if (hasArgs) { + if (parser.parseAssignmentList(regionArgs, operands) || + parser.parseArrowTypeList(result.types)) { + return failure(); + } + } + + if (regionArgs.size() != result.types.size() + 1) { + return parser.emitError(parser.getNameLoc(), + "mismatch in number of loop-carried values and defined values"); + } + + // Parse type. + if (parser.parseColon() || parser.parseType(type)) { + return failure(); + } + + // Set block argument types so that they are known when parsing the region. + regionArgs.front().type = type; + for (auto [arg, argType] : llvm::zip_equal(llvm::drop_begin(regionArgs), result.types)) { + arg.type = argType; + } + + // Parse the body region. + Region* body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) { + return failure(); + } + ForOp::ensureTerminator(*body, builder, result.location); + + // Resolve input operands. + if (parser.resolveOperand(start, type, result.operands) || + parser.resolveOperand(stop, type, result.operands) || + parser.resolveOperand(step, type, result.operands)) { + return failure(); + } + if (hasArgs) { + for (auto argOperandType : + llvm::zip_equal(llvm::drop_begin(regionArgs), operands, result.types)) { + Type argOpType = std::get<2>(argOperandType); + std::get<0>(argOperandType).type = argOpType; + if (parser.resolveOperand(std::get<1>(argOperandType), argOpType, result.operands)) { + return failure(); + } + } + } + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) { + return failure(); + } + + return success(); } // Adapted from @@ -270,8 +489,15 @@ LogicalResult ForOp::verify() { // https://github.com/llvm/llvm-project/blob/a58268a77cdbfeb0b71f3e76d169ddd7edf7a4df/mlir/lib/Dialect/SCF/IR/SCF.cpp#L359 LogicalResult ForOp::verifyRegions() { auto inductionVar = getBody().getArgument(0); - if (inductionVar.getType() != getStart().getType()) { - return emitOpError("expected induction variable to be same type as bounds and step"); + auto inductionVarType = inductionVar.getType(); + if (inductionVarType != getStart().getType()) { + return emitOpError("expected induction variable to be same type as start"); + } + if (inductionVarType != getStop().getType()) { + return emitOpError("expected induction variable to be same type as stop"); + } + if (inductionVarType != getStep().getType()) { + return emitOpError("expected induction variable to be same type as step"); } auto inValues = getInValues(); @@ -284,29 +510,117 @@ LogicalResult ForOp::verifyRegions() { return success(); } +// Adapted from +// https://github.com/llvm/llvm-project/blob/a58268a77cdbfeb0b71f3e76d169ddd7edf7a4df/mlir/lib/Dialect/SCF/IR/SCF.cpp#L3343 void WhileOp::print(OpAsmPrinter& p) { auto inValues = getInValues(); + // Emit `: ( types )` only when there are in-values. + if (!inValues.empty()) { + p << " : (" << inValues.getTypes() << ")"; + } + + // Condition region: `args ( $assignments )`. + // Full assignments, since this is where the op's operands are introduced. auto& condition = getCondition(); auto conditionArgs = condition.getArguments(); printInitializationList(p, conditionArgs, inValues, " args"); - p << " -> (" << IntegerType::get(getContext(), 1) << ") "; + p << ' '; p.printRegion(condition, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/!inValues.empty()); + /*printBlockTerminators=*/true); + // Body region: `args ( $names )`. + // Names only. The operands are already stated in the condition's `args(...)`. auto& body = getBody(); auto bodyArgs = body.getArguments(); - printInitializationList(p, bodyArgs, inValues, " args"); - p << " -> (" << inValues.getTypes() << ") "; + p << " args("; + llvm::interleaveComma(bodyArgs, p); + p << ") "; p.printRegion(body, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/!inValues.empty()); p.printOptionalAttrDict((*this)->getAttrs()); } -ParseResult WhileOp::parse(OpAsmParser& /*parser*/, OperationState& /*result*/) { - // TODO: Implement this - llvm::report_fatal_error("WhileOp::parse is not implemented yet"); +// Adapted from +// https://github.com/llvm/llvm-project/blob/a58268a77cdbfeb0b71f3e76d169ddd7edf7a4df/mlir/lib/Dialect/SCF/IR/SCF.cpp#L3303 +ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) { + auto& builder = parser.getBuilder(); + + Region* condition = result.addRegion(); + Region* body = result.addRegion(); + + // Parse optional `: ( types )`. + // Omitted when there are no in-values. + llvm::SmallVector types; + if (succeeded(parser.parseOptionalColon())) { + if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + return parser.parseType(types.emplace_back()); + })) { + return failure(); + } + } + + // Parse the condition region's `args ( $assignments )`. + llvm::SmallVector condRegionArgs; + llvm::SmallVector condOperands; + if (parser.parseKeyword("args") || parser.parseAssignmentList(condRegionArgs, condOperands)) { + return failure(); + } + + if (condRegionArgs.size() != types.size()) { + return parser.emitError(parser.getNameLoc()) + << "expected " << types.size() << " condition arguments but got " + << condRegionArgs.size(); + } + + for (auto [arg, ty] : llvm::zip_equal(condRegionArgs, types)) { + arg.type = ty; + } + + if (parser.parseRegion(*condition, condRegionArgs)) { + return failure(); + } + WhileOp::ensureTerminator(*condition, builder, result.location); + + // Parse the body region's `args ( $names )`. + // Names only. The operands are inherited from the condition's `args(...)`. + llvm::SmallVector bodyRegionArgs; + if (parser.parseKeyword("args") || + parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + return parser.parseArgument(bodyRegionArgs.emplace_back()); + })) { + return failure(); + } + + if (bodyRegionArgs.size() != types.size()) { + return parser.emitError(parser.getNameLoc()) + << "expected " << types.size() << " body arguments but got " + << bodyRegionArgs.size(); + } + + for (auto [arg, ty] : llvm::zip_equal(bodyRegionArgs, types)) { + arg.type = ty; + } + + if (parser.parseRegion(*body, bodyRegionArgs)) { + return failure(); + } + WhileOp::ensureTerminator(*body, builder, result.location); + + // Resolve operands from the condition's `args(...)`. + if (parser.resolveOperands(condOperands, types, parser.getCurrentLocation(), result.operands)) { + return failure(); + } + + // Op results have the same types as in-values. + result.addTypes(types); + + if (parser.parseOptionalAttrDict(result.attributes)) { + return failure(); + } + + return success(); } LogicalResult WhileOp::verify() { @@ -337,26 +651,110 @@ LogicalResult WhileOp::verifyRegions() { void DoWhileOp::print(OpAsmPrinter& p) { auto inValues = getInValues(); + // Emit `: ( types )` only when there are in-values. + if (!inValues.empty()) { + p << " : (" << inValues.getTypes() << ")"; + } + + // Body region: `args ( $assignments )`. + // Fll assignments, since this is where the op's operands are introduced. auto& body = getBody(); auto bodyArgs = body.getArguments(); printInitializationList(p, bodyArgs, inValues, " args"); - p << " -> (" << inValues.getTypes() << ") "; + p << ' '; p.printRegion(body, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/!inValues.empty()); + // Condition region: `args ( $names )`. + // Names only. The operands are already stated in the body's `args(...)`. auto& condition = getCondition(); auto conditionArgs = condition.getArguments(); - printInitializationList(p, conditionArgs, inValues, " args"); - p << " -> (" << IntegerType::get(getContext(), 1) << ") "; + p << " args("; + llvm::interleaveComma(conditionArgs, p); + p << ") "; p.printRegion(condition, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/!inValues.empty()); + /*printBlockTerminators=*/true); p.printOptionalAttrDict((*this)->getAttrs()); } -ParseResult DoWhileOp::parse(OpAsmParser& /*parser*/, OperationState& /*result*/) { - // TODO: Implement this - llvm::report_fatal_error("DoWhileOp::parse is not implemented yet"); +ParseResult DoWhileOp::parse(OpAsmParser& parser, OperationState& result) { + auto& builder = parser.getBuilder(); + + Region* body = result.addRegion(); + Region* condition = result.addRegion(); + + // Parse optional `: ( types )`. + // Omitted when there are no in-values. + llvm::SmallVector types; + if (succeeded(parser.parseOptionalColon())) { + if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + return parser.parseType(types.emplace_back()); + })) { + return failure(); + } + } + + // Parse the body region's `args ( $assignments )`. + llvm::SmallVector bodyRegionArgs; + llvm::SmallVector bodyOperands; + if (parser.parseKeyword("args") || parser.parseAssignmentList(bodyRegionArgs, bodyOperands)) { + return failure(); + } + + if (bodyRegionArgs.size() != types.size()) { + return parser.emitError(parser.getNameLoc()) + << "expected " << types.size() << " body arguments but got " + << bodyRegionArgs.size(); + } + + for (auto [arg, ty] : llvm::zip_equal(bodyRegionArgs, types)) { + arg.type = ty; + } + + if (parser.parseRegion(*body, bodyRegionArgs)) { + return failure(); + } + DoWhileOp::ensureTerminator(*body, builder, result.location); + + // Parse the condition region's `args ( $names )`. + // Names only. The operands are inherited from the body's `args(...)`. + llvm::SmallVector condRegionArgs; + if (parser.parseKeyword("args") || + parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + return parser.parseArgument(condRegionArgs.emplace_back()); + })) { + return failure(); + } + + if (condRegionArgs.size() != types.size()) { + return parser.emitError(parser.getNameLoc()) + << "expected " << types.size() << " condition arguments but got " + << condRegionArgs.size(); + } + + for (auto [arg, ty] : llvm::zip_equal(condRegionArgs, types)) { + arg.type = ty; + } + + if (parser.parseRegion(*condition, condRegionArgs)) { + return failure(); + } + DoWhileOp::ensureTerminator(*condition, builder, result.location); + + // Resolve operands from the body's `args(...)`. + if (parser.resolveOperands(bodyOperands, types, parser.getCurrentLocation(), result.operands)) { + return failure(); + } + + // Op results have the same types as in-values. + result.addTypes(types); + + if (parser.parseOptionalAttrDict(result.attributes)) { + return failure(); + } + + return success(); } LogicalResult DoWhileOp::verify() { diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index b4f031f..b1986d0 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Conversion) +add_subdirectory(IR) add_subdirectory(Translation) diff --git a/unittests/IR/CMakeLists.txt b/unittests/IR/CMakeLists.txt new file mode 100644 index 0000000..714631e --- /dev/null +++ b/unittests/IR/CMakeLists.txt @@ -0,0 +1,23 @@ +set(test_name "jeff-mlir-parse-test") + +if(NOT TARGET ${test_name}) + add_executable(${test_name} + test_parse_do_while_op.cpp + test_parse_for_op.cpp + test_parse_switch_op.cpp + test_parse_while_op.cpp + ) + + target_link_libraries(${test_name} PRIVATE + GTest::gtest_main + MLIRJeff + MLIRIR + MLIRParser + MLIRFuncDialect + MLIRSupport + ) + + set_target_properties(${test_name} PROPERTIES FOLDER unittests) + + gtest_discover_tests(${test_name} DISCOVERY_TIMEOUT 60) +endif() diff --git a/unittests/IR/test_parse_do_while_op.cpp b/unittests/IR/test_parse_do_while_op.cpp new file mode 100644 index 0000000..71a1ff8 --- /dev/null +++ b/unittests/IR/test_parse_do_while_op.cpp @@ -0,0 +1,197 @@ +#include "jeff/IR/JeffDialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace mlir; + +class DoWhileOpTest : public ::testing::Test { + protected: + MLIRContext ctx; + ScopedDiagnosticHandler handler{&ctx, [](Diagnostic&) { return success(); }}; + + void SetUp() override { ctx.loadDialect(); } +}; + +// jeff SCF regions are isolated from above: every value used inside a region must come +// from a block argument or be computed locally. +// Tests with carried values pass the loop predicate through as an additional in-value. + +// === Valid tests === + +TEST_F(DoWhileOpTest, NoArgs) { + const std::string src = R"MLIR( + func.func @f() { + jeff.doWhile args() { + } args() { + %c_pred = jeff.int_const1(true) : i1 + jeff.yield %c_pred : i1 + } + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(DoWhileOpTest, WithArgsSingle) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) -> i32 { + %r1, %r2 = jeff.doWhile : (i32, i1) args(%b_x = %a, %b_pred = %pred) { + jeff.yield %b_x, %b_pred : i32, i1 + } args(%c_x, %c_pred) { + jeff.yield %c_pred : i1 + } + return %r1 : i32 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(DoWhileOpTest, WithArgsMultiple) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %b: i64, %pred: i1) -> (i32, i64) { + %r1, %r2, %r3 = jeff.doWhile : (i32, i64, i1) args(%b_x = %a, %b_y = %b, %b_pred = %pred) { + jeff.yield %b_x, %b_y, %b_pred : i32, i64, i1 + } args(%c_x, %c_y, %c_pred) { + jeff.yield %c_pred : i1 + } + return %r1, %r2 : i32, i64 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(DoWhileOpTest, Nested) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) -> i32 { + %r1, %r2 = jeff.doWhile : (i32, i1) args(%b_x = %a, %b_pred = %pred) { + %s1, %s2 = jeff.doWhile : (i32, i1) args(%bb = %b_x, %bbp = %b_pred) { + jeff.yield %bb, %bbp : i32, i1 + } args(%cc, %ccp) { + jeff.yield %ccp : i1 + } + jeff.yield %s1, %s2 : i32, i1 + } args(%c_x, %c_pred) { + jeff.yield %c_pred : i1 + } + return %r1 : i32 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// `DoWhileOp::print` elides the empty yield for the body when there are no in-values, +// but bare `jeff.yield` is still valid input. It can come from `YieldOp`'s own printer, +// generic-form output (`-mlir-print-op-generic`), or handwritten MLIR. +// The parser must accept this shape. +TEST_F(DoWhileOpTest, ExplicitEmptyBodyYield) { + const std::string src = R"MLIR( + func.func @f() { + jeff.doWhile args() { + jeff.yield + } args() { + %c_pred = jeff.int_const1(true) : i1 + jeff.yield %c_pred : i1 + } + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// Parse → print → parse → print, then assert idempotent. +// The first round normalizes (whitespace, SSA names). +// The second round must be a no-op. +TEST_F(DoWhileOpTest, RoundTripIdempotent) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %b: i64, %pred: i1) -> (i32, i64) { + %r1, %r2, %r3 = jeff.doWhile : (i32, i64, i1) args(%b_x = %a, %b_y = %b, %b_pred = %pred) { + jeff.yield %b_x, %b_y, %b_pred : i32, i64, i1 + } args(%c_x, %c_y, %c_pred) { + jeff.yield %c_pred : i1 + } + return %r1, %r2 : i32, i64 + } + )MLIR"; + + const auto module1 = parseSourceString(src, &ctx); + ASSERT_TRUE(module1); + std::string printed1; + llvm::raw_string_ostream(printed1) << *module1; + + const auto module2 = parseSourceString(printed1, &ctx); + ASSERT_TRUE(module2); + std::string printed2; + llvm::raw_string_ostream(printed2) << *module2; + + EXPECT_EQ(printed1, printed2); +} + +// === Invalid syntax tests (parse-level) === + +TEST_F(DoWhileOpTest, InvalidMissingArgsKeyword) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) { + jeff.doWhile : (i32, i1) (%b_x = %a, %b_pred = %pred) { + jeff.yield %b_x, %b_pred : i32, i1 + } args(%c_x, %c_pred) { + jeff.yield %c_pred : i1 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// In-values count and body args count differ. +TEST_F(DoWhileOpTest, InvalidArgCountMismatchWithTypes) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) { + jeff.doWhile : (i32, i1, i64) args(%b_x = %a, %b_pred = %pred) { + jeff.yield %b_x, %b_pred : i32, i1 + } args(%c_x, %c_pred) { + jeff.yield %c_pred : i1 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// Body args count and condition args count differ. +TEST_F(DoWhileOpTest, InvalidBodyCondArgCountMismatch) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %b: i64, %pred: i1) { + jeff.doWhile : (i32, i64, i1) args(%b_x = %a, %b_y = %b, %b_pred = %pred) { + jeff.yield %b_x, %b_y, %b_pred : i32, i64, i1 + } args(%c_x, %c_pred) { + jeff.yield %c_pred : i1 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// Missing in-value types with non-empty body args. +TEST_F(DoWhileOpTest, InvalidMissingTypeAnnotation) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) { + jeff.doWhile args(%b_x = %a, %b_pred = %pred) { + jeff.yield %b_x, %b_pred : i32, i1 + } args(%c_x, %c_pred) { + jeff.yield %c_pred : i1 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} diff --git a/unittests/IR/test_parse_for_op.cpp b/unittests/IR/test_parse_for_op.cpp new file mode 100644 index 0000000..636d6b7 --- /dev/null +++ b/unittests/IR/test_parse_for_op.cpp @@ -0,0 +1,278 @@ +#include "jeff/IR/JeffDialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace mlir; + +class ForOpTest : public ::testing::Test { + protected: + MLIRContext ctx; + ScopedDiagnosticHandler handler{&ctx, [](Diagnostic&) { return success(); }}; + + void SetUp() override { ctx.loadDialect(); } +}; + +// === Valid tests === + +TEST_F(ForOpTest, BasicFormI32) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32) { + jeff.for %i = %lo to %hi step %s : i32 { + jeff.yield + } + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(ForOpTest, BasicFormI64) { + const std::string src = R"MLIR( + func.func @f(%lo: i64, %hi: i64, %s: i64) { + jeff.for %i = %lo to %hi step %s : i64 { + jeff.yield + } + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// Body has no explicit `jeff.yield`. +// `SingleBlockImplicitTerminator` should auto-insert one +// via `ForOp::ensureTerminator`. +TEST_F(ForOpTest, ImplicitYield) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32) { + jeff.for %i = %lo to %hi step %s : i32 {} + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(ForOpTest, WithArgsSingle) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32, %init: i32) -> i32 { + %r = jeff.for %i = %lo to %hi step %s args(%acc = %init) -> (i32) : i32 { + jeff.yield %acc : i32 + } + return %r : i32 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(ForOpTest, WithArgsMultiple) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32, %a: i32, %b: i64) -> (i32, i64) { + %r1, %r2 = jeff.for %i = %lo to %hi step %s args(%x = %a, %y = %b) -> (i32, i64) : i32 { + jeff.yield %x, %y : i32, i64 + } + return %r1, %r2 : i32, i64 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// Inner `jeff.for` cannot see %lo, %hi, and %s from the enclosing function, +// so the outer `jeff.for` has to pass them in as args. +TEST_F(ForOpTest, Nested) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32) -> (i32, i32, i32) { + %r1, %r2, %r3 = jeff.for %i = %lo to %hi step %s args(%lo_arg = %lo, %hi_arg = %hi, %s_arg = %s) -> (i32, i32, i32) : i32 { + jeff.for %j = %lo_arg to %hi_arg step %s_arg : i32 {} + jeff.yield %lo_arg, %hi_arg, %s_arg : i32, i32, i32 + } + return %r1, %r2, %r3 : i32, i32, i32 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// `ForOp::print` elides the empty yield, but bare `jeff.yield` is still valid input. +// It can come from `YieldOp`'s own printer, generic-form output (`-mlir-print-op-generic`), or +// handwritten MLIR. +// The parser must accept this shape. +TEST_F(ForOpTest, ExplicitEmptyYield) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32) { + jeff.for %i = %lo to %hi step %s : i32 { + jeff.yield + } + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// Parse → print → parse → print, then assert idempotent. +// The first round normalizes (whitespace, SSA names). +// The second round must be a no-op. +TEST_F(ForOpTest, RoundTripIdempotent) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32, %init: i32) -> i32 { + %r = jeff.for %i = %lo to %hi step %s args(%acc = %init) -> (i32) : i32 { + jeff.yield %acc : i32 + } + return %r : i32 + } + )MLIR"; + + const auto module1 = parseSourceString(src, &ctx); + ASSERT_TRUE(module1); + std::string printed1; + llvm::raw_string_ostream(printed1) << *module1; + + const auto module2 = parseSourceString(printed1, &ctx); + ASSERT_TRUE(module2); + std::string printed2; + llvm::raw_string_ostream(printed2) << *module2; + + EXPECT_EQ(printed1, printed2); +} + +// === Invalid syntax tests (parse-level) === + +TEST_F(ForOpTest, InvalidMissingEquals) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32) { + jeff.for %i %lo to %hi step %s : i32 {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +TEST_F(ForOpTest, InvalidMissingTo) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32) { + jeff.for %i = %lo %hi step %s : i32 {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +TEST_F(ForOpTest, InvalidMissingType) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32) { + jeff.for %i = %lo to %hi step %s {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +TEST_F(ForOpTest, InvalidArgsWithoutArrow) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32, %init: i32) { + jeff.for %i = %lo to %hi step %s args(%acc = %init) : i32 { + jeff.yield %acc : i32 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// 2 region args, 1 result type. +// Caught by the explicit size check in `parse`. +TEST_F(ForOpTest, InvalidArgCountMismatch) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i32, %x: i32, %y: i32) { + jeff.for %i = %lo to %hi step %s args(%a = %x, %b = %y) -> (i32) : i32 { + jeff.yield %a : i32 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// === Invalid semantics tests (parse-level) === + +// `stop` is declared with a different type than the loop's type annotation. +// The parser's `resolveOperand` catches the mismatch. +TEST_F(ForOpTest, InvalidStopTypeMismatch) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i64, %s: i32) { + jeff.for %i = %lo to %hi step %s : i32 {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// `step` is declared with a different type than the loop's type annotation. +// The parser's `resolveOperand` catches the mismatch. +TEST_F(ForOpTest, InvalidStepTypeMismatch) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i64) { + jeff.for %i = %lo to %hi step %s : i32 {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// === Invalid semantics tests (verify-level) === + +// `index` is rejected by SupportedIntType. +TEST_F(ForOpTest, InvalidIndexType) { + const std::string src = R"MLIR( + func.func @f(%lo: index, %hi: index, %s: index) { + jeff.for %i = %lo to %hi step %s : index {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// Floating-point types are rejected by SupportedIntType. +TEST_F(ForOpTest, InvalidFloatType) { + const std::string src = R"MLIR( + func.func @f(%lo: f32, %hi: f32, %s: f32) { + jeff.for %i = %lo to %hi step %s : f32 {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// Generic form bypasses the custom parser, letting start/stop/step have distinct types. +// The verifier check in `ForOp::verifyRegions` is the only thing that catches the mismatch. +TEST_F(ForOpTest, InvalidStopTypeMismatchGenericForm) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i64, %s: i32) { + "jeff.for"(%lo, %hi, %s) ({ + ^bb0(%i: i32): + "jeff.yield"() : () -> () + }) : (i32, i64, i32) -> () + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +TEST_F(ForOpTest, InvalidStepTypeMismatchGenericForm) { + const std::string src = R"MLIR( + func.func @f(%lo: i32, %hi: i32, %s: i64) { + "jeff.for"(%lo, %hi, %s) ({ + ^bb0(%i: i32): + "jeff.yield"() : () -> () + }) : (i32, i32, i64) -> () + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} diff --git a/unittests/IR/test_parse_switch_op.cpp b/unittests/IR/test_parse_switch_op.cpp new file mode 100644 index 0000000..1abc88c --- /dev/null +++ b/unittests/IR/test_parse_switch_op.cpp @@ -0,0 +1,291 @@ +#include "jeff/IR/JeffDialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace mlir; + +class SwitchOpTest : public ::testing::Test { + protected: + MLIRContext ctx; + ScopedDiagnosticHandler handler{&ctx, [](Diagnostic&) { return success(); }}; + + void SetUp() override { ctx.loadDialect(); } +}; + +// === Valid tests === + +// Structurally legal: a switch with zero branches and no default. +// (See "Underspecified" in the op description.) +TEST_F(SwitchOpTest, ZeroCasesAndNoDefault) { + const std::string src = R"MLIR( + func.func @f(%sel: i32) { + jeff.switch (%sel) : (i32) -> () + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(SwitchOpTest, NoInValuesNoResults) { + const std::string src = R"MLIR( + func.func @f(%sel: i32) { + jeff.switch (%sel) : (i32) -> () + case 0 args() {} + case 1 args() {} + default args() {} + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(SwitchOpTest, WithInValuesAndResults) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32, %b: i64) -> (i32, i64) { + %r1, %r2 = jeff.switch (%sel, %a, %b) : (i32, i32, i64) -> (i32, i64) + case 0 args(%x, %y) { + jeff.yield %x, %y : i32, i64 + } + case 1 args(%x, %y) { + jeff.yield %x, %y : i32, i64 + } + default args(%x, %y) { + jeff.yield %x, %y : i32, i64 + } + return %r1, %r2 : i32, i64 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(SwitchOpTest, NoDefault) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32) -> i32 { + %r = jeff.switch (%sel, %a) : (i32, i32) -> (i32) + case 0 args(%x) { + jeff.yield %x : i32 + } + case 1 args(%x) { + jeff.yield %x : i32 + } + return %r : i32 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(SwitchOpTest, OnlyDefault) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32) -> i32 { + %r = jeff.switch (%sel, %a) : (i32, i32) -> (i32) + default args(%x) { + jeff.yield %x : i32 + } + return %r : i32 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// In-value types and result types are independent: the in-values are (i32, i64, i1), +// but the op yields a single i1. +// The yielded value comes from `%q`, which is the i1 in-value passed in. +// Regions are isolated from above, so `%p` from the function scope cannot be used directly inside +// the regions. +TEST_F(SwitchOpTest, DecoupledInValueAndResultTypes) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32, %b: i64, %p: i1) -> i1 { + %r = jeff.switch (%sel, %a, %b, %p) : (i32, i32, i64, i1) -> (i1) + case 0 args(%x, %y, %q) { + jeff.yield %q : i1 + } + default args(%x, %y, %q) { + jeff.yield %q : i1 + } + return %r : i1 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(SwitchOpTest, Nested) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32) -> i32 { + %r = jeff.switch (%sel, %a) : (i32, i32) -> (i32) + case 0 args(%x) { + %s = jeff.switch (%sel, %x) : (i32, i32) -> (i32) + case 0 args(%xx) { + jeff.yield %xx : i32 + } + default args(%xx) { + jeff.yield %xx : i32 + } + jeff.yield %s : i32 + } + default args(%x) { + jeff.yield %x : i32 + } + return %r : i32 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// Parse → print → parse → print, then assert idempotent. +TEST_F(SwitchOpTest, RoundTripIdempotent) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32, %b: i64) -> (i32, i64) { + %r1, %r2 = jeff.switch (%sel, %a, %b) : (i32, i32, i64) -> (i32, i64) + case 0 args(%x, %y) { + jeff.yield %x, %y : i32, i64 + } + case 1 args(%x, %y) { + jeff.yield %x, %y : i32, i64 + } + default args(%x, %y) { + jeff.yield %x, %y : i32, i64 + } + return %r1, %r2 : i32, i64 + } + )MLIR"; + + const auto module1 = parseSourceString(src, &ctx); + ASSERT_TRUE(module1); + std::string printed1; + llvm::raw_string_ostream(printed1) << *module1; + + const auto module2 = parseSourceString(printed1, &ctx); + ASSERT_TRUE(module2); + std::string printed2; + llvm::raw_string_ostream(printed2) << *module2; + + EXPECT_EQ(printed1, printed2); +} + +// === Invalid syntax tests (parse-level) === + +// Case labels must start at 0 and increase by 1. +TEST_F(SwitchOpTest, InvalidNonContiguousCaseLabels) { + const std::string src = R"MLIR( + func.func @f(%sel: i32) { + jeff.switch (%sel) : (i32) -> () + case 0 args() {} + case 2 args() {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +TEST_F(SwitchOpTest, InvalidCaseLabelNotStartingAtZero) { + const std::string src = R"MLIR( + func.func @f(%sel: i32) { + jeff.switch (%sel) : (i32) -> () + case 1 args() {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +TEST_F(SwitchOpTest, InvalidMissingArgsKeyword) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32) { + jeff.switch (%sel, %a) : (i32, i32) -> () + case 0 (%x) {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// Region arg count must match the number of in-values. +TEST_F(SwitchOpTest, InvalidRegionArgCountMismatch) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32, %b: i64) { + jeff.switch (%sel, %a, %b) : (i32, i32, i64) -> () + case 0 args(%x) {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// Operand-type count must match operand count. +TEST_F(SwitchOpTest, InvalidOperandTypeCountMismatch) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32) { + jeff.switch (%sel, %a) : (i32) -> () + case 0 args(%x) {} + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// The selector must be a `SupportedIntType` (i1/i8/i16/i32/i64). +// `index` is rejected by the ODS-generated verifier. +TEST_F(SwitchOpTest, InvalidSelectorTypeIndex) { + const std::string src = R"MLIR( + func.func @f(%sel: index) { + jeff.switch (%sel) : (index) -> () + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +TEST_F(SwitchOpTest, InvalidSelectorTypeFloat) { + const std::string src = R"MLIR( + func.func @f(%sel: f32) { + jeff.switch (%sel) : (f32) -> () + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// === Invalid verifier tests === + +// Yield operand count must match the op's result count. +TEST_F(SwitchOpTest, InvalidYieldCountMismatch) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32) -> i32 { + %r = jeff.switch (%sel, %a) : (i32, i32) -> (i32) + case 0 args(%x) { + jeff.yield + } + default args(%x) { + jeff.yield %x : i32 + } + return %r : i32 + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// Yield operand type must match the op's result type. +TEST_F(SwitchOpTest, InvalidYieldTypeMismatch) { + const std::string src = R"MLIR( + func.func @f(%sel: i32, %a: i32, %b: i64) -> i32 { + %r = jeff.switch (%sel, %a, %b) : (i32, i32, i64) -> (i32) + case 0 args(%x, %y) { + jeff.yield %y : i64 + } + default args(%x, %y) { + jeff.yield %x : i32 + } + return %r : i32 + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} diff --git a/unittests/IR/test_parse_while_op.cpp b/unittests/IR/test_parse_while_op.cpp new file mode 100644 index 0000000..24a502d --- /dev/null +++ b/unittests/IR/test_parse_while_op.cpp @@ -0,0 +1,197 @@ +#include "jeff/IR/JeffDialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace mlir; + +class WhileOpTest : public ::testing::Test { + protected: + MLIRContext ctx; + ScopedDiagnosticHandler handler{&ctx, [](Diagnostic&) { return success(); }}; + + void SetUp() override { ctx.loadDialect(); } +}; + +// jeff SCF regions are isolated from above: every value used inside a region must come +// from a block argument or be computed locally. +// Tests with carried values pass the loop predicate through as an additional in-value. + +// === Valid tests === + +TEST_F(WhileOpTest, NoArgs) { + const std::string src = R"MLIR( + func.func @f() { + jeff.while args() { + %c_pred = jeff.int_const1(true) : i1 + jeff.yield %c_pred : i1 + } args() { + } + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(WhileOpTest, WithArgsSingle) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) -> i32 { + %r1, %r2 = jeff.while : (i32, i1) args(%c_x = %a, %c_pred = %pred) { + jeff.yield %c_pred : i1 + } args(%b_x, %b_pred) { + jeff.yield %b_x, %b_pred : i32, i1 + } + return %r1 : i32 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(WhileOpTest, WithArgsMultiple) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %b: i64, %pred: i1) -> (i32, i64) { + %r1, %r2, %r3 = jeff.while : (i32, i64, i1) args(%c_x = %a, %c_y = %b, %c_pred = %pred) { + jeff.yield %c_pred : i1 + } args(%b_x, %b_y, %b_pred) { + jeff.yield %b_x, %b_y, %b_pred : i32, i64, i1 + } + return %r1, %r2 : i32, i64 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +TEST_F(WhileOpTest, Nested) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) -> i32 { + %r1, %r2 = jeff.while : (i32, i1) args(%c_x = %a, %c_pred = %pred) { + jeff.yield %c_pred : i1 + } args(%b_x, %b_pred) { + %s1, %s2 = jeff.while : (i32, i1) args(%cc = %b_x, %ccp = %b_pred) { + jeff.yield %ccp : i1 + } args(%bb, %bbp) { + jeff.yield %bb, %bbp : i32, i1 + } + jeff.yield %s1, %s2 : i32, i1 + } + return %r1 : i32 + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// `WhileOp::print` elides the empty yield for the body when there are no in-values, +// but bare `jeff.yield` is still valid input. It can come from `YieldOp`'s own printer, +// generic-form output (`-mlir-print-op-generic`), or handwritten MLIR. +// The parser must accept this shape. +TEST_F(WhileOpTest, ExplicitEmptyBodyYield) { + const std::string src = R"MLIR( + func.func @f() { + jeff.while args() { + %c_pred = jeff.int_const1(true) : i1 + jeff.yield %c_pred : i1 + } args() { + jeff.yield + } + return + } + )MLIR"; + ASSERT_TRUE(parseSourceString(src, &ctx)); +} + +// Parse → print → parse → print, then assert idempotent. +// The first round normalizes (whitespace, SSA names). +// The second round must be a no-op. +TEST_F(WhileOpTest, RoundTripIdempotent) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %b: i64, %pred: i1) -> (i32, i64) { + %r1, %r2, %r3 = jeff.while : (i32, i64, i1) args(%c_x = %a, %c_y = %b, %c_pred = %pred) { + jeff.yield %c_pred : i1 + } args(%b_x, %b_y, %b_pred) { + jeff.yield %b_x, %b_y, %b_pred : i32, i64, i1 + } + return %r1, %r2 : i32, i64 + } + )MLIR"; + + const auto module1 = parseSourceString(src, &ctx); + ASSERT_TRUE(module1); + std::string printed1; + llvm::raw_string_ostream(printed1) << *module1; + + const auto module2 = parseSourceString(printed1, &ctx); + ASSERT_TRUE(module2); + std::string printed2; + llvm::raw_string_ostream(printed2) << *module2; + + EXPECT_EQ(printed1, printed2); +} + +// === Invalid syntax tests (parse-level) === + +TEST_F(WhileOpTest, InvalidMissingArgsKeyword) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) { + jeff.while : (i32, i1) (%c_x = %a, %c_pred = %pred) { + jeff.yield %c_pred : i1 + } args(%b_x, %b_pred) { + jeff.yield %b_x, %b_pred : i32, i1 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// In-values count and condition args count differ. +TEST_F(WhileOpTest, InvalidArgCountMismatchWithTypes) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) { + jeff.while : (i32, i1, i64) args(%c_x = %a, %c_pred = %pred) { + jeff.yield %c_pred : i1 + } args(%b_x, %b_pred) { + jeff.yield %b_x, %b_pred : i32, i1 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// Condition args count and body args count differ. +TEST_F(WhileOpTest, InvalidCondBodyArgCountMismatch) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %b: i64, %pred: i1) { + jeff.while : (i32, i64, i1) args(%c_x = %a, %c_y = %b, %c_pred = %pred) { + jeff.yield %c_pred : i1 + } args(%b_x, %b_pred) { + jeff.yield %b_x, %b_pred : i64, i1 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +} + +// Missing in-value types with non-empty condition args. +TEST_F(WhileOpTest, InvalidMissingTypeAnnotation) { + const std::string src = R"MLIR( + func.func @f(%a: i32, %pred: i1) { + jeff.while args(%c_x = %a, %c_pred = %pred) { + jeff.yield %c_pred : i1 + } args(%b_x, %b_pred) { + jeff.yield %b_x, %b_pred : i32, i1 + } + return + } + )MLIR"; + ASSERT_FALSE(parseSourceString(src, &ctx)); +}