-
Notifications
You must be signed in to change notification settings - Fork 24
Separate wave control flow lowering into dedicated pass #646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Creates a new LowerWaveControlFlowPass that specifically handles wave.iterate -> scf.for conversion, avoiding the issue where mixed lowering was creating unlegalizable IR. The problem is that if LowerWaveToMLIRPass tries to do both: 1. Control flow lowering (wave.iterate -> scf.for) 2. Operation lowering (wave.read, wave.mma, etc. -> standard ops) This will cause ` wave.iterate` bodies to be cloned with unconverted wave operations inside `scf.for` loops, creating IR that couldn't be legalized. The solution separates the pipeline into two phases: 1. LowerWaveControlFlowPass: wave.iterate -> scf.for (leaves other wave ops unchanged) 2. LowerWaveToMLIRPass: Remaining wave ops -> standard MLIR ops This enables a clean two-phase lowering pipeline where control flow is handled first, then remaining wave operations are lowered in the standard `scf.for` context. Signed-off-by: tyb0807 <sontuan.vu@amd.com>
|
|
||
| def LowerWaveToMLIRPass : Pass<"lower-wave-to-mlir"> { | ||
| let summary = "Lower Wave dialect to upstream MLIR dialects"; | ||
| let summary = "Lower Wave dialect operations to upstream MLIR dialects"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You know control flow operations are also operations?
| let description = [{ | ||
| This pass lowers operations from the Wave dialect to upstream MLIR | ||
| dialects, notably arith and vector. | ||
| This pass lowers remaining Wave dialect operations to upstream MLIR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is unclear what "remaining" means here.
| target.addLegalDialect< | ||
| // clang-format off | ||
| arith::ArithDialect, | ||
| func::FuncDialect, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed?
| void getDependentDialects(DialectRegistry ®istry) const override { | ||
| registry.insert<arith::ArithDialect, scf::SCFDialect>(); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You already listed a different set of dialects in the tablegen definition. This is a huge footgun!
| @@ -0,0 +1,148 @@ | |||
| // RUN: water-opt %s -allow-unregistered-dialect -lower-wave-control-flow --mlir-print-local-scope --split-input-file --verify-diagnostics | FileCheck %s | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need unregistered dialect? -verify-diagnostics? Tests should be minimal.
|
|
||
| return | ||
| } | ||
| } No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing newline. If you run pre-commit, these should be fixed for you.
| // Create register values - elements_per_thread determined by MMA backward propagation. | ||
| %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>> | ||
| %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>> | ||
| %acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, <register>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this test changed here?
| // Create register values - elements_per_thread determined by MMA backward propagation. | ||
| %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>> | ||
| %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
|
I don't quite understand the problem from the description. We clone the body with operations that weren't converted, why are they not converted later on in the same pass? Is the infra somehow not revisiting them? |
Creates a new LowerWaveControlFlowPass that specifically handles wave.iterate -> scf.for conversion, avoiding the issue where mixed lowering was creating unlegalizable IR.
The problem is that if LowerWaveToMLIRPass tries to do both:
This will cause
wave.iteratebodies to be cloned with unconverted wave operations insidescf.forloops, creating IR that couldn't be legalized.The solution separates the pipeline into two phases:
This enables a clean two-phase lowering pipeline where control flow is handled first, then remaining wave operations are lowered in the standard
scf.forcontext.