Skip to content

Conversation

@tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Dec 28, 2025

Creates a new LowerWaveControlFlowPass that specifically handles wave.iterate -> scf.for conversion, avoiding the issue where mixed lowering was creating unlegalizable IR.

The problem is that if LowerWaveToMLIRPass tries to do both:

  1. Control flow lowering (wave.iterate -> scf.for)
  2. Operation lowering (wave.read, wave.mma, etc. -> standard ops)

This will cause wave.iterate bodies to be cloned with unconverted wave operations inside scf.for loops, creating IR that couldn't be legalized.

The solution separates the pipeline into two phases:

  1. LowerWaveControlFlowPass: wave.iterate -> scf.for (leaves other wave ops unchanged)
  2. LowerWaveToMLIRPass: Remaining wave ops -> standard MLIR ops

This enables a clean two-phase lowering pipeline where control flow is handled first, then remaining wave operations are lowered in the standard scf.for context.

Creates a new LowerWaveControlFlowPass that specifically handles
wave.iterate -> scf.for conversion, avoiding the issue where mixed
lowering was creating unlegalizable IR.

The problem is that if LowerWaveToMLIRPass tries to do both:
1. Control flow lowering (wave.iterate -> scf.for)
2. Operation lowering (wave.read, wave.mma, etc. -> standard ops)

This will cause ` wave.iterate` bodies to be cloned with unconverted wave
operations inside `scf.for` loops, creating IR that couldn't be legalized.

The solution separates the pipeline into two phases:
1. LowerWaveControlFlowPass: wave.iterate -> scf.for (leaves other wave ops unchanged)
2. LowerWaveToMLIRPass: Remaining wave ops -> standard MLIR ops

This enables a clean two-phase lowering pipeline where control flow
is handled first, then remaining wave operations are lowered in the
standard `scf.for` context.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
@tyb0807 tyb0807 requested review from ftynse and tgymnich December 28, 2025 10:04

def LowerWaveToMLIRPass : Pass<"lower-wave-to-mlir"> {
let summary = "Lower Wave dialect to upstream MLIR dialects";
let summary = "Lower Wave dialect operations to upstream MLIR dialects";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You know control flow operations are also operations?

let description = [{
This pass lowers operations from the Wave dialect to upstream MLIR
dialects, notably arith and vector.
This pass lowers remaining Wave dialect operations to upstream MLIR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unclear what "remaining" means here.

target.addLegalDialect<
// clang-format off
arith::ArithDialect,
func::FuncDialect,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Comment on lines +182 to +185
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, scf::SCFDialect>();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You already listed a different set of dialects in the tablegen definition. This is a huge footgun!

@@ -0,0 +1,148 @@
// RUN: water-opt %s -allow-unregistered-dialect -lower-wave-control-flow --mlir-print-local-scope --split-input-file --verify-diagnostics | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need unregistered dialect? -verify-diagnostics? Tests should be minimal.


return
}
} No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing newline. If you run pre-commit, these should be fixed for you.

Comment on lines +231 to +234
// Create register values - elements_per_thread determined by MMA backward propagation.
%lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>>
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>
%acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, <register>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this test changed here?

Comment on lines +257 to +259
// Create register values - elements_per_thread determined by MMA backward propagation.
%lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>>
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

@ftynse
Copy link
Contributor

ftynse commented Dec 29, 2025

I don't quite understand the problem from the description. We clone the body with operations that weren't converted, why are they not converted later on in the same pass? Is the infra somehow not revisiting them?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants