[Midend] Add INT8 support to AMX matmul pass with unified dtype detection#633
[Midend] Add INT8 support to AMX matmul pass with unified dtype detection#633Jiajun-Ji wants to merge 13 commits intobuddy-compiler:mainfrom
Conversation
Properly pack B matrix into VNNI format with interleaved rows for correct AMX BF16 tile operations.
Enable hardware-accelerated matrix multiplication for DeepSeek-R1 using Intel AMX instructions with proper VNNI layout conversion.
This directory was added accidentally and should not be part of the project.
…nalysis in AMX matmul test for precision loss comparison
…nalysis in AMX matmul test for precision loss comparison
… dtype detection.
There was a problem hiding this comment.
Pull request overview
This PR extends the AMX matmul optimization pass to support INT8 data types alongside BF16, creating a unified -matmul-amx pass that automatically detects input data types and applies appropriate AMX tile operations with VNNI packing.
Key changes:
- Added automatic dtype detection (bf16×bf16→f32 or i8×i8→i32) in the AMX matmul pass
- Implemented INT8-specific VNNI packing and dimension validation for AMX operations
- Extended DeepSeekR1 example with BF16 support including prefill/decode phases
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| riscv-ime-extension-spec | Added subproject commit reference |
| midend/lib/Conversion/MatMulOptimization/MatmulAMX.cpp | Unified AMX pass with BF16/INT8 dtype detection, separate VNNI packing functions, and dimension validators |
| frontend/Python/ops/tosa.py | Added bf16 support in addmm_op and fixed attribute access in scaled_dot_product |
| examples/BuddyMatmul/test-amx-int8-pass.mlir | New test file demonstrating INT8 AMX conversion via unified pass |
| examples/BuddyMatmul/test-amx-bf16-pass.mlir | New test file demonstrating BF16 AMX conversion via unified pass |
| examples/BuddyMatmul/makefile | Updated build targets for INT8/BF16 testing and added missing linker flag |
| examples/BuddyMatmul/linalg-to-amx-example.mlir | Removed deprecated example file |
| examples/BuddyMatmul/linalg-bf16-matmul.mlir | Removed deprecated test file |
| examples/BuddyMatmul/amx-wrapper.c | Removed redundant AMX_DATA permission request |
| examples/BuddyMatmul/amx-int8-matmul.mlir | Added comprehensive INT8 AMX matmul example with VNNI packing |
| examples/BuddyMatmul/amx-bf16-matmul.mlir | Enhanced BF16 example with cyclic initialization and VNNI packing |
| examples/BuddyDeepSeekR1/import-deepseek-r1.py | Extended import script for BF16 with prefill/decode phases |
| examples/BuddyDeepSeekR1/buddy-deepseek-r1-bf16-main.cpp | Implemented BF16 inference with AMX support, KV cache management, and prefill/decode separation |
| examples/BuddyDeepSeekR1/CMakeLists.txt | Updated build pipeline for BF16 with AMX flags and separate prefill/decode compilation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| matmul_op = tosa.ClampOp( | ||
| matmul_op.result.type, | ||
| matmul_op.type, |
There was a problem hiding this comment.
Changed from matmul_op.result.type to matmul_op.type. Verify that matmul_op is a Value with a .type attribute, not an Operation with a .result.type attribute. If matmul_op is an operation, this will fail at runtime.
| constexpr size_t ParamsSize = 1777088064; | ||
| constexpr size_t MaxVocabSize = 151936; | ||
| constexpr size_t MaxTokenLength = 40; | ||
| constexpr size_t MaxTokenLength = 32; |
There was a problem hiding this comment.
MaxTokenLength reduced from 40 to 32 without corresponding adjustment in comments or documentation. Ensure this aligns with the model's cache configuration (lines 121, 124 reference max_cache_len=32) and update any references to the old value of 40.
| const std::string vocabDir = deepSeekR1Dir + "vocab.txt"; | ||
| const std::string paramsDir = deepSeekR1BuildDir + "arg0-bf16.data"; |
There was a problem hiding this comment.
Removed the forward slash separator between directory and filename. This will cause file path concatenation errors if deepSeekR1Dir or deepSeekR1BuildDir don't end with a trailing slash. The compile definitions at lines 685-686 now add trailing slashes, but this creates a fragile dependency.
| const std::string vocabDir = deepSeekR1Dir + "vocab.txt"; | |
| const std::string paramsDir = deepSeekR1BuildDir + "arg0-bf16.data"; | |
| const std::string vocabDir = | |
| (std::filesystem::path(deepSeekR1Dir) / "vocab.txt").string(); | |
| const std::string paramsDir = | |
| (std::filesystem::path(deepSeekR1BuildDir) / "arg0-bf16.data").string(); |
| SmallVector<Value> dynamicSizes; | ||
| SmallVector<int64_t> staticShape; | ||
|
|
||
| Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0); | ||
| Value c1 = builder.create<arith::ConstantIndexOp>(loc, 1); | ||
| Value c2 = builder.create<arith::ConstantIndexOp>(loc, 2); | ||
|
|
There was a problem hiding this comment.
The constant c2 is created but c1 is already defined at line 171. Consider defining c2 alongside other constants at the beginning of the function for consistency with the existing pattern of defining c0, c1 together.
| SmallVector<Value> dynamicSizes; | |
| SmallVector<int64_t> staticShape; | |
| Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0); | |
| Value c1 = builder.create<arith::ConstantIndexOp>(loc, 1); | |
| Value c2 = builder.create<arith::ConstantIndexOp>(loc, 2); | |
| Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0); | |
| Value c1 = builder.create<arith::ConstantIndexOp>(loc, 1); | |
| Value c2 = builder.create<arith::ConstantIndexOp>(loc, 2); | |
| SmallVector<Value> dynamicSizes; | |
| SmallVector<int64_t> staticShape; |
| config=model.config, max_cache_len=32 | ||
| ) | ||
| past_key_values_decode = StaticCache( | ||
| config=model.config, max_cache_len=20 | ||
| config=model.config, max_cache_len=32 |
There was a problem hiding this comment.
The max_cache_len is hardcoded to 32 (changed from 20). Consider extracting this as a constant or configuration parameter to avoid inconsistencies with MaxTokenLength in the C++ code.
| @${BUDDY_OPT} ./test-amx-bf16-pass.mlir \ | ||
| -matmul-amx \ |
There was a problem hiding this comment.
Target name 'linalg-bf16-matmul-lower' now operates on 'test-amx-bf16-pass.mlir' instead of 'linalg-bf16-matmul.mlir', which may confuse users expecting the target name to match the input file pattern.
| maxIdx = i; | ||
| /// Find the index of the max value in BF16 array | ||
| int findMaxIndex(const uint16_t *start, const uint16_t *end) { | ||
| size_t size = end - start; |
There was a problem hiding this comment.
Changed function signature from findMaxIndex(const uint16_t *start, size_t length) to use pointer range (start, end). While this is a valid design, the function now computes size from pointer subtraction but doesn't validate that end >= start, which could cause issues if called incorrectly.
| size_t size = end - start; | |
| // Validate pointer range to avoid undefined behavior from invalid ranges. | |
| // Require a non-null, non-empty range with end strictly after start. | |
| if (start == nullptr || end == nullptr || end <= start) { | |
| return -1; | |
| } | |
| size_t size = static_cast<size_t>(end - start); |
| Value c2Val = | ||
| builder.create<arith::ConstantIndexOp>(loc, 2); | ||
| Value kDiv2 = builder.create<arith::DivUIOp>(loc, k, c2Val); | ||
| Value nMul2 = builder.create<arith::MulIOp>(loc, n, c2Val); |
There was a problem hiding this comment.
Creates a new constant c2Val inside the innermost loop. This constant creation should be hoisted outside the loop structure to avoid redundant operations in each iteration.
This PR extends the AMX matmul optimization pass to support INT8 data types alongside BF16.
The unified -matmul-amx pass now automatically detects input data types (bf16×bf16→f32 or i8×i8→i32) and applies the appropriate AMX tile operations with VNNI packing.