-
Notifications
You must be signed in to change notification settings - Fork 24
Decompose memrefs early #647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
This reverts commit 81d9747. Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
This reverts commit 548b353. Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces an early memref decomposition pass that explicitly materializes index calculations, LLVM GEPs, and loads/stores. This enables better integer range and uniformity analyses by exposing index arithmetic early in the compilation pipeline. The changes demonstrate a tangible benefit with VGPR count reduction from 160 to 140 in the mxfp gemm test case.
Key changes:
- Adds
water-memref-decompositionpass that converts multi-dimensional memref operations to explicit pointer arithmetic with affine maps - Integrates integer narrowing pass (
arith-int-range-narrowing) into the compilation pipeline to leverage exposed index calculations - Reorders pipeline to run memref decomposition before affine lowering, enabling better optimization opportunities
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| wave_lang/kernel/wave/water.py | Updates compilation pipeline to add memref decomposition before affine lowering and integrates integer range optimizations |
| water/tools/water-opt/water-opt.cpp | Registers the arithmetic integer range narrowing pass |
| water/test/Transforms/memref-decomposition.mlir | Comprehensive test suite covering load/store, vector operations, reinterpret_cast, and AMD GPU-specific operations |
| water/lib/Transforms/MemrefDecomposition.cpp | Core implementation of memref decomposition pass with type converter and pattern rewriters |
| water/lib/Transforms/CMakeLists.txt | Adds new source file and required dependencies (AMDGPU, SCFTransforms) |
| water/include/water/Transforms/Passes.td | Defines the new pass with documentation and dialect dependencies |
| tests/kernel/wave_gemm_mxfp_test.py | Updates expected VGPR count from 160 to 140 and adjusts waitcount expectations reflecting optimization impact |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
|
|
||
| using namespace mlir; | ||
|
|
||
| namespace { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: there is no point in having static functions inside an anonymous namespace. LLVM style says to prefer static functions and only use namespaces for classes.
|
|
||
| namespace { | ||
|
|
||
| static Value getValue(OpBuilder &rewriter, Location loc, OpFoldResult in) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add documentation to all top-level entities.
|
|
||
| static SmallVector<Value> getValues(OpBuilder &rewriter, Location loc, | ||
| ArrayRef<OpFoldResult> in) { | ||
| SmallVector<Value> result; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: reserve before pushing back in a loop.
| static SmallVector<Value> flatten(ArrayRef<ValueRange> values) { | ||
| SmallVector<Value> result; | ||
| for (ValueRange value : values) | ||
| llvm::append_range(result, value); | ||
|
|
||
| return result; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would llvm::concat<Value> work instead? It avoids allocation/copy.
| static std::tuple<LogicalResult, Value, SmallVector<OpFoldResult>, | ||
| SmallVector<OpFoldResult>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's strange to have LogicalResult as part of the tuple instead of FailureOr. I'd consider having vectors as SmallVectorImpl & operands and using null Value as error marker. Multi-element tuples tend to be unreadable at callsites.
| } | ||
|
|
||
| /// Generate a GEP op with the given buffer and byte offset. | ||
| static Value GEP(OpBuilder &builder, Location loc, Value buffer, Value offset) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: createGEP would not break the naming style guide, and it is generally nicer to indicate when some IR may be created.
| /// adjusted pointer. | ||
| static Value getFlattenMemref(OpBuilder &rewriter, Location loc, Value source, | ||
| Type loadType, ArrayRef<OpFoldResult> sizes, | ||
| unsigned typeBit, ArrayRef<OpFoldResult> strides, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is typeBit? Is it the bitwidth of the elemental type in bits (it should be named accordingly!). Why in bits? What happens if it is not divisible by 8?
| zero, sizes, strides, | ||
| getAsOpFoldResult(indices)); | ||
|
|
||
| AffineExpr mul = rewriter.getAffineSymbolExpr(0) * (typeBit / 8); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can this use some more descriptive name?
|
|
||
| unsigned alignment = loadOp.getAlignment().value_or(0); | ||
| rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, loadType, ptr, alignment, | ||
| /*volatile_*/ false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can't we just carry over the volatile flag?
|
|
||
| unsigned alignment = storeOp.getAlignment().value_or(0); | ||
| rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, valueToStore, ptr, | ||
| alignment, /*volatile_*/ false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
Decompose memrefs early, materializing index calculations, llvm GEPs and llvm loads/stores explicitly. This makes index calculations more amendable for integer range and (future) uniformity analyses. Add an integer narrowing pass to the pipleine to see the reduced count of VGPRs on mxfp gemm.
Some operations with non-trivial lowering (buffer casts and
amdgpu.gather-to-lds) are kept in memref land but converted to 0D memrefs, to expose indexing as well.This pass is basically an alternative memref-to-llvm-lowering and it is compatible with the existing upstream one as for operations not lowered it will generate
MemrefDescriptorshims.MemrefDescriptors, hiding all index calculations.* sizeof(T)part of index calculations to be materialized explicitly. Early POC was usingmemref<?xi8>andmemref.viewbut it required a lot of memref casts back and forth.ptrdialect? I had aptrdialect POC as well butptrdialect is currently incomplete and also at this point it is a carbon copy ofllvmdialect and provides no useful abstraction.